diff --git a/.coveragerc b/.coveragerc index 0b5d5bf0ad4..7792266b114 100644 --- a/.coveragerc +++ b/.coveragerc @@ -6,3 +6,6 @@ omit = site-packages [report] exclude_also = if TYPE_CHECKING + assert False + : \.\.\.(\s*#.*)?$ + ^ +\.\.\.$ diff --git a/.github/dependabot.yml b/.github/dependabot.yml index 3b392a34b3b..d1898c69e6e 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -18,17 +18,17 @@ updates: interval: "daily" open-pull-requests-limit: 10 - # Maintain dependencies for GitHub Actions aiohttp 3.9 + # Maintain dependencies for GitHub Actions aiohttp backport - package-ecosystem: "github-actions" directory: "/" labels: - dependencies - target-branch: "3.9" + target-branch: "3.10" schedule: interval: "daily" open-pull-requests-limit: 10 - # Maintain dependencies for Python aiohttp 3.10 + # Maintain dependencies for Python aiohttp backport - package-ecosystem: "pip" directory: "/" labels: diff --git a/.github/workflows/ci-cd.yml b/.github/workflows/ci-cd.yml index 0b9c1dbcb96..93d4575da2d 100644 --- a/.github/workflows/ci-cd.yml +++ b/.github/workflows/ci-cd.yml @@ -45,7 +45,7 @@ jobs: - name: Setup Python uses: actions/setup-python@v5 with: - python-version: 3.9 + python-version: 3.11 - name: Cache PyPI uses: actions/cache@v4.0.2 with: @@ -162,7 +162,8 @@ jobs: - name: Get pip cache dir id: pip-cache run: | - echo "::set-output name=dir::$(pip cache dir)" # - name: Cache + echo "dir=$(pip cache dir)" >> "${GITHUB_OUTPUT}" + shell: bash - name: Cache PyPI uses: actions/cache@v4.0.2 with: @@ -221,7 +222,7 @@ jobs: run: | python -m coverage xml - name: Upload coverage - uses: codecov/codecov-action@v3 + uses: codecov/codecov-action@v4 with: file: ./coverage.xml flags: >- @@ -350,7 +351,7 @@ jobs: run: | make cythonize - name: Build wheels - uses: pypa/cibuildwheel@v2.17.0 + uses: pypa/cibuildwheel@v2.19.2 env: CIBW_ARCHS_MACOS: x86_64 arm64 universal2 - uses: actions/upload-artifact@v3 @@ -405,7 +406,7 @@ jobs: uses: pypa/gh-action-pypi-publish@release/v1 - name: Sign the dists with Sigstore - uses: sigstore/gh-action-sigstore-python@v2.1.1 + uses: sigstore/gh-action-sigstore-python@v3.0.0 with: inputs: >- ./dist/*.tar.gz @@ -415,7 +416,7 @@ jobs: # Confusingly, this action also supports updating releases, not # just creating them. This is what we want here, since we've manually # created the release above. - uses: softprops/action-gh-release@v1 + uses: softprops/action-gh-release@v2 with: # dist/ contains the built packages, which smoketest-artifacts/ # contains the signatures and certificates. diff --git a/.github/workflows/labels.yml b/.github/workflows/labels.yml index a4e961e88af..8d9c0f6f4a2 100644 --- a/.github/workflows/labels.yml +++ b/.github/workflows/labels.yml @@ -9,6 +9,7 @@ jobs: backport: runs-on: ubuntu-latest name: Backport label added + if: ${{ github.event.pull_request.user.type != 'Bot' }} steps: - uses: actions/github-script@v7 with: diff --git a/.mypy.ini b/.mypy.ini index 86b5c86f345..78001c36e8f 100644 --- a/.mypy.ini +++ b/.mypy.ini @@ -12,6 +12,8 @@ disallow_untyped_defs = True extra_checks = True implicit_reexport = False no_implicit_optional = True +pretty = True +show_column_numbers = True show_error_codes = True show_error_code_links = True strict_equality = True @@ -26,12 +28,6 @@ warn_return_any = True disallow_untyped_calls = False disallow_untyped_defs = False -[mypy-aiodns] -ignore_missing_imports = True - -[mypy-asynctest] -ignore_missing_imports = True - [mypy-brotli] ignore_missing_imports = True @@ -40,6 +36,3 @@ ignore_missing_imports = True [mypy-gunicorn.*] ignore_missing_imports = True - -[mypy-python_on_whales] -ignore_missing_imports = True diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d11ab1bfa32..dc3e65cf52f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -48,24 +48,24 @@ repos: entry: ./tools/check_changes.py pass_filenames: false - repo: https://github.com/pre-commit/pre-commit-hooks - rev: 'v4.0.1' + rev: 'v4.6.0' hooks: - id: check-merge-conflict - repo: https://github.com/asottile/yesqa - rev: v1.3.0 + rev: v1.5.0 hooks: - id: yesqa - repo: https://github.com/PyCQA/isort - rev: '5.11.5' + rev: '5.13.2' hooks: - id: isort - repo: https://github.com/psf/black - rev: '22.3.0' + rev: '24.4.0' hooks: - id: black language_version: python3 # Should be a command that runs python - repo: https://github.com/pre-commit/pre-commit-hooks - rev: 'v4.0.1' + rev: 'v4.6.0' hooks: - id: end-of-file-fixer exclude: >- @@ -97,12 +97,12 @@ repos: - id: detect-private-key exclude: ^examples/ - repo: https://github.com/asottile/pyupgrade - rev: 'v2.29.0' + rev: 'v3.15.2' hooks: - id: pyupgrade args: ['--py37-plus'] - repo: https://github.com/PyCQA/flake8 - rev: '4.0.1' + rev: '7.0.0' hooks: - id: flake8 additional_dependencies: diff --git a/CHANGES.rst b/CHANGES.rst index 5b02623067a..0150c95494c 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -10,6 +10,466 @@ .. towncrier release notes start +3.10.2 (2024-08-08) +=================== + +Bug fixes +--------- + +- Fixed server checks for circular symbolic links to be compatible with Python 3.13 -- by :user:`steverep`. + + + *Related issues and pull requests on GitHub:* + :issue:`8565`. + + + +- Fixed request body not being read when ignoring an Upgrade request -- by :user:`Dreamsorcerer`. + + + *Related issues and pull requests on GitHub:* + :issue:`8597`. + + + +- Fixed an edge case where shutdown would wait for timeout when the handler was already completed -- by :user:`Dreamsorcerer`. + + + *Related issues and pull requests on GitHub:* + :issue:`8611`. + + + +- Fixed connecting to ``npipe://``, ``tcp://``, and ``unix://`` urls -- by :user:`bdraco`. + + + *Related issues and pull requests on GitHub:* + :issue:`8632`. + + + +- Fixed WebSocket ping tasks being prematurely garbage collected -- by :user:`bdraco`. + + There was a small risk that WebSocket ping tasks would be prematurely garbage collected because the event loop only holds a weak reference to the task. The garbage collection risk has been fixed by holding a strong reference to the task. Additionally, the task is now scheduled eagerly with Python 3.12+ to increase the chance it can be completed immediately and avoid having to hold any references to the task. + + + *Related issues and pull requests on GitHub:* + :issue:`8641`. + + + +- Fixed incorrectly following symlinks for compressed file variants -- by :user:`steverep`. + + + *Related issues and pull requests on GitHub:* + :issue:`8652`. + + + + +Removals and backward incompatible breaking changes +--------------------------------------------------- + +- Removed ``Request.wait_for_disconnection()``, which was mistakenly added briefly in 3.10.0 -- by :user:`Dreamsorcerer`. + + + *Related issues and pull requests on GitHub:* + :issue:`8636`. + + + + +Contributor-facing changes +-------------------------- + +- Fixed monkey patches for ``Path.stat()`` and ``Path.is_dir()`` for Python 3.13 compatibility -- by :user:`steverep`. + + + *Related issues and pull requests on GitHub:* + :issue:`8551`. + + + + +Miscellaneous internal changes +------------------------------ + +- Improved WebSocket performance when messages are sent or received frequently -- by :user:`bdraco`. + + The WebSocket heartbeat scheduling algorithm was improved to reduce the ``asyncio`` scheduling overhead by decreasing the number of ``asyncio.TimerHandle`` creations and cancellations. + + + *Related issues and pull requests on GitHub:* + :issue:`8608`. + + + +- Minor improvements to various type annotations -- by :user:`Dreamsorcerer`. + + + *Related issues and pull requests on GitHub:* + :issue:`8634`. + + + + +---- + + +3.10.1 (2024-08-03) +======================== + +Bug fixes +--------- + +- Fixed WebSocket server heartbeat timeout logic to terminate :py:meth:`~aiohttp.ClientWebSocketResponse.receive` and return :py:class:`~aiohttp.ServerTimeoutError` -- by :user:`arcivanov`. + + When a WebSocket pong message was not received, the :py:meth:`~aiohttp.ClientWebSocketResponse.receive` operation did not terminate. This change causes ``_pong_not_received`` to feed the ``reader`` an error message, causing pending :py:meth:`~aiohttp.ClientWebSocketResponse.receive` to terminate and return the error message. The error message contains the exception :py:class:`~aiohttp.ServerTimeoutError`. + + + *Related issues and pull requests on GitHub:* + :issue:`8540`. + + + +- Fixed url dispatcher index not matching when a variable is preceded by a fixed string after a slash -- by :user:`bdraco`. + + + *Related issues and pull requests on GitHub:* + :issue:`8566`. + + + + +Removals and backward incompatible breaking changes +--------------------------------------------------- + +- Creating :py:class:`aiohttp.TCPConnector`, :py:class:`aiohttp.ClientSession`, :py:class:`~aiohttp.resolver.ThreadedResolver` :py:class:`aiohttp.web.Server`, or :py:class:`aiohttp.CookieJar` instances without a running event loop now raises a :exc:`RuntimeError` -- by :user:`asvetlov`. + + Creating these objects without a running event loop was deprecated in :issue:`3372` which was released in version 3.5.0. + + This change first appeared in version 3.10.0 as :issue:`6378`. + + + *Related issues and pull requests on GitHub:* + :issue:`8555`, :issue:`8583`. + + + + +---- + + +3.10.0 (2024-07-30) +======================== + +Bug fixes +--------- + +- Fixed server response headers for ``Content-Type`` and ``Content-Encoding`` for + static compressed files -- by :user:`steverep`. + + Server will now respond with a ``Content-Type`` appropriate for the compressed + file (e.g. ``"application/gzip"``), and omit the ``Content-Encoding`` header. + Users should expect that most clients will no longer decompress such responses + by default. + + + *Related issues and pull requests on GitHub:* + :issue:`4462`. + + + +- Fixed duplicate cookie expiration calls in the CookieJar implementation + + + *Related issues and pull requests on GitHub:* + :issue:`7784`. + + + +- Adjusted ``FileResponse`` to check file existence and access when preparing the response -- by :user:`steverep`. + + The :py:class:`~aiohttp.web.FileResponse` class was modified to respond with + 403 Forbidden or 404 Not Found as appropriate. Previously, it would cause a + server error if the path did not exist or could not be accessed. Checks for + existence, non-regular files, and permissions were expected to be done in the + route handler. For static routes, this now permits a compressed file to exist + without its uncompressed variant and still be served. In addition, this + changes the response status for files without read permission to 403, and for + non-regular files from 404 to 403 for consistency. + + + *Related issues and pull requests on GitHub:* + :issue:`8182`. + + + +- Fixed ``AsyncResolver`` to match ``ThreadedResolver`` behavior + -- by :user:`bdraco`. + + On system with IPv6 support, the :py:class:`~aiohttp.resolver.AsyncResolver` would not fallback + to providing A records when AAAA records were not available. + Additionally, unlike the :py:class:`~aiohttp.resolver.ThreadedResolver`, the :py:class:`~aiohttp.resolver.AsyncResolver` + did not handle link-local addresses correctly. + + This change makes the behavior consistent with the :py:class:`~aiohttp.resolver.ThreadedResolver`. + + + *Related issues and pull requests on GitHub:* + :issue:`8270`. + + + +- Fixed ``ws_connect`` not respecting `receive_timeout`` on WS(S) connection. + -- by :user:`arcivanov`. + + + *Related issues and pull requests on GitHub:* + :issue:`8444`. + + + +- Removed blocking I/O in the event loop for static resources and refactored + exception handling -- by :user:`steverep`. + + File system calls when handling requests for static routes were moved to a + separate thread to potentially improve performance. Exception handling + was tightened in order to only return 403 Forbidden or 404 Not Found responses + for expected scenarios; 500 Internal Server Error would be returned for any + unknown errors. + + + *Related issues and pull requests on GitHub:* + :issue:`8507`. + + + + +Features +-------- + +- Added a Request.wait_for_disconnection() method, as means of allowing request handlers to be notified of premature client disconnections. + + + *Related issues and pull requests on GitHub:* + :issue:`2492`. + + + +- Added 5 new exceptions: :py:exc:`~aiohttp.InvalidUrlClientError`, :py:exc:`~aiohttp.RedirectClientError`, + :py:exc:`~aiohttp.NonHttpUrlClientError`, :py:exc:`~aiohttp.InvalidUrlRedirectClientError`, + :py:exc:`~aiohttp.NonHttpUrlRedirectClientError` + + :py:exc:`~aiohttp.InvalidUrlRedirectClientError`, :py:exc:`~aiohttp.NonHttpUrlRedirectClientError` + are raised instead of :py:exc:`ValueError` or :py:exc:`~aiohttp.InvalidURL` when the redirect URL is invalid. Classes + :py:exc:`~aiohttp.InvalidUrlClientError`, :py:exc:`~aiohttp.RedirectClientError`, + :py:exc:`~aiohttp.NonHttpUrlClientError` are base for them. + + The :py:exc:`~aiohttp.InvalidURL` now exposes a ``description`` property with the text explanation of the error details. + + -- by :user:`setla`, :user:`AraHaan`, and :user:`bdraco` + + + *Related issues and pull requests on GitHub:* + :issue:`2507`, :issue:`3315`, :issue:`6722`, :issue:`8481`, :issue:`8482`. + + + +- Added a feature to retry closed connections automatically for idempotent methods. -- by :user:`Dreamsorcerer` + + + *Related issues and pull requests on GitHub:* + :issue:`7297`. + + + +- Implemented filter_cookies() with domain-matching and path-matching on the keys, instead of testing every single cookie. + This may break existing cookies that have been saved with `CookieJar.save()`. Cookies can be migrated with this script:: + + import pickle + with file_path.open("rb") as f: + cookies = pickle.load(f) + + morsels = [(name, m) for c in cookies.values() for name, m in c.items()] + cookies.clear() + for name, m in morsels: + cookies[(m["domain"], m["path"].rstrip("/"))][name] = m + + with file_path.open("wb") as f: + pickle.dump(cookies, f, pickle.HIGHEST_PROTOCOL) + + + *Related issues and pull requests on GitHub:* + :issue:`7583`, :issue:`8535`. + + + +- Separated connection and socket timeout errors, from ServerTimeoutError. + + + *Related issues and pull requests on GitHub:* + :issue:`7801`. + + + +- Implemented happy eyeballs + + + *Related issues and pull requests on GitHub:* + :issue:`7954`. + + + +- Added server capability to check for static files with Brotli compression via a ``.br`` extension -- by :user:`steverep`. + + + *Related issues and pull requests on GitHub:* + :issue:`8062`. + + + + +Removals and backward incompatible breaking changes +--------------------------------------------------- + +- The shutdown logic in 3.9 waited on all tasks, which caused issues with some libraries. + In 3.10 we've changed this logic to only wait on request handlers. This means that it's + important for developers to correctly handle the lifecycle of background tasks using a + library such as ``aiojobs``. If an application is using ``handler_cancellation=True`` then + it is also a good idea to ensure that any :func:`asyncio.shield` calls are replaced with + :func:`aiojobs.aiohttp.shield`. + + Please read the updated documentation on these points: \ + https://docs.aiohttp.org/en/stable/web_advanced.html#graceful-shutdown \ + https://docs.aiohttp.org/en/stable/web_advanced.html#web-handler-cancellation + + -- by :user:`Dreamsorcerer` + + + *Related issues and pull requests on GitHub:* + :issue:`8495`. + + + + +Improved documentation +---------------------- + +- Added documentation for ``aiohttp.web.FileResponse``. + + + *Related issues and pull requests on GitHub:* + :issue:`3958`. + + + +- Improved the docs for the `ssl` params. + + + *Related issues and pull requests on GitHub:* + :issue:`8403`. + + + + +Contributor-facing changes +-------------------------- + +- Enabled HTTP parser tests originally intended for 3.9.2 release -- by :user:`pajod`. + + + *Related issues and pull requests on GitHub:* + :issue:`8088`. + + + + +Miscellaneous internal changes +------------------------------ + +- Improved URL handler resolution time by indexing resources in the UrlDispatcher. + For applications with a large number of handlers, this should increase performance significantly. + -- by :user:`bdraco` + + + *Related issues and pull requests on GitHub:* + :issue:`7829`. + + + +- Added `nacl_middleware `_ to the list of middlewares in the third party section of the documentation. + + + *Related issues and pull requests on GitHub:* + :issue:`8346`. + + + +- Minor improvements to static typing -- by :user:`Dreamsorcerer`. + + + *Related issues and pull requests on GitHub:* + :issue:`8364`. + + + +- Added a 3.11-specific overloads to ``ClientSession`` -- by :user:`max-muoto`. + + + *Related issues and pull requests on GitHub:* + :issue:`8463`. + + + +- Simplified path checks for ``UrlDispatcher.add_static()`` method -- by :user:`steverep`. + + + *Related issues and pull requests on GitHub:* + :issue:`8491`. + + + +- Avoided creating a future on every websocket receive -- by :user:`bdraco`. + + + *Related issues and pull requests on GitHub:* + :issue:`8498`. + + + +- Updated identity checks for all ``WSMsgType`` type compares -- by :user:`bdraco`. + + + *Related issues and pull requests on GitHub:* + :issue:`8501`. + + + +- When using Python 3.12 or later, the writer is no longer scheduled on the event loop if it can finish synchronously. Avoiding event loop scheduling reduces latency and improves performance. -- by :user:`bdraco`. + + + *Related issues and pull requests on GitHub:* + :issue:`8510`. + + + +- Restored :py:class:`~aiohttp.resolver.AsyncResolver` to be the default resolver. -- by :user:`bdraco`. + + :py:class:`~aiohttp.resolver.AsyncResolver` was disabled by default because + of IPv6 compatibility issues. These issues have been resolved and + :py:class:`~aiohttp.resolver.AsyncResolver` is again now the default resolver. + + + *Related issues and pull requests on GitHub:* + :issue:`8522`. + + + + +---- + + 3.9.5 (2024-04-16) ================== diff --git a/CONTRIBUTORS.txt b/CONTRIBUTORS.txt index 277171a239e..202193375dd 100644 --- a/CONTRIBUTORS.txt +++ b/CONTRIBUTORS.txt @@ -46,6 +46,7 @@ Anes Abismail Antoine Pietri Anton Kasyanov Anton Zhdan-Pushkin +Arcadiy Ivanov Arseny Timoniq Artem Yushkovskiy Arthur Darcet @@ -351,6 +352,7 @@ William Grzybowski William S. Wilson Ong wouter bolsterlee +Xiang Li Yang Zhou Yannick Koechlin Yannick Péroux @@ -367,5 +369,6 @@ Yuvi Panda Zainab Lawal Zeal Wierslee Zlatan Sičanica +Łukasz Setla Марк Коренберг Семён Марьясин diff --git a/README.rst b/README.rst index 90b7f713577..45b647437e3 100644 --- a/README.rst +++ b/README.rst @@ -148,7 +148,7 @@ Communication channels *aio-libs Discussions*: https://github.com/aio-libs/aiohttp/discussions -*gitter chat* https://gitter.im/aio-libs/Lobby +*Matrix*: `#aio-libs:matrix.org `_ We support `Stack Overflow `_. diff --git a/aiohttp/__init__.py b/aiohttp/__init__.py index e82e790b46a..f050229f008 100644 --- a/aiohttp/__init__.py +++ b/aiohttp/__init__.py @@ -1,40 +1,47 @@ -__version__ = "3.9.5" +__version__ = "3.10.2" from typing import TYPE_CHECKING, Tuple from . import hdrs as hdrs from .client import ( - BaseConnector as BaseConnector, - ClientConnectionError as ClientConnectionError, - ClientConnectorCertificateError as ClientConnectorCertificateError, - ClientConnectorError as ClientConnectorError, - ClientConnectorSSLError as ClientConnectorSSLError, - ClientError as ClientError, - ClientHttpProxyError as ClientHttpProxyError, - ClientOSError as ClientOSError, - ClientPayloadError as ClientPayloadError, - ClientProxyConnectionError as ClientProxyConnectionError, - ClientRequest as ClientRequest, - ClientResponse as ClientResponse, - ClientResponseError as ClientResponseError, - ClientSession as ClientSession, - ClientSSLError as ClientSSLError, - ClientTimeout as ClientTimeout, - ClientWebSocketResponse as ClientWebSocketResponse, - ContentTypeError as ContentTypeError, - Fingerprint as Fingerprint, - InvalidURL as InvalidURL, - NamedPipeConnector as NamedPipeConnector, - RequestInfo as RequestInfo, - ServerConnectionError as ServerConnectionError, - ServerDisconnectedError as ServerDisconnectedError, - ServerFingerprintMismatch as ServerFingerprintMismatch, - ServerTimeoutError as ServerTimeoutError, - TCPConnector as TCPConnector, - TooManyRedirects as TooManyRedirects, - UnixConnector as UnixConnector, - WSServerHandshakeError as WSServerHandshakeError, - request as request, + BaseConnector, + ClientConnectionError, + ClientConnectorCertificateError, + ClientConnectorError, + ClientConnectorSSLError, + ClientError, + ClientHttpProxyError, + ClientOSError, + ClientPayloadError, + ClientProxyConnectionError, + ClientRequest, + ClientResponse, + ClientResponseError, + ClientSession, + ClientSSLError, + ClientTimeout, + ClientWebSocketResponse, + ConnectionTimeoutError, + ContentTypeError, + Fingerprint, + InvalidURL, + InvalidUrlClientError, + InvalidUrlRedirectClientError, + NamedPipeConnector, + NonHttpUrlClientError, + NonHttpUrlRedirectClientError, + RedirectClientError, + RequestInfo, + ServerConnectionError, + ServerDisconnectedError, + ServerFingerprintMismatch, + ServerTimeoutError, + SocketTimeoutError, + TCPConnector, + TooManyRedirects, + UnixConnector, + WSServerHandshakeError, + request, ) from .cookiejar import CookieJar as CookieJar, DummyCookieJar as DummyCookieJar from .formdata import FormData as FormData @@ -131,14 +138,21 @@ "ClientSession", "ClientTimeout", "ClientWebSocketResponse", + "ConnectionTimeoutError", "ContentTypeError", "Fingerprint", "InvalidURL", + "InvalidUrlClientError", + "InvalidUrlRedirectClientError", + "NonHttpUrlClientError", + "NonHttpUrlRedirectClientError", + "RedirectClientError", "RequestInfo", "ServerConnectionError", "ServerDisconnectedError", "ServerFingerprintMismatch", "ServerTimeoutError", + "SocketTimeoutError", "TCPConnector", "TooManyRedirects", "UnixConnector", diff --git a/aiohttp/_http_parser.pyx b/aiohttp/_http_parser.pyx index 7ea9b32ca55..dd317edaf79 100644 --- a/aiohttp/_http_parser.pyx +++ b/aiohttp/_http_parser.pyx @@ -47,6 +47,7 @@ include "_headers.pxi" from aiohttp cimport _find_header +ALLOWED_UPGRADES = frozenset({"websocket"}) DEF DEFAULT_FREELIST_SIZE = 250 cdef extern from "Python.h": @@ -417,7 +418,6 @@ cdef class HttpParser: cdef _on_headers_complete(self): self._process_header() - method = http_method_str(self._cparser.method) should_close = not cparser.llhttp_should_keep_alive(self._cparser) upgrade = self._cparser.upgrade chunked = self._cparser.flags & cparser.F_CHUNKED @@ -425,8 +425,13 @@ cdef class HttpParser: raw_headers = tuple(self._raw_headers) headers = CIMultiDictProxy(self._headers) - if upgrade or self._cparser.method == cparser.HTTP_CONNECT: - self._upgraded = True + if self._cparser.type == cparser.HTTP_REQUEST: + allowed = upgrade and headers.get("upgrade", "").lower() in ALLOWED_UPGRADES + if allowed or self._cparser.method == cparser.HTTP_CONNECT: + self._upgraded = True + else: + if upgrade and self._cparser.status_code == 101: + self._upgraded = True # do not support old websocket spec if SEC_WEBSOCKET_KEY1 in headers: @@ -441,6 +446,7 @@ cdef class HttpParser: encoding = enc if self._cparser.type == cparser.HTTP_REQUEST: + method = http_method_str(self._cparser.method) msg = _new_request_message( method, self._path, self.http_version(), headers, raw_headers, @@ -565,7 +571,7 @@ cdef class HttpParser: if self._upgraded: return messages, True, data[nb:] else: - return messages, False, b'' + return messages, False, b"" def set_upgraded(self, val): self._upgraded = val @@ -748,10 +754,7 @@ cdef int cb_on_headers_complete(cparser.llhttp_t* parser) except -1: pyparser._last_error = exc return -1 else: - if ( - pyparser._cparser.upgrade or - pyparser._cparser.method == cparser.HTTP_CONNECT - ): + if pyparser._upgraded or pyparser._cparser.method == cparser.HTTP_CONNECT: return 2 else: return 0 diff --git a/aiohttp/abc.py b/aiohttp/abc.py index ee838998997..3fb024048a4 100644 --- a/aiohttp/abc.py +++ b/aiohttp/abc.py @@ -1,5 +1,6 @@ import asyncio import logging +import socket from abc import ABC, abstractmethod from collections.abc import Sized from http.cookies import BaseCookie, Morsel @@ -14,12 +15,12 @@ List, Optional, Tuple, + TypedDict, ) from multidict import CIMultiDict from yarl import URL -from .helpers import get_running_loop from .typedefs import LooseCookies if TYPE_CHECKING: @@ -119,11 +120,35 @@ def __await__(self) -> Generator[Any, None, StreamResponse]: """Execute the view handler.""" +class ResolveResult(TypedDict): + """Resolve result. + + This is the result returned from an AbstractResolver's + resolve method. + + :param hostname: The hostname that was provided. + :param host: The IP address that was resolved. + :param port: The port that was resolved. + :param family: The address family that was resolved. + :param proto: The protocol that was resolved. + :param flags: The flags that were resolved. + """ + + hostname: str + host: str + port: int + family: int + proto: int + flags: int + + class AbstractResolver(ABC): """Abstract DNS resolver.""" @abstractmethod - async def resolve(self, host: str, port: int, family: int) -> List[Dict[str, Any]]: + async def resolve( + self, host: str, port: int = 0, family: socket.AddressFamily = socket.AF_INET + ) -> List[ResolveResult]: """Return IP address for given hostname""" @abstractmethod @@ -144,7 +169,7 @@ class AbstractCookieJar(Sized, IterableBase): """Abstract Cookie Jar.""" def __init__(self, *, loop: Optional[asyncio.AbstractEventLoop] = None) -> None: - self._loop = get_running_loop(loop) + self._loop = loop or asyncio.get_running_loop() @abstractmethod def clear(self, predicate: Optional[ClearCookiePredicate] = None) -> None: diff --git a/aiohttp/client.py b/aiohttp/client.py index 32d2c3b7119..3d1045f355a 100644 --- a/aiohttp/client.py +++ b/aiohttp/client.py @@ -9,7 +9,7 @@ import traceback import warnings from contextlib import suppress -from types import SimpleNamespace, TracebackType +from types import TracebackType from typing import ( TYPE_CHECKING, Any, @@ -27,6 +27,7 @@ Set, Tuple, Type, + TypedDict, TypeVar, Union, ) @@ -38,25 +39,32 @@ from . import hdrs, http, payload from .abc import AbstractCookieJar from .client_exceptions import ( - ClientConnectionError as ClientConnectionError, - ClientConnectorCertificateError as ClientConnectorCertificateError, - ClientConnectorError as ClientConnectorError, - ClientConnectorSSLError as ClientConnectorSSLError, - ClientError as ClientError, - ClientHttpProxyError as ClientHttpProxyError, - ClientOSError as ClientOSError, - ClientPayloadError as ClientPayloadError, - ClientProxyConnectionError as ClientProxyConnectionError, - ClientResponseError as ClientResponseError, - ClientSSLError as ClientSSLError, - ContentTypeError as ContentTypeError, - InvalidURL as InvalidURL, - ServerConnectionError as ServerConnectionError, - ServerDisconnectedError as ServerDisconnectedError, - ServerFingerprintMismatch as ServerFingerprintMismatch, - ServerTimeoutError as ServerTimeoutError, - TooManyRedirects as TooManyRedirects, - WSServerHandshakeError as WSServerHandshakeError, + ClientConnectionError, + ClientConnectorCertificateError, + ClientConnectorError, + ClientConnectorSSLError, + ClientError, + ClientHttpProxyError, + ClientOSError, + ClientPayloadError, + ClientProxyConnectionError, + ClientResponseError, + ClientSSLError, + ConnectionTimeoutError, + ContentTypeError, + InvalidURL, + InvalidUrlClientError, + InvalidUrlRedirectClientError, + NonHttpUrlClientError, + NonHttpUrlRedirectClientError, + RedirectClientError, + ServerConnectionError, + ServerDisconnectedError, + ServerFingerprintMismatch, + ServerTimeoutError, + SocketTimeoutError, + TooManyRedirects, + WSServerHandshakeError, ) from .client_reqrep import ( ClientRequest as ClientRequest, @@ -67,6 +75,7 @@ ) from .client_ws import ClientWebSocketResponse as ClientWebSocketResponse from .connector import ( + HTTP_AND_EMPTY_SCHEMA_SET, BaseConnector as BaseConnector, NamedPipeConnector as NamedPipeConnector, TCPConnector as TCPConnector, @@ -80,7 +89,6 @@ TimeoutHandle, ceil_timeout, get_env_proxy_for_url, - get_running_loop, method_must_be_empty_body, sentinel, strip_auth_from_url, @@ -104,12 +112,19 @@ "ClientProxyConnectionError", "ClientResponseError", "ClientSSLError", + "ConnectionTimeoutError", "ContentTypeError", "InvalidURL", + "InvalidUrlClientError", + "RedirectClientError", + "NonHttpUrlClientError", + "InvalidUrlRedirectClientError", + "NonHttpUrlRedirectClientError", "ServerConnectionError", "ServerDisconnectedError", "ServerFingerprintMismatch", "ServerTimeoutError", + "SocketTimeoutError", "TooManyRedirects", "WSServerHandshakeError", # client_reqrep @@ -136,6 +151,37 @@ else: SSLContext = None +if sys.version_info >= (3, 11) and TYPE_CHECKING: + from typing import Unpack + + +class _RequestOptions(TypedDict, total=False): + params: Union[Mapping[str, Union[str, int]], str, None] + data: Any + json: Any + cookies: Union[LooseCookies, None] + headers: Union[LooseHeaders, None] + skip_auto_headers: Union[Iterable[str], None] + auth: Union[BasicAuth, None] + allow_redirects: bool + max_redirects: int + compress: Union[str, None] + chunked: Union[bool, None] + expect100: bool + raise_for_status: Union[None, bool, Callable[[ClientResponse], Awaitable[None]]] + read_until_eof: bool + proxy: Union[StrOrURL, None] + proxy_auth: Union[BasicAuth, None] + timeout: "Union[ClientTimeout, _SENTINEL, None]" + ssl: Union[SSLContext, bool, Fingerprint] + server_hostname: Union[str, None] + proxy_headers: Union[LooseHeaders, None] + trace_request_ctx: Union[Mapping[str, str], None] + read_bufsize: Union[int, None] + auto_decompress: Union[bool, None] + max_line_size: Union[int, None] + max_field_size: Union[int, None] + @attr.s(auto_attribs=True, frozen=True, slots=True) class ClientTimeout: @@ -162,6 +208,9 @@ class ClientTimeout: # 5 Minute default read timeout DEFAULT_TIMEOUT: Final[ClientTimeout] = ClientTimeout(total=5 * 60) +# https://www.rfc-editor.org/rfc/rfc9110#section-9.2.2 +IDEMPOTENT_METHODS = frozenset({"GET", "HEAD", "OPTIONS", "TRACE", "PUT", "DELETE"}) + _RetType = TypeVar("_RetType") _CharsetResolver = Callable[[ClientResponse, bytes], str] @@ -237,6 +286,21 @@ def __init__( # We initialise _connector to None immediately, as it's referenced in __del__() # and could cause issues if an exception occurs during initialisation. self._connector: Optional[BaseConnector] = None + + if loop is None: + if connector is not None: + loop = connector._loop + + loop = loop or asyncio.get_running_loop() + + if base_url is None or isinstance(base_url, URL): + self._base_url: Optional[URL] = base_url + else: + self._base_url = URL(base_url) + assert ( + self._base_url.origin() == self._base_url + ), "Only absolute URLs without path part are supported" + if timeout is sentinel or timeout is None: self._timeout = DEFAULT_TIMEOUT if read_timeout is not sentinel: @@ -272,19 +336,6 @@ def __init__( "conflict, please setup " "timeout.connect" ) - if loop is None: - if connector is not None: - loop = connector._loop - - loop = get_running_loop(loop) - - if base_url is None or isinstance(base_url, URL): - self._base_url: Optional[URL] = base_url - else: - self._base_url = URL(base_url) - assert ( - self._base_url.origin() == self._base_url - ), "Only absolute URLs without path part are supported" if connector is None: connector = TCPConnector(loop=loop) @@ -369,11 +420,22 @@ def __del__(self, _warnings: Any = warnings) -> None: context["source_traceback"] = self._source_traceback self._loop.call_exception_handler(context) - def request( - self, method: str, url: StrOrURL, **kwargs: Any - ) -> "_RequestContextManager": - """Perform HTTP request.""" - return _RequestContextManager(self._request(method, url, **kwargs)) + if sys.version_info >= (3, 11) and TYPE_CHECKING: + + def request( + self, + method: str, + url: StrOrURL, + **kwargs: Unpack[_RequestOptions], + ) -> "_RequestContextManager": ... + + else: + + def request( + self, method: str, url: StrOrURL, **kwargs: Any + ) -> "_RequestContextManager": + """Perform HTTP request.""" + return _RequestContextManager(self._request(method, url, **kwargs)) def _build_url(self, str_or_url: StrOrURL) -> URL: url = URL(str_or_url) @@ -413,7 +475,7 @@ async def _request( ssl: Union[SSLContext, bool, Fingerprint] = True, server_hostname: Optional[str] = None, proxy_headers: Optional[LooseHeaders] = None, - trace_request_ctx: Optional[SimpleNamespace] = None, + trace_request_ctx: Optional[Mapping[str, str]] = None, read_bufsize: Optional[int] = None, auto_decompress: Optional[bool] = None, max_line_size: Optional[int] = None, @@ -451,7 +513,11 @@ async def _request( try: url = self._build_url(str_or_url) except ValueError as e: - raise InvalidURL(str_or_url) from e + raise InvalidUrlClientError(str_or_url) from e + + assert self._connector is not None + if url.scheme not in self._connector.allowed_protocol_schema_set: + raise NonHttpUrlClientError(url) skip_headers = set(self._skip_auto_headers) if skip_auto_headers is not None: @@ -505,8 +571,19 @@ async def _request( timer = tm.timer() try: with timer: + # https://www.rfc-editor.org/rfc/rfc9112.html#name-retrying-requests + retry_persistent_connection = method in IDEMPOTENT_METHODS while True: url, auth_from_url = strip_auth_from_url(url) + if not url.raw_host: + # NOTE: Bail early, otherwise, causes `InvalidURL` through + # NOTE: `self._request_class()` below. + err_exc_cls = ( + InvalidUrlRedirectClientError + if redirects + else InvalidUrlClientError + ) + raise err_exc_cls(url) if auth and auth_from_url: raise ValueError( "Cannot combine AUTH argument with " @@ -577,13 +654,12 @@ async def _request( real_timeout.connect, ceil_threshold=real_timeout.ceil_threshold, ): - assert self._connector is not None conn = await self._connector.connect( req, traces=traces, timeout=real_timeout ) except asyncio.TimeoutError as exc: - raise ServerTimeoutError( - "Connection timeout " "to host {}".format(url) + raise ConnectionTimeoutError( + f"Connection timeout to host {url}" ) from exc assert conn.transport is not None @@ -612,6 +688,11 @@ async def _request( except BaseException: conn.close() raise + except (ClientOSError, ServerDisconnectedError): + if retry_persistent_connection: + retry_persistent_connection = False + continue + raise except ClientError: raise except OSError as exc: @@ -659,25 +740,35 @@ async def _request( resp.release() try: - parsed_url = URL( + parsed_redirect_url = URL( r_url, encoded=not self._requote_redirect_url ) - except ValueError as e: - raise InvalidURL(r_url) from e + raise InvalidUrlRedirectClientError( + r_url, + "Server attempted redirecting to a location that does not look like a URL", + ) from e - scheme = parsed_url.scheme - if scheme not in ("http", "https", ""): + scheme = parsed_redirect_url.scheme + if scheme not in HTTP_AND_EMPTY_SCHEMA_SET: resp.close() - raise ValueError("Can redirect only to http or https") + raise NonHttpUrlRedirectClientError(r_url) elif not scheme: - parsed_url = url.join(parsed_url) + parsed_redirect_url = url.join(parsed_redirect_url) - if url.origin() != parsed_url.origin(): + try: + redirect_origin = parsed_redirect_url.origin() + except ValueError as origin_val_err: + raise InvalidUrlRedirectClientError( + parsed_redirect_url, + "Invalid redirect URL origin", + ) from origin_val_err + + if url.origin() != redirect_origin: auth = None headers.pop(hdrs.AUTHORIZATION, None) - url = parsed_url + url = parsed_redirect_url params = {} resp.release() continue @@ -740,7 +831,7 @@ def ws_connect( headers: Optional[LooseHeaders] = None, proxy: Optional[StrOrURL] = None, proxy_auth: Optional[BasicAuth] = None, - ssl: Union[SSLContext, bool, None, Fingerprint] = True, + ssl: Union[SSLContext, bool, Fingerprint] = True, verify_ssl: Optional[bool] = None, fingerprint: Optional[bytes] = None, ssl_context: Optional[SSLContext] = None, @@ -792,7 +883,7 @@ async def _ws_connect( headers: Optional[LooseHeaders] = None, proxy: Optional[StrOrURL] = None, proxy_auth: Optional[BasicAuth] = None, - ssl: Optional[Union[SSLContext, bool, Fingerprint]] = True, + ssl: Union[SSLContext, bool, Fingerprint] = True, verify_ssl: Optional[bool] = None, fingerprint: Optional[bytes] = None, ssl_context: Optional[SSLContext] = None, @@ -828,6 +919,11 @@ async def _ws_connect( # For the sake of backward compatibility, if user passes in None, convert it to True if ssl is None: + warnings.warn( + "ssl=None is deprecated, please use ssl=True", + DeprecationWarning, + stacklevel=2, + ) ssl = True ssl = _merge_ssl_params(ssl, verify_ssl, ssl_context, fingerprint) @@ -922,6 +1018,16 @@ async def _ws_connect( assert conn is not None conn_proto = conn.protocol assert conn_proto is not None + + # For WS connection the read_timeout must be either receive_timeout or greater + # None == no timeout, i.e. infinite timeout, so None is the max timeout possible + if receive_timeout is None: + # Reset regardless + conn_proto.read_timeout = receive_timeout + elif conn_proto.read_timeout is not None: + # If read_timeout was set check which wins + conn_proto.read_timeout = max(receive_timeout, conn_proto.read_timeout) + transport = conn.transport assert transport is not None reader: FlowControlDataQueue[WSMessage] = FlowControlDataQueue( @@ -970,61 +1076,111 @@ def _prepare_headers(self, headers: Optional[LooseHeaders]) -> "CIMultiDict[str] added_names.add(key) return result - def get( - self, url: StrOrURL, *, allow_redirects: bool = True, **kwargs: Any - ) -> "_RequestContextManager": - """Perform HTTP GET request.""" - return _RequestContextManager( - self._request(hdrs.METH_GET, url, allow_redirects=allow_redirects, **kwargs) - ) + if sys.version_info >= (3, 11) and TYPE_CHECKING: + + def get( + self, + url: StrOrURL, + **kwargs: Unpack[_RequestOptions], + ) -> "_RequestContextManager": ... + + def options( + self, + url: StrOrURL, + **kwargs: Unpack[_RequestOptions], + ) -> "_RequestContextManager": ... + + def head( + self, + url: StrOrURL, + **kwargs: Unpack[_RequestOptions], + ) -> "_RequestContextManager": ... + + def post( + self, + url: StrOrURL, + **kwargs: Unpack[_RequestOptions], + ) -> "_RequestContextManager": ... + + def put( + self, + url: StrOrURL, + **kwargs: Unpack[_RequestOptions], + ) -> "_RequestContextManager": ... + + def patch( + self, + url: StrOrURL, + **kwargs: Unpack[_RequestOptions], + ) -> "_RequestContextManager": ... + + def delete( + self, + url: StrOrURL, + **kwargs: Unpack[_RequestOptions], + ) -> "_RequestContextManager": ... + + else: + + def get( + self, url: StrOrURL, *, allow_redirects: bool = True, **kwargs: Any + ) -> "_RequestContextManager": + """Perform HTTP GET request.""" + return _RequestContextManager( + self._request( + hdrs.METH_GET, url, allow_redirects=allow_redirects, **kwargs + ) + ) - def options( - self, url: StrOrURL, *, allow_redirects: bool = True, **kwargs: Any - ) -> "_RequestContextManager": - """Perform HTTP OPTIONS request.""" - return _RequestContextManager( - self._request( - hdrs.METH_OPTIONS, url, allow_redirects=allow_redirects, **kwargs + def options( + self, url: StrOrURL, *, allow_redirects: bool = True, **kwargs: Any + ) -> "_RequestContextManager": + """Perform HTTP OPTIONS request.""" + return _RequestContextManager( + self._request( + hdrs.METH_OPTIONS, url, allow_redirects=allow_redirects, **kwargs + ) ) - ) - def head( - self, url: StrOrURL, *, allow_redirects: bool = False, **kwargs: Any - ) -> "_RequestContextManager": - """Perform HTTP HEAD request.""" - return _RequestContextManager( - self._request( - hdrs.METH_HEAD, url, allow_redirects=allow_redirects, **kwargs + def head( + self, url: StrOrURL, *, allow_redirects: bool = False, **kwargs: Any + ) -> "_RequestContextManager": + """Perform HTTP HEAD request.""" + return _RequestContextManager( + self._request( + hdrs.METH_HEAD, url, allow_redirects=allow_redirects, **kwargs + ) ) - ) - def post( - self, url: StrOrURL, *, data: Any = None, **kwargs: Any - ) -> "_RequestContextManager": - """Perform HTTP POST request.""" - return _RequestContextManager( - self._request(hdrs.METH_POST, url, data=data, **kwargs) - ) + def post( + self, url: StrOrURL, *, data: Any = None, **kwargs: Any + ) -> "_RequestContextManager": + """Perform HTTP POST request.""" + return _RequestContextManager( + self._request(hdrs.METH_POST, url, data=data, **kwargs) + ) - def put( - self, url: StrOrURL, *, data: Any = None, **kwargs: Any - ) -> "_RequestContextManager": - """Perform HTTP PUT request.""" - return _RequestContextManager( - self._request(hdrs.METH_PUT, url, data=data, **kwargs) - ) + def put( + self, url: StrOrURL, *, data: Any = None, **kwargs: Any + ) -> "_RequestContextManager": + """Perform HTTP PUT request.""" + return _RequestContextManager( + self._request(hdrs.METH_PUT, url, data=data, **kwargs) + ) - def patch( - self, url: StrOrURL, *, data: Any = None, **kwargs: Any - ) -> "_RequestContextManager": - """Perform HTTP PATCH request.""" - return _RequestContextManager( - self._request(hdrs.METH_PATCH, url, data=data, **kwargs) - ) + def patch( + self, url: StrOrURL, *, data: Any = None, **kwargs: Any + ) -> "_RequestContextManager": + """Perform HTTP PATCH request.""" + return _RequestContextManager( + self._request(hdrs.METH_PATCH, url, data=data, **kwargs) + ) - def delete(self, url: StrOrURL, **kwargs: Any) -> "_RequestContextManager": - """Perform HTTP DELETE request.""" - return _RequestContextManager(self._request(hdrs.METH_DELETE, url, **kwargs)) + def delete(self, url: StrOrURL, **kwargs: Any) -> "_RequestContextManager": + """Perform HTTP DELETE request.""" + return _RequestContextManager( + self._request(hdrs.METH_DELETE, url, **kwargs) + ) async def close(self) -> None: """Close underlying connector. diff --git a/aiohttp/client_exceptions.py b/aiohttp/client_exceptions.py index 9b6e44203c8..ff29b3d3ca9 100644 --- a/aiohttp/client_exceptions.py +++ b/aiohttp/client_exceptions.py @@ -2,10 +2,10 @@ import asyncio import warnings -from typing import TYPE_CHECKING, Any, Optional, Tuple, Union +from typing import TYPE_CHECKING, Optional, Tuple, Union from .http_parser import RawResponseMessage -from .typedefs import LooseHeaders +from .typedefs import LooseHeaders, StrOrURL try: import ssl @@ -29,6 +29,8 @@ "ClientSSLError", "ClientConnectorSSLError", "ClientConnectorCertificateError", + "ConnectionTimeoutError", + "SocketTimeoutError", "ServerConnectionError", "ServerTimeoutError", "ServerDisconnectedError", @@ -39,6 +41,11 @@ "ContentTypeError", "ClientPayloadError", "InvalidURL", + "InvalidUrlClientError", + "RedirectClientError", + "NonHttpUrlClientError", + "InvalidUrlRedirectClientError", + "NonHttpUrlRedirectClientError", ) @@ -93,7 +100,7 @@ def __str__(self) -> str: return "{}, message={!r}, url={!r}".format( self.status, self.message, - self.request_info.real_url, + str(self.request_info.real_url), ) def __repr__(self) -> str: @@ -242,6 +249,14 @@ class ServerTimeoutError(ServerConnectionError, asyncio.TimeoutError): """Server timeout error.""" +class ConnectionTimeoutError(ServerTimeoutError): + """Connection timeout error.""" + + +class SocketTimeoutError(ServerTimeoutError): + """Socket timeout error.""" + + class ServerFingerprintMismatch(ServerConnectionError): """SSL certificate does not match expected fingerprint.""" @@ -271,17 +286,52 @@ class InvalidURL(ClientError, ValueError): # Derive from ValueError for backward compatibility - def __init__(self, url: Any) -> None: + def __init__(self, url: StrOrURL, description: Union[str, None] = None) -> None: # The type of url is not yarl.URL because the exception can be raised # on URL(url) call - super().__init__(url) + self._url = url + self._description = description + + if description: + super().__init__(url, description) + else: + super().__init__(url) + + @property + def url(self) -> StrOrURL: + return self._url @property - def url(self) -> Any: - return self.args[0] + def description(self) -> "str | None": + return self._description def __repr__(self) -> str: - return f"<{self.__class__.__name__} {self.url}>" + return f"<{self.__class__.__name__} {self}>" + + def __str__(self) -> str: + if self._description: + return f"{self._url} - {self._description}" + return str(self._url) + + +class InvalidUrlClientError(InvalidURL): + """Invalid URL client error.""" + + +class RedirectClientError(ClientError): + """Client redirect error.""" + + +class NonHttpUrlClientError(ClientError): + """Non http URL client error.""" + + +class InvalidUrlRedirectClientError(InvalidUrlClientError, RedirectClientError): + """Invalid URL redirect client error.""" + + +class NonHttpUrlRedirectClientError(NonHttpUrlClientError, RedirectClientError): + """Non http URL redirect client error.""" class ClientSSLError(ClientConnectorError): diff --git a/aiohttp/client_proto.py b/aiohttp/client_proto.py index 723f5aae5f4..f8c83240209 100644 --- a/aiohttp/client_proto.py +++ b/aiohttp/client_proto.py @@ -7,7 +7,7 @@ ClientOSError, ClientPayloadError, ServerDisconnectedError, - ServerTimeoutError, + SocketTimeoutError, ) from .helpers import ( _EXC_SENTINEL, @@ -224,8 +224,16 @@ def _reschedule_timeout(self) -> None: def start_timeout(self) -> None: self._reschedule_timeout() + @property + def read_timeout(self) -> Optional[float]: + return self._read_timeout + + @read_timeout.setter + def read_timeout(self, read_timeout: Optional[float]) -> None: + self._read_timeout = read_timeout + def _on_read_timeout(self) -> None: - exc = ServerTimeoutError("Timeout on reading data from socket") + exc = SocketTimeoutError("Timeout on reading data from socket") self.set_exception(exc) if self._payload is not None: set_exception(self._payload, exc) diff --git a/aiohttp/client_reqrep.py b/aiohttp/client_reqrep.py index afe719da16e..2c10da4ff81 100644 --- a/aiohttp/client_reqrep.py +++ b/aiohttp/client_reqrep.py @@ -245,7 +245,8 @@ class ClientRequest: hdrs.ACCEPT_ENCODING: _gen_default_accept_encoding(), } - body = b"" + # Type of body depends on PAYLOAD_REGISTRY, which is dynamic. + body: Any = b"" auth = None response = None @@ -352,7 +353,12 @@ def _writer(self, writer: Optional["asyncio.Task[None]"]) -> None: if self.__writer is not None: self.__writer.remove_done_callback(self.__reset_writer) self.__writer = writer - if writer is not None: + if writer is None: + return + if writer.done(): + # The writer is already done, so we can reset it immediately. + self.__reset_writer() + else: writer.add_done_callback(self.__reset_writer) def is_ssl(self) -> bool: @@ -436,7 +442,7 @@ def update_headers(self, headers: Optional[LooseHeaders]) -> None: if headers: if isinstance(headers, (dict, MultiDictProxy, MultiDict)): - headers = headers.items() # type: ignore[assignment] + headers = headers.items() for key, value in headers: # type: ignore[misc] # A special case for Host header @@ -566,7 +572,7 @@ def update_body_from_data(self, body: Any) -> None: # copy payload headers assert body.headers - for (key, value) in body.headers.items(): + for key, value in body.headers.items(): if key in self.headers: continue if key in self.skip_auto_headers: @@ -592,6 +598,10 @@ def update_proxy( raise ValueError("proxy_auth must be None or BasicAuth() tuple") self.proxy = proxy self.proxy_auth = proxy_auth + if proxy_headers is not None and not isinstance( + proxy_headers, (MultiDict, MultiDictProxy) + ): + proxy_headers = CIMultiDict(proxy_headers) self.proxy_headers = proxy_headers def keep_alive(self) -> bool: @@ -627,10 +637,10 @@ async def write_bytes( await self.body.write(writer) else: if isinstance(self.body, (bytes, bytearray)): - self.body = (self.body,) # type: ignore[assignment] + self.body = (self.body,) for chunk in self.body: - await writer.write(chunk) # type: ignore[arg-type] + await writer.write(chunk) except OSError as underlying_exc: reraised_exc = underlying_exc @@ -721,9 +731,17 @@ async def send(self, conn: "Connection") -> "ClientResponse": self.method, path, v=self.version ) await writer.write_headers(status_line, self.headers) + coro = self.write_bytes(writer, conn) - self._writer = self.loop.create_task(self.write_bytes(writer, conn)) + if sys.version_info >= (3, 12): + # Optimization for Python 3.12, try to write + # bytes immediately to avoid having to schedule + # the task on the event loop. + task = asyncio.Task(coro, loop=self.loop, eager_start=True) + else: + task = self.loop.create_task(coro) + self._writer = task response_class = self.response_class assert response_class is not None self.response = response_class( @@ -820,9 +838,9 @@ def __init__( # work after the response has finished reading the body. if session is None: # TODO: Fix session=None in tests (see ClientRequest.__init__). - self._resolve_charset: Callable[ - ["ClientResponse", bytes], str - ] = lambda *_: "utf-8" + self._resolve_charset: Callable[["ClientResponse", bytes], str] = ( + lambda *_: "utf-8" + ) else: self._resolve_charset = session._resolve_charset if loop.get_debug(): @@ -840,7 +858,12 @@ def _writer(self, writer: Optional["asyncio.Task[None]"]) -> None: if self.__writer is not None: self.__writer.remove_done_callback(self.__reset_writer) self.__writer = writer - if writer is not None: + if writer is None: + return + if writer.done(): + # The writer is already done, so we can reset it immediately. + self.__reset_writer() + else: writer.add_done_callback(self.__reset_writer) @reify diff --git a/aiohttp/client_ws.py b/aiohttp/client_ws.py index d9c74a30f52..247f62c758e 100644 --- a/aiohttp/client_ws.py +++ b/aiohttp/client_ws.py @@ -4,9 +4,9 @@ import sys from typing import Any, Optional, cast -from .client_exceptions import ClientError +from .client_exceptions import ClientError, ServerTimeoutError from .client_reqrep import ClientResponse -from .helpers import call_later, set_result +from .helpers import calculate_timeout_when, set_result from .http import ( WS_CLOSED_MESSAGE, WS_CLOSING_MESSAGE, @@ -62,63 +62,116 @@ def __init__( self._autoping = autoping self._heartbeat = heartbeat self._heartbeat_cb: Optional[asyncio.TimerHandle] = None + self._heartbeat_when: float = 0.0 if heartbeat is not None: self._pong_heartbeat = heartbeat / 2.0 self._pong_response_cb: Optional[asyncio.TimerHandle] = None self._loop = loop - self._waiting: Optional[asyncio.Future[bool]] = None + self._waiting: bool = False + self._close_wait: Optional[asyncio.Future[None]] = None self._exception: Optional[BaseException] = None self._compress = compress self._client_notakeover = client_notakeover + self._ping_task: Optional[asyncio.Task[None]] = None self._reset_heartbeat() def _cancel_heartbeat(self) -> None: - if self._pong_response_cb is not None: - self._pong_response_cb.cancel() - self._pong_response_cb = None - + self._cancel_pong_response_cb() if self._heartbeat_cb is not None: self._heartbeat_cb.cancel() self._heartbeat_cb = None + if self._ping_task is not None: + self._ping_task.cancel() + self._ping_task = None - def _reset_heartbeat(self) -> None: - self._cancel_heartbeat() + def _cancel_pong_response_cb(self) -> None: + if self._pong_response_cb is not None: + self._pong_response_cb.cancel() + self._pong_response_cb = None - if self._heartbeat is not None: - self._heartbeat_cb = call_later( - self._send_heartbeat, - self._heartbeat, - self._loop, - timeout_ceil_threshold=self._conn._connector._timeout_ceil_threshold - if self._conn is not None - else 5, - ) + def _reset_heartbeat(self) -> None: + if self._heartbeat is None: + return + self._cancel_pong_response_cb() + loop = self._loop + assert loop is not None + conn = self._conn + timeout_ceil_threshold = ( + conn._connector._timeout_ceil_threshold if conn is not None else 5 + ) + now = loop.time() + when = calculate_timeout_when(now, self._heartbeat, timeout_ceil_threshold) + self._heartbeat_when = when + if self._heartbeat_cb is None: + # We do not cancel the previous heartbeat_cb here because + # it generates a significant amount of TimerHandle churn + # which causes asyncio to rebuild the heap frequently. + # Instead _send_heartbeat() will reschedule the next + # heartbeat if it fires too early. + self._heartbeat_cb = loop.call_at(when, self._send_heartbeat) def _send_heartbeat(self) -> None: - if self._heartbeat is not None and not self._closed: - # fire-and-forget a task is not perfect but maybe ok for - # sending ping. Otherwise we need a long-living heartbeat - # task in the class. - self._loop.create_task(self._writer.ping()) - - if self._pong_response_cb is not None: - self._pong_response_cb.cancel() - self._pong_response_cb = call_later( - self._pong_not_received, - self._pong_heartbeat, - self._loop, - timeout_ceil_threshold=self._conn._connector._timeout_ceil_threshold - if self._conn is not None - else 5, + self._heartbeat_cb = None + loop = self._loop + now = loop.time() + if now < self._heartbeat_when: + # Heartbeat fired too early, reschedule + self._heartbeat_cb = loop.call_at( + self._heartbeat_when, self._send_heartbeat ) + return + + conn = self._conn + timeout_ceil_threshold = ( + conn._connector._timeout_ceil_threshold if conn is not None else 5 + ) + when = calculate_timeout_when(now, self._pong_heartbeat, timeout_ceil_threshold) + self._cancel_pong_response_cb() + self._pong_response_cb = loop.call_at(when, self._pong_not_received) + + if sys.version_info >= (3, 12): + # Optimization for Python 3.12, try to send the ping + # immediately to avoid having to schedule + # the task on the event loop. + ping_task = asyncio.Task(self._writer.ping(), loop=loop, eager_start=True) + else: + ping_task = loop.create_task(self._writer.ping()) + + if not ping_task.done(): + self._ping_task = ping_task + ping_task.add_done_callback(self._ping_task_done) + + def _ping_task_done(self, task: "asyncio.Task[None]") -> None: + """Callback for when the ping task completes.""" + self._ping_task = None def _pong_not_received(self) -> None: if not self._closed: - self._closed = True + self._set_closed() self._close_code = WSCloseCode.ABNORMAL_CLOSURE - self._exception = asyncio.TimeoutError() + self._exception = ServerTimeoutError() self._response.close() + if self._waiting and not self._closing: + self._reader.feed_data( + WSMessage(WSMsgType.ERROR, self._exception, None) + ) + + def _set_closed(self) -> None: + """Set the connection to closed. + + Cancel any heartbeat timers and set the closed flag. + """ + self._closed = True + self._cancel_heartbeat() + + def _set_closing(self) -> None: + """Set the connection to closing. + + Cancel any heartbeat timers and set the closing flag. + """ + self._closing = True + self._cancel_heartbeat() @property def closed(self) -> bool: @@ -181,14 +234,15 @@ async def send_json( async def close(self, *, code: int = WSCloseCode.OK, message: bytes = b"") -> bool: # we need to break `receive()` cycle first, # `close()` may be called from different task - if self._waiting is not None and not self._closing: - self._closing = True + if self._waiting and not self._closing: + assert self._loop is not None + self._close_wait = self._loop.create_future() + self._set_closing() self._reader.feed_data(WS_CLOSING_MESSAGE, 0) - await self._waiting + await self._close_wait if not self._closed: - self._cancel_heartbeat() - self._closed = True + self._set_closed() try: await self._writer.close(code, message) except asyncio.CancelledError: @@ -219,7 +273,7 @@ async def close(self, *, code: int = WSCloseCode.OK, message: bytes = b"") -> bo self._response.close() return True - if msg.type == WSMsgType.CLOSE: + if msg.type is WSMsgType.CLOSE: self._close_code = msg.data self._response.close() return True @@ -228,7 +282,7 @@ async def close(self, *, code: int = WSCloseCode.OK, message: bytes = b"") -> bo async def receive(self, timeout: Optional[float] = None) -> WSMessage: while True: - if self._waiting is not None: + if self._waiting: raise RuntimeError("Concurrent call to receive() is not allowed") if self._closed: @@ -238,15 +292,15 @@ async def receive(self, timeout: Optional[float] = None) -> WSMessage: return WS_CLOSED_MESSAGE try: - self._waiting = self._loop.create_future() + self._waiting = True try: async with async_timeout.timeout(timeout or self._receive_timeout): msg = await self._reader.read() self._reset_heartbeat() finally: - waiter = self._waiting - self._waiting = None - set_result(waiter, True) + self._waiting = False + if self._close_wait: + set_result(self._close_wait, None) except (asyncio.CancelledError, asyncio.TimeoutError): self._close_code = WSCloseCode.ABNORMAL_CLOSURE raise @@ -255,7 +309,8 @@ async def receive(self, timeout: Optional[float] = None) -> WSMessage: await self.close() return WSMessage(WSMsgType.CLOSED, None, None) except ClientError: - self._closed = True + # Likely ServerDisconnectedError when connection is lost + self._set_closed() self._close_code = WSCloseCode.ABNORMAL_CLOSURE return WS_CLOSED_MESSAGE except WebSocketError as exc: @@ -264,35 +319,35 @@ async def receive(self, timeout: Optional[float] = None) -> WSMessage: return WSMessage(WSMsgType.ERROR, exc, None) except Exception as exc: self._exception = exc - self._closing = True + self._set_closing() self._close_code = WSCloseCode.ABNORMAL_CLOSURE await self.close() return WSMessage(WSMsgType.ERROR, exc, None) - if msg.type == WSMsgType.CLOSE: - self._closing = True + if msg.type is WSMsgType.CLOSE: + self._set_closing() self._close_code = msg.data if not self._closed and self._autoclose: await self.close() - elif msg.type == WSMsgType.CLOSING: - self._closing = True - elif msg.type == WSMsgType.PING and self._autoping: + elif msg.type is WSMsgType.CLOSING: + self._set_closing() + elif msg.type is WSMsgType.PING and self._autoping: await self.pong(msg.data) continue - elif msg.type == WSMsgType.PONG and self._autoping: + elif msg.type is WSMsgType.PONG and self._autoping: continue return msg async def receive_str(self, *, timeout: Optional[float] = None) -> str: msg = await self.receive(timeout) - if msg.type != WSMsgType.TEXT: + if msg.type is not WSMsgType.TEXT: raise TypeError(f"Received message {msg.type}:{msg.data!r} is not str") return cast(str, msg.data) async def receive_bytes(self, *, timeout: Optional[float] = None) -> bytes: msg = await self.receive(timeout) - if msg.type != WSMsgType.BINARY: + if msg.type is not WSMsgType.BINARY: raise TypeError(f"Received message {msg.type}:{msg.data!r} is not bytes") return cast(bytes, msg.data) diff --git a/aiohttp/compression_utils.py b/aiohttp/compression_utils.py index 9631d377e9a..ab4a2f1cc84 100644 --- a/aiohttp/compression_utils.py +++ b/aiohttp/compression_utils.py @@ -50,9 +50,11 @@ def __init__( max_sync_chunk_size: Optional[int] = MAX_SYNC_CHUNK_SIZE, ): super().__init__( - mode=encoding_to_mode(encoding, suppress_deflate_header) - if wbits is None - else wbits, + mode=( + encoding_to_mode(encoding, suppress_deflate_header) + if wbits is None + else wbits + ), executor=executor, max_sync_chunk_size=max_sync_chunk_size, ) diff --git a/aiohttp/connector.py b/aiohttp/connector.py index f95ebe84c66..d4691b10e6e 100644 --- a/aiohttp/connector.py +++ b/aiohttp/connector.py @@ -1,6 +1,7 @@ import asyncio import functools import random +import socket import sys import traceback import warnings @@ -22,6 +23,7 @@ List, Literal, Optional, + Sequence, Set, Tuple, Type, @@ -29,10 +31,11 @@ cast, ) +import aiohappyeyeballs import attr from . import hdrs, helpers -from .abc import AbstractResolver +from .abc import AbstractResolver, ResolveResult from .client_exceptions import ( ClientConnectionError, ClientConnectorCertificateError, @@ -47,7 +50,7 @@ ) from .client_proto import ResponseHandler from .client_reqrep import ClientRequest, Fingerprint, _merge_ssl_params -from .helpers import ceil_timeout, get_running_loop, is_ip_address, noop, sentinel +from .helpers import ceil_timeout, is_ip_address, noop, sentinel from .locks import EventResultOrError from .resolver import DefaultResolver @@ -60,6 +63,14 @@ SSLContext = object # type: ignore[misc,assignment] +EMPTY_SCHEMA_SET = frozenset({""}) +HTTP_SCHEMA_SET = frozenset({"http", "https"}) +WS_SCHEMA_SET = frozenset({"ws", "wss"}) + +HTTP_AND_EMPTY_SCHEMA_SET = HTTP_SCHEMA_SET | EMPTY_SCHEMA_SET +HIGH_LEVEL_SCHEMA_SET = HTTP_AND_EMPTY_SCHEMA_SET | WS_SCHEMA_SET + + __all__ = ("BaseConnector", "TCPConnector", "UnixConnector", "NamedPipeConnector") @@ -208,6 +219,8 @@ class BaseConnector: # abort transport after 2 seconds (cleanup broken connections) _cleanup_closed_period = 2.0 + allowed_protocol_schema_set = HIGH_LEVEL_SCHEMA_SET + def __init__( self, *, @@ -229,7 +242,7 @@ def __init__( if keepalive_timeout is sentinel: keepalive_timeout = 15.0 - loop = get_running_loop(loop) + loop = loop or asyncio.get_running_loop() self._timeout_ceil_threshold = timeout_ceil_threshold self._closed = False @@ -240,9 +253,9 @@ def __init__( self._limit = limit self._limit_per_host = limit_per_host self._acquired: Set[ResponseHandler] = set() - self._acquired_per_host: DefaultDict[ - ConnectionKey, Set[ResponseHandler] - ] = defaultdict(set) + self._acquired_per_host: DefaultDict[ConnectionKey, Set[ResponseHandler]] = ( + defaultdict(set) + ) self._keepalive_timeout = cast(float, keepalive_timeout) self._force_close = force_close @@ -691,14 +704,14 @@ async def _create_connection( class _DNSCacheTable: def __init__(self, ttl: Optional[float] = None) -> None: - self._addrs_rr: Dict[Tuple[str, int], Tuple[Iterator[Dict[str, Any]], int]] = {} + self._addrs_rr: Dict[Tuple[str, int], Tuple[Iterator[ResolveResult], int]] = {} self._timestamps: Dict[Tuple[str, int], float] = {} self._ttl = ttl def __contains__(self, host: object) -> bool: return host in self._addrs_rr - def add(self, key: Tuple[str, int], addrs: List[Dict[str, Any]]) -> None: + def add(self, key: Tuple[str, int], addrs: List[ResolveResult]) -> None: self._addrs_rr[key] = (cycle(addrs), len(addrs)) if self._ttl is not None: @@ -714,7 +727,7 @@ def clear(self) -> None: self._addrs_rr.clear() self._timestamps.clear() - def next_addrs(self, key: Tuple[str, int]) -> List[Dict[str, Any]]: + def next_addrs(self, key: Tuple[str, int]) -> List[ResolveResult]: loop, length = self._addrs_rr[key] addrs = list(islice(loop, length)) # Consume one more element to shift internal state of `cycle` @@ -735,7 +748,7 @@ class TCPConnector(BaseConnector): fingerprint - Pass the binary sha256 digest of the expected certificate in DER format to verify that the certificate the server presents matches. See also - https://en.wikipedia.org/wiki/Transport_Layer_Security#Certificate_pinning + https://en.wikipedia.org/wiki/HTTP_Public_Key_Pinning resolver - Enable DNS lookups and use this resolver use_dns_cache - Use memory cache for DNS lookups. @@ -750,9 +763,15 @@ class TCPConnector(BaseConnector): limit_per_host - Number of simultaneous connections to one host. enable_cleanup_closed - Enables clean-up closed ssl transports. Disabled by default. + happy_eyeballs_delay - This is the “Connection Attempt Delay” + as defined in RFC 8305. To disable + the happy eyeballs algorithm, set to None. + interleave - “First Address Family Count” as defined in RFC 8305 loop - Optional event loop. """ + allowed_protocol_schema_set = HIGH_LEVEL_SCHEMA_SET | frozenset({"tcp"}) + def __init__( self, *, @@ -760,7 +779,7 @@ def __init__( fingerprint: Optional[bytes] = None, use_dns_cache: bool = True, ttl_dns_cache: Optional[int] = 10, - family: int = 0, + family: socket.AddressFamily = socket.AddressFamily.AF_UNSPEC, ssl_context: Optional[SSLContext] = None, ssl: Union[bool, Fingerprint, SSLContext] = True, local_addr: Optional[Tuple[str, int]] = None, @@ -772,6 +791,8 @@ def __init__( enable_cleanup_closed: bool = False, loop: Optional[asyncio.AbstractEventLoop] = None, timeout_ceil_threshold: float = 5, + happy_eyeballs_delay: Optional[float] = 0.25, + interleave: Optional[int] = None, ): super().__init__( keepalive_timeout=keepalive_timeout, @@ -792,7 +813,9 @@ def __init__( self._cached_hosts = _DNSCacheTable(ttl=ttl_dns_cache) self._throttle_dns_events: Dict[Tuple[str, int], EventResultOrError] = {} self._family = family - self._local_addr = local_addr + self._local_addr_infos = aiohappyeyeballs.addr_to_addr_infos(local_addr) + self._happy_eyeballs_delay = happy_eyeballs_delay + self._interleave = interleave def close(self) -> Awaitable[None]: """Close all ongoing DNS calls.""" @@ -823,8 +846,8 @@ def clear_dns_cache( self._cached_hosts.clear() async def _resolve_host( - self, host: str, port: int, traces: Optional[List["Trace"]] = None - ) -> List[Dict[str, Any]]: + self, host: str, port: int, traces: Optional[Sequence["Trace"]] = None + ) -> List[ResolveResult]: """Resolve host and return list of addresses.""" if is_ip_address(host): return [ @@ -880,7 +903,7 @@ async def _resolve_host( return await asyncio.shield(resolved_host_task) except asyncio.CancelledError: - def drop_exception(fut: "asyncio.Future[List[Dict[str, Any]]]") -> None: + def drop_exception(fut: "asyncio.Future[List[ResolveResult]]") -> None: with suppress(Exception, asyncio.CancelledError): fut.result() @@ -892,8 +915,8 @@ async def _resolve_host_with_throttle( key: Tuple[str, int], host: str, port: int, - traces: Optional[List["Trace"]], - ) -> List[Dict[str, Any]]: + traces: Optional[Sequence["Trace"]], + ) -> List[ResolveResult]: """Resolve host with a dns events throttle.""" if key in self._throttle_dns_events: # get event early, before any await (#4014) @@ -1011,6 +1034,36 @@ def _get_fingerprint(self, req: ClientRequest) -> Optional["Fingerprint"]: return None async def _wrap_create_connection( + self, + *args: Any, + addr_infos: List[aiohappyeyeballs.AddrInfoType], + req: ClientRequest, + timeout: "ClientTimeout", + client_error: Type[Exception] = ClientConnectorError, + **kwargs: Any, + ) -> Tuple[asyncio.Transport, ResponseHandler]: + try: + async with ceil_timeout( + timeout.sock_connect, ceil_threshold=timeout.ceil_threshold + ): + sock = await aiohappyeyeballs.start_connection( + addr_infos=addr_infos, + local_addr_infos=self._local_addr_infos, + happy_eyeballs_delay=self._happy_eyeballs_delay, + interleave=self._interleave, + loop=self._loop, + ) + return await self._loop.create_connection(*args, **kwargs, sock=sock) + except cert_errors as exc: + raise ClientConnectorCertificateError(req.connection_key, exc) from exc + except ssl_errors as exc: + raise ClientConnectorSSLError(req.connection_key, exc) from exc + except OSError as exc: + if exc.errno is None and isinstance(exc, asyncio.TimeoutError): + raise + raise client_error(req.connection_key, exc) from exc + + async def _wrap_existing_connection( self, *args: Any, req: ClientRequest, @@ -1176,6 +1229,27 @@ async def _start_tls_connection( return tls_transport, tls_proto + def _convert_hosts_to_addr_infos( + self, hosts: List[ResolveResult] + ) -> List[aiohappyeyeballs.AddrInfoType]: + """Converts the list of hosts to a list of addr_infos. + + The list of hosts is the result of a DNS lookup. The list of + addr_infos is the result of a call to `socket.getaddrinfo()`. + """ + addr_infos: List[aiohappyeyeballs.AddrInfoType] = [] + for hinfo in hosts: + host = hinfo["host"] + is_ipv6 = ":" in host + family = socket.AF_INET6 if is_ipv6 else socket.AF_INET + if self._family and self._family != family: + continue + addr = (host, hinfo["port"], 0, 0) if is_ipv6 else (host, hinfo["port"]) + addr_infos.append( + (family, socket.SOCK_STREAM, socket.IPPROTO_TCP, "", addr) + ) + return addr_infos + async def _create_direct_connection( self, req: ClientRequest, @@ -1209,36 +1283,27 @@ async def _create_direct_connection( raise ClientConnectorError(req.connection_key, exc) from exc last_exc: Optional[Exception] = None - - for hinfo in hosts: - host = hinfo["host"] - port = hinfo["port"] - + addr_infos = self._convert_hosts_to_addr_infos(hosts) + while addr_infos: # Strip trailing dots, certificates contain FQDN without dots. # See https://github.com/aio-libs/aiohttp/issues/3636 server_hostname = ( - (req.server_hostname or hinfo["hostname"]).rstrip(".") - if sslcontext - else None + (req.server_hostname or host).rstrip(".") if sslcontext else None ) try: transp, proto = await self._wrap_create_connection( self._factory, - host, - port, timeout=timeout, ssl=sslcontext, - family=hinfo["family"], - proto=hinfo["proto"], - flags=hinfo["flags"], + addr_infos=addr_infos, server_hostname=server_hostname, - local_addr=self._local_addr, req=req, client_error=client_error, ) except ClientConnectorError as exc: last_exc = exc + aiohappyeyeballs.pop_addr_infos_interleave(addr_infos, self._interleave) continue if req.is_ssl() and fingerprint: @@ -1249,6 +1314,10 @@ async def _create_direct_connection( if not self._cleanup_closed_disabled: self._cleanup_closed_transports.append(transp) last_exc = exc + # Remove the bad peer from the list of addr_infos + sock: socket.socket = transp.get_extra_info("socket") + bad_peer = sock.getpeername() + aiohappyeyeballs.remove_addr_infos(addr_infos, bad_peer) continue return transp, proto @@ -1367,7 +1436,7 @@ async def _create_proxy_connection( if not runtime_has_start_tls: # HTTP proxy with support for upgrade to HTTPS sslcontext = self._get_ssl_context(req) - return await self._wrap_create_connection( + return await self._wrap_existing_connection( self._factory, timeout=timeout, ssl=sslcontext, @@ -1401,6 +1470,8 @@ class UnixConnector(BaseConnector): loop - Optional event loop. """ + allowed_protocol_schema_set = HIGH_LEVEL_SCHEMA_SET | frozenset({"unix"}) + def __init__( self, path: str, @@ -1457,6 +1528,8 @@ class NamedPipeConnector(BaseConnector): loop - Optional event loop. """ + allowed_protocol_schema_set = HIGH_LEVEL_SCHEMA_SET | frozenset({"npipe"}) + def __init__( self, path: str, diff --git a/aiohttp/cookiejar.py b/aiohttp/cookiejar.py index a348f112cb5..e9997ce2935 100644 --- a/aiohttp/cookiejar.py +++ b/aiohttp/cookiejar.py @@ -2,6 +2,7 @@ import calendar import contextlib import datetime +import itertools import os # noqa import pathlib import pickle @@ -10,7 +11,7 @@ from collections import defaultdict from http.cookies import BaseCookie, Morsel, SimpleCookie from math import ceil -from typing import ( # noqa +from typing import ( DefaultDict, Dict, Iterable, @@ -35,6 +36,10 @@ CookieItem = Union[str, "Morsel[str]"] +# We cache these string methods here as their use is in performance critical code. +_FORMAT_PATH = "{}/{}".format +_FORMAT_DOMAIN_REVERSED = "{1}.{0}".format + class CookieJar(AbstractCookieJar): """Implements cookie storage adhering to RFC 6265.""" @@ -153,7 +158,12 @@ def __iter__(self) -> "Iterator[Morsel[str]]": yield from val.values() def __len__(self) -> int: - return sum(1 for i in self) + """Return number of cookies. + + This function does not iterate self to avoid unnecessary expiration + checks. + """ + return sum(len(cookie.values()) for cookie in self._cookies.values()) def _do_expiration(self) -> None: self.clear(lambda x: False) @@ -211,6 +221,7 @@ def update_cookies(self, cookies: LooseCookies, response_url: URL = URL()) -> No # Cut everything from the last slash to the end path = "/" + path[1 : path.rfind("/")] cookie["path"] = path + path = path.rstrip("/") max_age = cookie["max-age"] if max_age: @@ -256,26 +267,40 @@ def filter_cookies(self, request_url: URL = URL()) -> "BaseCookie[str]": request_origin = request_url.origin() is_not_secure = request_origin not in self._treat_as_secure_origin + # Send shared cookie + for c in self._cookies[("", "")].values(): + filtered[c.key] = c.value + + if is_ip_address(hostname): + if not self._unsafe: + return filtered + domains: Iterable[str] = (hostname,) + else: + # Get all the subdomains that might match a cookie (e.g. "foo.bar.com", "bar.com", "com") + domains = itertools.accumulate( + reversed(hostname.split(".")), _FORMAT_DOMAIN_REVERSED + ) + + # Get all the path prefixes that might match a cookie (e.g. "", "/foo", "/foo/bar") + paths = itertools.accumulate(request_url.path.split("/"), _FORMAT_PATH) + # Create every combination of (domain, path) pairs. + pairs = itertools.product(domains, paths) + # Point 2: https://www.rfc-editor.org/rfc/rfc6265.html#section-5.4 - for cookie in sorted(self, key=lambda c: len(c["path"])): + cookies = itertools.chain.from_iterable( + self._cookies[p].values() for p in pairs + ) + path_len = len(request_url.path) + for cookie in cookies: name = cookie.key domain = cookie["domain"] - # Send shared cookies - if not domain: - filtered[name] = cookie.value - continue - - if not self._unsafe and is_ip_address(hostname): - continue - if (domain, name) in self._host_only_cookies: if domain != hostname: continue - elif not self._is_domain_match(domain, hostname): - continue - if not self._is_path_match(request_url.path, cookie["path"]): + # Skip edge case when the cookie has a trailing slash but request doesn't. + if len(cookie["path"]) > path_len: continue if is_not_secure and cookie["secure"]: @@ -305,25 +330,6 @@ def _is_domain_match(domain: str, hostname: str) -> bool: return not is_ip_address(hostname) - @staticmethod - def _is_path_match(req_path: str, cookie_path: str) -> bool: - """Implements path matching adhering to RFC 6265.""" - if not req_path.startswith("/"): - req_path = "/" - - if req_path == cookie_path: - return True - - if not req_path.startswith(cookie_path): - return False - - if cookie_path.endswith("/"): - return True - - non_matching = req_path[len(cookie_path) :] - - return non_matching.startswith("/") - @classmethod def _parse_date(cls, date_str: str) -> Optional[int]: """Implements date string parsing adhering to RFC 6265.""" diff --git a/aiohttp/helpers.py b/aiohttp/helpers.py index 284033b7a04..437c871e8f7 100644 --- a/aiohttp/helpers.py +++ b/aiohttp/helpers.py @@ -14,7 +14,6 @@ import re import sys import time -import warnings import weakref from collections import namedtuple from contextlib import suppress @@ -52,7 +51,7 @@ from yarl import URL from . import hdrs -from .log import client_logger, internal_logger +from .log import client_logger if sys.version_info >= (3, 11): import asyncio as async_timeout @@ -287,38 +286,6 @@ def proxies_from_env() -> Dict[str, ProxyInfo]: return ret -def current_task( - loop: Optional[asyncio.AbstractEventLoop] = None, -) -> "Optional[asyncio.Task[Any]]": - return asyncio.current_task(loop=loop) - - -def get_running_loop( - loop: Optional[asyncio.AbstractEventLoop] = None, -) -> asyncio.AbstractEventLoop: - if loop is None: - loop = asyncio.get_event_loop() - if not loop.is_running(): - warnings.warn( - "The object should be created within an async function", - DeprecationWarning, - stacklevel=3, - ) - if loop.get_debug(): - internal_logger.warning( - "The object should be created within an async function", stack_info=True - ) - return loop - - -def isasyncgenfunction(obj: Any) -> bool: - func = getattr(inspect, "isasyncgenfunction", None) - if func is not None: - return func(obj) # type: ignore[no-any-return] - else: - return False - - def get_env_proxy_for_url(url: URL) -> Tuple[URL, Optional[BasicAuth]]: """Get a permitted proxy for the given URL from the env.""" if url.host is not None and proxy_bypass(url.host): @@ -619,12 +586,23 @@ def call_later( loop: asyncio.AbstractEventLoop, timeout_ceil_threshold: float = 5, ) -> Optional[asyncio.TimerHandle]: - if timeout is not None and timeout > 0: - when = loop.time() + timeout - if timeout > timeout_ceil_threshold: - when = ceil(when) - return loop.call_at(when, cb) - return None + if timeout is None or timeout <= 0: + return None + now = loop.time() + when = calculate_timeout_when(now, timeout, timeout_ceil_threshold) + return loop.call_at(when, cb) + + +def calculate_timeout_when( + loop_time: float, + timeout: float, + timeout_ceiling_threshold: float, +) -> float: + """Calculate when to execute a timeout.""" + when = loop_time + timeout + if timeout > timeout_ceiling_threshold: + return ceil(when) + return when class TimeoutHandle: @@ -709,7 +687,7 @@ def assert_timeout(self) -> None: raise asyncio.TimeoutError from None def __enter__(self) -> BaseTimerContext: - task = current_task(loop=self._loop) + task = asyncio.current_task(loop=self._loop) if task is None: raise RuntimeError( @@ -749,7 +727,7 @@ def ceil_timeout( if delay is None or delay <= 0: return async_timeout.timeout(None) - loop = get_running_loop() + loop = asyncio.get_running_loop() now = loop.time() when = now + delay if delay > ceil_threshold: @@ -818,8 +796,7 @@ def set_exception( self, exc: BaseException, exc_cause: BaseException = ..., - ) -> None: - ... # pragma: no cover + ) -> None: ... # pragma: no cover def set_exception( @@ -905,12 +882,10 @@ def __init_subclass__(cls) -> None: ) @overload # type: ignore[override] - def __getitem__(self, key: AppKey[_T]) -> _T: - ... + def __getitem__(self, key: AppKey[_T]) -> _T: ... @overload - def __getitem__(self, key: str) -> Any: - ... + def __getitem__(self, key: str) -> Any: ... def __getitem__(self, key: Union[str, AppKey[_T]]) -> Any: for mapping in self._maps: @@ -921,16 +896,13 @@ def __getitem__(self, key: Union[str, AppKey[_T]]) -> Any: raise KeyError(key) @overload # type: ignore[override] - def get(self, key: AppKey[_T], default: _S) -> Union[_T, _S]: - ... + def get(self, key: AppKey[_T], default: _S) -> Union[_T, _S]: ... @overload - def get(self, key: AppKey[_T], default: None = ...) -> Optional[_T]: - ... + def get(self, key: AppKey[_T], default: None = ...) -> Optional[_T]: ... @overload - def get(self, key: str, default: Any = ...) -> Any: - ... + def get(self, key: str, default: Any = ...) -> Any: ... def get(self, key: Union[str, AppKey[_T]], default: Any = None) -> Any: try: diff --git a/aiohttp/http_exceptions.py b/aiohttp/http_exceptions.py index 72eac3a3cac..c43ee0d9659 100644 --- a/aiohttp/http_exceptions.py +++ b/aiohttp/http_exceptions.py @@ -1,6 +1,5 @@ """Low-level http related exceptions.""" - from textwrap import indent from typing import Optional, Union diff --git a/aiohttp/http_parser.py b/aiohttp/http_parser.py index 013511917e8..751a7e1bb73 100644 --- a/aiohttp/http_parser.py +++ b/aiohttp/http_parser.py @@ -47,7 +47,6 @@ TransferEncodingError, ) from .http_writer import HttpVersion, HttpVersion10 -from .log import internal_logger from .streams import EMPTY_PAYLOAD, StreamReader from .typedefs import RawHeaders @@ -249,7 +248,6 @@ def __init__( timer: Optional[BaseTimerContext] = None, code: Optional[int] = None, method: Optional[str] = None, - readall: bool = False, payload_exception: Optional[Type[BaseException]] = None, response_with_body: bool = True, read_until_eof: bool = False, @@ -263,7 +261,6 @@ def __init__( self.timer = timer self.code = code self.method = method - self.readall = readall self.payload_exception = payload_exception self.response_with_body = response_with_body self.read_until_eof = read_until_eof @@ -393,7 +390,6 @@ def get_content_length() -> Optional[int]: method=method, compression=msg.compression, code=self.code, - readall=self.readall, response_with_body=self.response_with_body, auto_decompress=self._auto_decompress, lax=self.lax, @@ -413,7 +409,6 @@ def get_content_length() -> Optional[int]: payload, method=msg.method, compression=msg.compression, - readall=True, auto_decompress=self._auto_decompress, lax=self.lax, ) @@ -431,7 +426,6 @@ def get_content_length() -> Optional[int]: method=method, compression=msg.compression, code=self.code, - readall=True, response_with_body=self.response_with_body, auto_decompress=self._auto_decompress, lax=self.lax, @@ -751,13 +745,12 @@ def __init__( compression: Optional[str] = None, code: Optional[int] = None, method: Optional[str] = None, - readall: bool = False, response_with_body: bool = True, auto_decompress: bool = True, lax: bool = False, ) -> None: self._length = 0 - self._type = ParseState.PARSE_NONE + self._type = ParseState.PARSE_UNTIL_EOF self._chunk = ChunkState.PARSE_CHUNKED_SIZE self._chunk_size = 0 self._chunk_tail = b"" @@ -779,7 +772,6 @@ def __init__( self._type = ParseState.PARSE_NONE real_payload.feed_eof() self.done = True - elif chunked: self._type = ParseState.PARSE_CHUNKED elif length is not None: @@ -788,16 +780,6 @@ def __init__( if self._length == 0: real_payload.feed_eof() self.done = True - else: - if readall and code != 204: - self._type = ParseState.PARSE_UNTIL_EOF - elif method in ("PUT", "POST"): - internal_logger.warning( # pragma: no cover - "Content-Length or Transfer-Encoding header is required" - ) - self._type = ParseState.PARSE_NONE - real_payload.feed_eof() - self.done = True self.payload = real_payload diff --git a/aiohttp/payload.py b/aiohttp/payload.py index 6593b05c6f7..5271393612a 100644 --- a/aiohttp/payload.py +++ b/aiohttp/payload.py @@ -11,7 +11,6 @@ IO, TYPE_CHECKING, Any, - ByteString, Dict, Final, Iterable, @@ -217,7 +216,9 @@ async def write(self, writer: AbstractStreamWriter) -> None: class BytesPayload(Payload): - def __init__(self, value: ByteString, *args: Any, **kwargs: Any) -> None: + def __init__( + self, value: Union[bytes, bytearray, memoryview], *args: Any, **kwargs: Any + ) -> None: if not isinstance(value, (bytes, bytearray, memoryview)): raise TypeError(f"value argument must be byte-ish, not {type(value)!r}") diff --git a/aiohttp/pytest_plugin.py b/aiohttp/pytest_plugin.py index 5754747bf48..c862b409566 100644 --- a/aiohttp/pytest_plugin.py +++ b/aiohttp/pytest_plugin.py @@ -1,11 +1,21 @@ import asyncio import contextlib +import inspect import warnings -from typing import Any, Awaitable, Callable, Dict, Iterator, Optional, Type, Union +from typing import ( + Any, + Awaitable, + Callable, + Dict, + Iterator, + Optional, + Protocol, + Type, + Union, +) import pytest -from aiohttp.helpers import isasyncgenfunction from aiohttp.web import Application from .test_utils import ( @@ -24,9 +34,23 @@ except ImportError: # pragma: no cover uvloop = None # type: ignore[assignment] -AiohttpClient = Callable[[Union[Application, BaseTestServer]], Awaitable[TestClient]] AiohttpRawServer = Callable[[Application], Awaitable[RawTestServer]] -AiohttpServer = Callable[[Application], Awaitable[TestServer]] + + +class AiohttpClient(Protocol): + def __call__( + self, + __param: Union[Application, BaseTestServer], + *, + server_kwargs: Optional[Dict[str, Any]] = None, + **kwargs: Any + ) -> Awaitable[TestClient]: ... + + +class AiohttpServer(Protocol): + def __call__( + self, app: Application, *, port: Optional[int] = None, **kwargs: Any + ) -> Awaitable[TestServer]: ... def pytest_addoption(parser): # type: ignore[no-untyped-def] @@ -57,7 +81,7 @@ def pytest_fixture_setup(fixturedef): # type: ignore[no-untyped-def] """ func = fixturedef.func - if isasyncgenfunction(func): + if inspect.isasyncgenfunction(func): # async generator fixture is_async_gen = True elif asyncio.iscoroutinefunction(func): @@ -262,7 +286,9 @@ def aiohttp_server(loop: asyncio.AbstractEventLoop) -> Iterator[AiohttpServer]: """ servers = [] - async def go(app, *, port=None, **kwargs): # type: ignore[no-untyped-def] + async def go( + app: Application, *, port: Optional[int] = None, **kwargs: Any + ) -> TestServer: server = TestServer(app, port=port) await server.start_server(loop=loop, **kwargs) servers.append(server) diff --git a/aiohttp/resolver.py b/aiohttp/resolver.py index 6c17b1e7e89..10e36266abe 100644 --- a/aiohttp/resolver.py +++ b/aiohttp/resolver.py @@ -1,20 +1,24 @@ import asyncio import socket -from typing import Any, Dict, List, Optional, Type, Union +import sys +from typing import Any, Dict, List, Optional, Tuple, Type, Union -from .abc import AbstractResolver -from .helpers import get_running_loop +from .abc import AbstractResolver, ResolveResult __all__ = ("ThreadedResolver", "AsyncResolver", "DefaultResolver") + try: import aiodns - # aiodns_default = hasattr(aiodns.DNSResolver, 'gethostbyname') + aiodns_default = hasattr(aiodns.DNSResolver, "getaddrinfo") except ImportError: # pragma: no cover - aiodns = None + aiodns = None # type: ignore[assignment] + aiodns_default = False + -aiodns_default = False +_NUMERIC_SOCKET_FLAGS = socket.AI_NUMERICHOST | socket.AI_NUMERICSERV +_SUPPORTS_SCOPE_ID = sys.version_info >= (3, 9, 0) class ThreadedResolver(AbstractResolver): @@ -25,48 +29,48 @@ class ThreadedResolver(AbstractResolver): """ def __init__(self, loop: Optional[asyncio.AbstractEventLoop] = None) -> None: - self._loop = get_running_loop(loop) + self._loop = loop or asyncio.get_running_loop() async def resolve( - self, hostname: str, port: int = 0, family: int = socket.AF_INET - ) -> List[Dict[str, Any]]: + self, host: str, port: int = 0, family: socket.AddressFamily = socket.AF_INET + ) -> List[ResolveResult]: infos = await self._loop.getaddrinfo( - hostname, + host, port, type=socket.SOCK_STREAM, family=family, flags=socket.AI_ADDRCONFIG, ) - hosts = [] + hosts: List[ResolveResult] = [] for family, _, proto, _, address in infos: if family == socket.AF_INET6: if len(address) < 3: # IPv6 is not supported by Python build, # or IPv6 is not enabled in the host continue - if address[3]: + if address[3] and _SUPPORTS_SCOPE_ID: # This is essential for link-local IPv6 addresses. # LL IPv6 is a VERY rare case. Strictly speaking, we should use # getnameinfo() unconditionally, but performance makes sense. - host, _port = socket.getnameinfo( - address, socket.NI_NUMERICHOST | socket.NI_NUMERICSERV + resolved_host, _port = await self._loop.getnameinfo( + address, _NUMERIC_SOCKET_FLAGS ) port = int(_port) else: - host, port = address[:2] + resolved_host, port = address[:2] else: # IPv4 assert family == socket.AF_INET - host, port = address # type: ignore[misc] + resolved_host, port = address # type: ignore[misc] hosts.append( - { - "hostname": hostname, - "host": host, - "port": port, - "family": family, - "proto": proto, - "flags": socket.AI_NUMERICHOST | socket.AI_NUMERICSERV, - } + ResolveResult( + hostname=host, + host=resolved_host, + port=port, + family=family, + proto=proto, + flags=_NUMERIC_SOCKET_FLAGS, + ) ) return hosts @@ -87,32 +91,56 @@ def __init__( if aiodns is None: raise RuntimeError("Resolver requires aiodns library") - self._loop = get_running_loop(loop) - self._resolver = aiodns.DNSResolver(*args, loop=loop, **kwargs) + self._resolver = aiodns.DNSResolver(*args, **kwargs) if not hasattr(self._resolver, "gethostbyname"): # aiodns 1.1 is not available, fallback to DNSResolver.query self.resolve = self._resolve_with_query # type: ignore async def resolve( - self, host: str, port: int = 0, family: int = socket.AF_INET - ) -> List[Dict[str, Any]]: + self, host: str, port: int = 0, family: socket.AddressFamily = socket.AF_INET + ) -> List[ResolveResult]: try: - resp = await self._resolver.gethostbyname(host, family) + resp = await self._resolver.getaddrinfo( + host, + port=port, + type=socket.SOCK_STREAM, + family=family, + flags=socket.AI_ADDRCONFIG, + ) except aiodns.error.DNSError as exc: msg = exc.args[1] if len(exc.args) >= 1 else "DNS lookup failed" raise OSError(msg) from exc - hosts = [] - for address in resp.addresses: + hosts: List[ResolveResult] = [] + for node in resp.nodes: + address: Union[Tuple[bytes, int], Tuple[bytes, int, int, int]] = node.addr + family = node.family + if family == socket.AF_INET6: + if len(address) > 3 and address[3] and _SUPPORTS_SCOPE_ID: + # This is essential for link-local IPv6 addresses. + # LL IPv6 is a VERY rare case. Strictly speaking, we should use + # getnameinfo() unconditionally, but performance makes sense. + result = await self._resolver.getnameinfo( + (address[0].decode("ascii"), *address[1:]), + _NUMERIC_SOCKET_FLAGS, + ) + resolved_host = result.node + else: + resolved_host = address[0].decode("ascii") + port = address[1] + else: # IPv4 + assert family == socket.AF_INET + resolved_host = address[0].decode("ascii") + port = address[1] hosts.append( - { - "hostname": host, - "host": address, - "port": port, - "family": family, - "proto": 0, - "flags": socket.AI_NUMERICHOST | socket.AI_NUMERICSERV, - } + ResolveResult( + hostname=host, + host=resolved_host, + port=port, + family=family, + proto=0, + flags=_NUMERIC_SOCKET_FLAGS, + ) ) if not hosts: diff --git a/aiohttp/test_utils.py b/aiohttp/test_utils.py index a36e8599689..97c1469dd2a 100644 --- a/aiohttp/test_utils.py +++ b/aiohttp/test_utils.py @@ -11,17 +11,7 @@ import warnings from abc import ABC, abstractmethod from types import TracebackType -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Iterator, - List, - Optional, - Type, - Union, - cast, -) +from typing import TYPE_CHECKING, Any, Callable, Iterator, List, Optional, Type, cast from unittest import IsolatedAsyncioTestCase, mock from aiosignal import Signal @@ -29,7 +19,11 @@ from yarl import URL import aiohttp -from aiohttp.client import _RequestContextManager, _WSRequestContextManager +from aiohttp.client import ( + _RequestContextManager, + _RequestOptions, + _WSRequestContextManager, +) from . import ClientSession, hdrs from .abc import AbstractCookieJar @@ -55,6 +49,9 @@ else: SSLContext = None +if sys.version_info >= (3, 11) and TYPE_CHECKING: + from typing import Unpack + REUSE_ADDRESS = os.name == "posix" and sys.platform != "cygwin" @@ -90,7 +87,7 @@ class BaseTestServer(ABC): def __init__( self, *, - scheme: Union[str, object] = sentinel, + scheme: str = "", loop: Optional[asyncio.AbstractEventLoop] = None, host: str = "127.0.0.1", port: Optional[int] = None, @@ -135,12 +132,8 @@ async def start_server( sockets = server.sockets # type: ignore[attr-defined] assert sockets is not None self.port = sockets[0].getsockname()[1] - if self.scheme is sentinel: - if self._ssl: - scheme = "https" - else: - scheme = "http" - self.scheme = scheme + if not self.scheme: + self.scheme = "https" if self._ssl else "http" self._root = URL(f"{self.scheme}://{self.host}:{self.port}") @abstractmethod # pragma: no cover @@ -222,7 +215,7 @@ def __init__( self, app: Application, *, - scheme: Union[str, object] = sentinel, + scheme: str = "", host: str = "127.0.0.1", port: Optional[int] = None, **kwargs: Any, @@ -239,7 +232,7 @@ def __init__( self, handler: _RequestHandler, *, - scheme: Union[str, object] = sentinel, + scheme: str = "", host: str = "127.0.0.1", port: Optional[int] = None, **kwargs: Any, @@ -324,45 +317,101 @@ async def _request( self._responses.append(resp) return resp - def request( - self, method: str, path: StrOrURL, **kwargs: Any - ) -> _RequestContextManager: - """Routes a request to tested http server. + if sys.version_info >= (3, 11) and TYPE_CHECKING: + + def request( + self, method: str, path: StrOrURL, **kwargs: Unpack[_RequestOptions] + ) -> _RequestContextManager: ... + + def get( + self, + path: StrOrURL, + **kwargs: Unpack[_RequestOptions], + ) -> _RequestContextManager: ... + + def options( + self, + path: StrOrURL, + **kwargs: Unpack[_RequestOptions], + ) -> _RequestContextManager: ... + + def head( + self, + path: StrOrURL, + **kwargs: Unpack[_RequestOptions], + ) -> _RequestContextManager: ... + + def post( + self, + path: StrOrURL, + **kwargs: Unpack[_RequestOptions], + ) -> _RequestContextManager: ... + + def put( + self, + path: StrOrURL, + **kwargs: Unpack[_RequestOptions], + ) -> _RequestContextManager: ... + + def patch( + self, + path: StrOrURL, + **kwargs: Unpack[_RequestOptions], + ) -> _RequestContextManager: ... + + def delete( + self, + path: StrOrURL, + **kwargs: Unpack[_RequestOptions], + ) -> _RequestContextManager: ... - The interface is identical to aiohttp.ClientSession.request, - except the loop kwarg is overridden by the instance used by the - test server. + else: - """ - return _RequestContextManager(self._request(method, path, **kwargs)) + def request( + self, method: str, path: StrOrURL, **kwargs: Any + ) -> _RequestContextManager: + """Routes a request to tested http server. - def get(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager: - """Perform an HTTP GET request.""" - return _RequestContextManager(self._request(hdrs.METH_GET, path, **kwargs)) + The interface is identical to aiohttp.ClientSession.request, + except the loop kwarg is overridden by the instance used by the + test server. - def post(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager: - """Perform an HTTP POST request.""" - return _RequestContextManager(self._request(hdrs.METH_POST, path, **kwargs)) + """ + return _RequestContextManager(self._request(method, path, **kwargs)) - def options(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager: - """Perform an HTTP OPTIONS request.""" - return _RequestContextManager(self._request(hdrs.METH_OPTIONS, path, **kwargs)) + def get(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager: + """Perform an HTTP GET request.""" + return _RequestContextManager(self._request(hdrs.METH_GET, path, **kwargs)) - def head(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager: - """Perform an HTTP HEAD request.""" - return _RequestContextManager(self._request(hdrs.METH_HEAD, path, **kwargs)) + def post(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager: + """Perform an HTTP POST request.""" + return _RequestContextManager(self._request(hdrs.METH_POST, path, **kwargs)) - def put(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager: - """Perform an HTTP PUT request.""" - return _RequestContextManager(self._request(hdrs.METH_PUT, path, **kwargs)) + def options(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager: + """Perform an HTTP OPTIONS request.""" + return _RequestContextManager( + self._request(hdrs.METH_OPTIONS, path, **kwargs) + ) + + def head(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager: + """Perform an HTTP HEAD request.""" + return _RequestContextManager(self._request(hdrs.METH_HEAD, path, **kwargs)) - def patch(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager: - """Perform an HTTP PATCH request.""" - return _RequestContextManager(self._request(hdrs.METH_PATCH, path, **kwargs)) + def put(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager: + """Perform an HTTP PUT request.""" + return _RequestContextManager(self._request(hdrs.METH_PUT, path, **kwargs)) - def delete(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager: - """Perform an HTTP PATCH request.""" - return _RequestContextManager(self._request(hdrs.METH_DELETE, path, **kwargs)) + def patch(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager: + """Perform an HTTP PATCH request.""" + return _RequestContextManager( + self._request(hdrs.METH_PATCH, path, **kwargs) + ) + + def delete(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager: + """Perform an HTTP PATCH request.""" + return _RequestContextManager( + self._request(hdrs.METH_DELETE, path, **kwargs) + ) def ws_connect(self, path: StrOrURL, **kwargs: Any) -> _WSRequestContextManager: """Initiate websocket connection. diff --git a/aiohttp/tracing.py b/aiohttp/tracing.py index 62847a0bf7c..012ed7bdaf6 100644 --- a/aiohttp/tracing.py +++ b/aiohttp/tracing.py @@ -1,5 +1,5 @@ from types import SimpleNamespace -from typing import TYPE_CHECKING, Awaitable, Optional, Protocol, Type, TypeVar +from typing import TYPE_CHECKING, Awaitable, Mapping, Optional, Protocol, Type, TypeVar import attr from aiosignal import Signal @@ -19,8 +19,7 @@ def __call__( __client_session: ClientSession, __trace_config_ctx: SimpleNamespace, __params: _ParamT_contra, - ) -> Awaitable[None]: - ... + ) -> Awaitable[None]: ... __all__ = ( @@ -50,9 +49,9 @@ class TraceConfig: def __init__( self, trace_config_ctx_factory: Type[SimpleNamespace] = SimpleNamespace ) -> None: - self._on_request_start: Signal[ - _SignalCallback[TraceRequestStartParams] - ] = Signal(self) + self._on_request_start: Signal[_SignalCallback[TraceRequestStartParams]] = ( + Signal(self) + ) self._on_request_chunk_sent: Signal[ _SignalCallback[TraceRequestChunkSentParams] ] = Signal(self) @@ -89,12 +88,12 @@ def __init__( self._on_dns_resolvehost_end: Signal[ _SignalCallback[TraceDnsResolveHostEndParams] ] = Signal(self) - self._on_dns_cache_hit: Signal[ - _SignalCallback[TraceDnsCacheHitParams] - ] = Signal(self) - self._on_dns_cache_miss: Signal[ - _SignalCallback[TraceDnsCacheMissParams] - ] = Signal(self) + self._on_dns_cache_hit: Signal[_SignalCallback[TraceDnsCacheHitParams]] = ( + Signal(self) + ) + self._on_dns_cache_miss: Signal[_SignalCallback[TraceDnsCacheMissParams]] = ( + Signal(self) + ) self._on_request_headers_sent: Signal[ _SignalCallback[TraceRequestHeadersSentParams] ] = Signal(self) @@ -102,7 +101,7 @@ def __init__( self._trace_config_ctx_factory = trace_config_ctx_factory def trace_config_ctx( - self, trace_request_ctx: Optional[SimpleNamespace] = None + self, trace_request_ctx: Optional[Mapping[str, str]] = None ) -> SimpleNamespace: """Return a new trace_config_ctx instance""" return self._trace_config_ctx_factory(trace_request_ctx=trace_request_ctx) diff --git a/aiohttp/typedefs.py b/aiohttp/typedefs.py index 5e963e1a10e..9fb21c15f83 100644 --- a/aiohttp/typedefs.py +++ b/aiohttp/typedefs.py @@ -7,6 +7,7 @@ Callable, Iterable, Mapping, + Protocol, Tuple, Union, ) @@ -34,7 +35,13 @@ Byteish = Union[bytes, bytearray, memoryview] JSONEncoder = Callable[[Any], str] JSONDecoder = Callable[[str], Any] -LooseHeaders = Union[Mapping[Union[str, istr], str], _CIMultiDict, _CIMultiDictProxy] +LooseHeaders = Union[ + Mapping[str, str], + Mapping[istr, str], + _CIMultiDict, + _CIMultiDictProxy, + Iterable[Tuple[Union[str, istr], str]], +] RawHeaders = Tuple[Tuple[bytes, bytes], ...] StrOrURL = Union[str, URL] @@ -49,6 +56,12 @@ ] Handler = Callable[["Request"], Awaitable["StreamResponse"]] -Middleware = Callable[["Request", Handler], Awaitable["StreamResponse"]] + + +class Middleware(Protocol): + def __call__( + self, request: "Request", handler: Handler + ) -> Awaitable["StreamResponse"]: ... + PathLike = Union[str, "os.PathLike[str]"] diff --git a/aiohttp/web.py b/aiohttp/web.py index e9116507f4e..8708f1fcbec 100644 --- a/aiohttp/web.py +++ b/aiohttp/web.py @@ -6,8 +6,6 @@ import warnings from argparse import ArgumentParser from collections.abc import Iterable -from contextlib import suppress -from functools import partial from importlib import import_module from typing import ( Any, @@ -21,7 +19,6 @@ Union, cast, ) -from weakref import WeakSet from .abc import AbstractAccessLogger from .helpers import AppKey as AppKey @@ -320,23 +317,6 @@ async def _run_app( reuse_port: Optional[bool] = None, handler_cancellation: bool = False, ) -> None: - async def wait( - starting_tasks: "WeakSet[asyncio.Task[object]]", shutdown_timeout: float - ) -> None: - # Wait for pending tasks for a given time limit. - t = asyncio.current_task() - assert t is not None - starting_tasks.add(t) - with suppress(asyncio.TimeoutError): - await asyncio.wait_for(_wait(starting_tasks), timeout=shutdown_timeout) - - async def _wait(exclude: "WeakSet[asyncio.Task[object]]") -> None: - t = asyncio.current_task() - assert t is not None - exclude.add(t) - while tasks := asyncio.all_tasks().difference(exclude): - await asyncio.wait(tasks) - # An internal function to actually do all dirty job for application running if asyncio.iscoroutine(app): app = await app @@ -355,12 +335,6 @@ async def _wait(exclude: "WeakSet[asyncio.Task[object]]") -> None: ) await runner.setup() - # On shutdown we want to avoid waiting on tasks which run forever. - # It's very likely that all tasks which run forever will have been created by - # the time we have completed the application startup (in runner.setup()), - # so we just record all running tasks here and exclude them later. - starting_tasks: "WeakSet[asyncio.Task[object]]" = WeakSet(asyncio.all_tasks()) - runner.shutdown_callback = partial(wait, starting_tasks, shutdown_timeout) sites: List[BaseSite] = [] diff --git a/aiohttp/web_app.py b/aiohttp/web_app.py index 91bf5fdac61..3b4b6489e60 100644 --- a/aiohttp/web_app.py +++ b/aiohttp/web_app.py @@ -76,6 +76,7 @@ _T = TypeVar("_T") _U = TypeVar("_U") +_Resource = TypeVar("_Resource", bound=AbstractResource) class Application(MutableMapping[Union[str, AppKey[Any]], Any]): @@ -183,12 +184,10 @@ def __eq__(self, other: object) -> bool: return self is other @overload # type: ignore[override] - def __getitem__(self, key: AppKey[_T]) -> _T: - ... + def __getitem__(self, key: AppKey[_T]) -> _T: ... @overload - def __getitem__(self, key: str) -> Any: - ... + def __getitem__(self, key: str) -> Any: ... def __getitem__(self, key: Union[str, AppKey[_T]]) -> Any: return self._state[key] @@ -202,12 +201,10 @@ def _check_frozen(self) -> None: ) @overload # type: ignore[override] - def __setitem__(self, key: AppKey[_T], value: _T) -> None: - ... + def __setitem__(self, key: AppKey[_T], value: _T) -> None: ... @overload - def __setitem__(self, key: str, value: Any) -> None: - ... + def __setitem__(self, key: str, value: Any) -> None: ... def __setitem__(self, key: Union[str, AppKey[_T]], value: Any) -> None: self._check_frozen() @@ -232,16 +229,13 @@ def __iter__(self) -> Iterator[Union[str, AppKey[Any]]]: return iter(self._state) @overload # type: ignore[override] - def get(self, key: AppKey[_T], default: None = ...) -> Optional[_T]: - ... + def get(self, key: AppKey[_T], default: None = ...) -> Optional[_T]: ... @overload - def get(self, key: AppKey[_T], default: _U) -> Union[_T, _U]: - ... + def get(self, key: AppKey[_T], default: _U) -> Union[_T, _U]: ... @overload - def get(self, key: str, default: Any = ...) -> Any: - ... + def get(self, key: str, default: Any = ...) -> Any: ... def get(self, key: Union[str, AppKey[_T]], default: Any = None) -> Any: return self._state.get(key, default) @@ -334,7 +328,7 @@ async def handler(app: "Application") -> None: reg_handler("on_shutdown") reg_handler("on_cleanup") - def add_subapp(self, prefix: str, subapp: "Application") -> AbstractResource: + def add_subapp(self, prefix: str, subapp: "Application") -> PrefixedSubAppResource: if not isinstance(prefix, str): raise TypeError("Prefix must be str") prefix = prefix.rstrip("/") @@ -344,8 +338,8 @@ def add_subapp(self, prefix: str, subapp: "Application") -> AbstractResource: return self._add_subapp(factory, subapp) def _add_subapp( - self, resource_factory: Callable[[], AbstractResource], subapp: "Application" - ) -> AbstractResource: + self, resource_factory: Callable[[], _Resource], subapp: "Application" + ) -> _Resource: if self.frozen: raise RuntimeError("Cannot add sub application to frozen application") if subapp.frozen: @@ -359,7 +353,7 @@ def _add_subapp( subapp._set_loop(self._loop) return resource - def add_domain(self, domain: str, subapp: "Application") -> AbstractResource: + def add_domain(self, domain: str, subapp: "Application") -> MatchedSubAppResource: if not isinstance(domain, str): raise TypeError("Domain must be str") elif "*" in domain: @@ -535,7 +529,7 @@ async def _handle(self, request: Request) -> StreamResponse: for m, new_style in app._middlewares_handlers: # type: ignore[union-attr] if new_style: handler = update_wrapper( - partial(m, handler=handler), handler + partial(m, handler=handler), handler # type: ignore[misc] ) else: handler = await m(app, handler) # type: ignore[arg-type,assignment] diff --git a/aiohttp/web_fileresponse.py b/aiohttp/web_fileresponse.py index 7dbe50f0a5a..0c23e375d25 100644 --- a/aiohttp/web_fileresponse.py +++ b/aiohttp/web_fileresponse.py @@ -1,7 +1,11 @@ import asyncio -import mimetypes import os import pathlib +import sys +from contextlib import suppress +from mimetypes import MimeTypes +from stat import S_ISREG +from types import MappingProxyType from typing import ( # noqa IO, TYPE_CHECKING, @@ -22,6 +26,8 @@ from .helpers import ETAG_ANY, ETag, must_be_empty_body from .typedefs import LooseHeaders, PathLike from .web_exceptions import ( + HTTPForbidden, + HTTPNotFound, HTTPNotModified, HTTPPartialContent, HTTPPreconditionFailed, @@ -40,6 +46,35 @@ NOSENDFILE: Final[bool] = bool(os.environ.get("AIOHTTP_NOSENDFILE")) +CONTENT_TYPES: Final[MimeTypes] = MimeTypes() + +if sys.version_info < (3, 9): + CONTENT_TYPES.encodings_map[".br"] = "br" + +# File extension to IANA encodings map that will be checked in the order defined. +ENCODING_EXTENSIONS = MappingProxyType( + {ext: CONTENT_TYPES.encodings_map[ext] for ext in (".br", ".gz")} +) + +FALLBACK_CONTENT_TYPE = "application/octet-stream" + +# Provide additional MIME type/extension pairs to be recognized. +# https://en.wikipedia.org/wiki/List_of_archive_formats#Compression_only +ADDITIONAL_CONTENT_TYPES = MappingProxyType( + { + "application/gzip": ".gz", + "application/x-brotli": ".br", + "application/x-bzip2": ".bz2", + "application/x-compress": ".Z", + "application/x-xz": ".xz", + } +) + +# Add custom pairs and clear the encodings map so guess_type ignores them. +CONTENT_TYPES.encodings_map.clear() +for content_type, extension in ADDITIONAL_CONTENT_TYPES.items(): + CONTENT_TYPES.add_type(content_type, extension) # type: ignore[attr-defined] + class FileResponse(StreamResponse): """A response object can be used to send files.""" @@ -124,35 +159,51 @@ async def _precondition_failed( self.content_length = 0 return await super().prepare(request) - def _get_file_path_stat_and_gzip( - self, check_for_gzipped_file: bool - ) -> Tuple[pathlib.Path, os.stat_result, bool]: - """Return the file path, stat result, and gzip status. + def _get_file_path_stat_encoding( + self, accept_encoding: str + ) -> Tuple[pathlib.Path, os.stat_result, Optional[str]]: + """Return the file path, stat result, and encoding. + + If an uncompressed file is returned, the encoding is set to + :py:data:`None`. This method should be called from a thread executor since it calls os.stat which may block. """ - filepath = self._path - if check_for_gzipped_file: - gzip_path = filepath.with_name(filepath.name + ".gz") - try: - return gzip_path, gzip_path.stat(), True - except OSError: - # Fall through and try the non-gzipped file - pass + file_path = self._path + for file_extension, file_encoding in ENCODING_EXTENSIONS.items(): + if file_encoding not in accept_encoding: + continue + + compressed_path = file_path.with_suffix(file_path.suffix + file_extension) + with suppress(OSError): + # Do not follow symlinks and ignore any non-regular files. + st = compressed_path.lstat() + if S_ISREG(st.st_mode): + return compressed_path, st, file_encoding - return filepath, filepath.stat(), False + # Fallback to the uncompressed file + return file_path, file_path.stat(), None async def prepare(self, request: "BaseRequest") -> Optional[AbstractStreamWriter]: - loop = asyncio.get_event_loop() + loop = asyncio.get_running_loop() # Encoding comparisons should be case-insensitive # https://www.rfc-editor.org/rfc/rfc9110#section-8.4.1 - check_for_gzipped_file = ( - "gzip" in request.headers.get(hdrs.ACCEPT_ENCODING, "").lower() - ) - filepath, st, gzip = await loop.run_in_executor( - None, self._get_file_path_stat_and_gzip, check_for_gzipped_file - ) + accept_encoding = request.headers.get(hdrs.ACCEPT_ENCODING, "").lower() + try: + file_path, st, file_encoding = await loop.run_in_executor( + None, self._get_file_path_stat_encoding, accept_encoding + ) + except OSError: + # Most likely to be FileNotFoundError or OSError for circular + # symlinks in python >= 3.13, so respond with 404. + self.set_status(HTTPNotFound.status_code) + return await super().prepare(request) + + # Forbid special files like sockets, pipes, devices, etc. + if not S_ISREG(st.st_mode): + self.set_status(HTTPForbidden.status_code) + return await super().prepare(request) etag_value = f"{st.st_mtime_ns:x}-{st.st_size:x}" last_modified = st.st_mtime @@ -182,15 +233,6 @@ async def prepare(self, request: "BaseRequest") -> Optional[AbstractStreamWriter ): return await self._not_modified(request, etag_value, last_modified) - if hdrs.CONTENT_TYPE not in self.headers: - ct, encoding = mimetypes.guess_type(str(filepath)) - if not ct: - ct = "application/octet-stream" - should_set_ct = True - else: - encoding = "gzip" if gzip else None - should_set_ct = False - status = self._status file_size = st.st_size count = file_size @@ -265,11 +307,16 @@ async def prepare(self, request: "BaseRequest") -> Optional[AbstractStreamWriter # return a HTTP 206 for a Range request. self.set_status(status) - if should_set_ct: - self.content_type = ct # type: ignore[assignment] - if encoding: - self.headers[hdrs.CONTENT_ENCODING] = encoding - if gzip: + # If the Content-Type header is not already set, guess it based on the + # extension of the request path. The encoding returned by guess_type + # can be ignored since the map was cleared above. + if hdrs.CONTENT_TYPE not in self.headers: + self.content_type = ( + CONTENT_TYPES.guess_type(self._path)[0] or FALLBACK_CONTENT_TYPE + ) + + if file_encoding: + self.headers[hdrs.CONTENT_ENCODING] = file_encoding self.headers[hdrs.VARY] = hdrs.ACCEPT_ENCODING # Disable compression if we are already sending # a compressed file since we don't want to double @@ -293,7 +340,12 @@ async def prepare(self, request: "BaseRequest") -> Optional[AbstractStreamWriter if count == 0 or must_be_empty_body(request.method, self.status): return await super().prepare(request) - fobj = await loop.run_in_executor(None, filepath.open, "rb") + try: + fobj = await loop.run_in_executor(None, file_path.open, "rb") + except PermissionError: + self.set_status(HTTPForbidden.status_code) + return await super().prepare(request) + if start: # be aware that start could be None or int=0 here. offset = start else: diff --git a/aiohttp/web_protocol.py b/aiohttp/web_protocol.py index f083b13eb0f..9ba05a08e75 100644 --- a/aiohttp/web_protocol.py +++ b/aiohttp/web_protocol.py @@ -148,6 +148,7 @@ class RequestHandler(BaseProtocol): "_lingering_time", "_messages", "_message_tail", + "_handler_waiter", "_waiter", "_task_handler", "_upgrade", @@ -204,6 +205,7 @@ def __init__( self._message_tail = b"" self._waiter: Optional[asyncio.Future[None]] = None + self._handler_waiter: Optional[asyncio.Future[None]] = None self._task_handler: Optional[asyncio.Task[None]] = None self._upgrade = False @@ -262,7 +264,12 @@ async def shutdown(self, timeout: Optional[float] = 15.0) -> None: if self._waiter: self._waiter.cancel() - # wait for handlers + # Wait for graceful handler completion + if self._handler_waiter is not None: + with suppress(asyncio.CancelledError, asyncio.TimeoutError): + async with ceil_timeout(timeout): + await self._handler_waiter + # Then cancel handler and wait with suppress(asyncio.CancelledError, asyncio.TimeoutError): async with ceil_timeout(timeout): if self._current_request is not None: @@ -445,7 +452,7 @@ async def _handle_request( start_time: float, request_handler: Callable[[BaseRequest], Awaitable[StreamResponse]], ) -> Tuple[StreamResponse, bool]: - assert self._request_handler is not None + self._handler_waiter = self._loop.create_future() try: try: self._current_request = request @@ -475,6 +482,8 @@ async def _handle_request( ) reset = await self.finish_response(request, resp, start_time) + finally: + self._handler_waiter.set_result(None) return resp, reset @@ -609,6 +618,7 @@ async def finish_response( can get exception information. Returns True if the client disconnects prematurely. """ + request._finish() if self._request_parser is not None: self._request_parser.set_upgraded(False) self._upgrade = False diff --git a/aiohttp/web_request.py b/aiohttp/web_request.py index 4bc670a798c..a485f0dcea6 100644 --- a/aiohttp/web_request.py +++ b/aiohttp/web_request.py @@ -99,10 +99,10 @@ class FileField: qdtext=_QDTEXT, quoted_pair=_QUOTED_PAIR ) -_FORWARDED_PAIR: Final[ - str -] = r"({token})=({token}|{quoted_string})(:\d{{1,4}})?".format( - token=_TOKEN, quoted_string=_QUOTED_STRING +_FORWARDED_PAIR: Final[str] = ( + r"({token})=({token}|{quoted_string})(:\d{{1,4}})?".format( + token=_TOKEN, quoted_string=_QUOTED_STRING + ) ) _QUOTED_PAIR_REPLACE_RE: Final[Pattern[str]] = re.compile(r"\\([\t !-~])") @@ -235,7 +235,8 @@ def clone( # a copy semantic dct["headers"] = CIMultiDictProxy(CIMultiDict(headers)) dct["raw_headers"] = tuple( - (k.encode("utf-8"), v.encode("utf-8")) for k, v in headers.items() + (k.encode("utf-8"), v.encode("utf-8")) + for k, v in dct["headers"].items() ) message = self._message._replace(**dct) @@ -819,6 +820,18 @@ async def _prepare_hook(self, response: StreamResponse) -> None: def _cancel(self, exc: BaseException) -> None: set_exception(self._payload, exc) + def _finish(self) -> None: + if self._post is None or self.content_type != "multipart/form-data": + return + + # NOTE: Release file descriptors for the + # NOTE: `tempfile.Temporaryfile`-created `_io.BufferedRandom` + # NOTE: instances of files sent within multipart request body + # NOTE: via HTTP POST request. + for file_name, file_field_object in self._post.items(): + if isinstance(file_field_object, FileField): + file_field_object.file.close() + class Request(BaseRequest): diff --git a/aiohttp/web_response.py b/aiohttp/web_response.py index 40d6f01ecaa..78d3fe32949 100644 --- a/aiohttp/web_response.py +++ b/aiohttp/web_response.py @@ -52,6 +52,7 @@ BaseClass = collections.abc.MutableMapping +# TODO(py311): Convert to StrEnum for wider use class ContentCoding(enum.Enum): # The content codings that we have support for. # @@ -175,7 +176,7 @@ def enable_compression( ) -> None: """Enables response compression encoding.""" # Backwards compatibility for when force was a bool <0.17. - if type(force) == bool: + if isinstance(force, bool): force = ContentCoding.deflate if force else ContentCoding.identity warnings.warn( "Using boolean for force is deprecated #3318", DeprecationWarning @@ -673,7 +674,7 @@ def body(self, body: bytes) -> None: # copy payload headers if body.headers: - for (key, value) in body.headers.items(): + for key, value in body.headers.items(): if key not in headers: headers[key] = value diff --git a/aiohttp/web_routedef.py b/aiohttp/web_routedef.py index d79cd32a14a..93802141c56 100644 --- a/aiohttp/web_routedef.py +++ b/aiohttp/web_routedef.py @@ -162,12 +162,10 @@ def __repr__(self) -> str: return f"" @overload - def __getitem__(self, index: int) -> AbstractRouteDef: - ... + def __getitem__(self, index: int) -> AbstractRouteDef: ... @overload - def __getitem__(self, index: slice) -> List[AbstractRouteDef]: - ... + def __getitem__(self, index: slice) -> List[AbstractRouteDef]: ... def __getitem__(self, index): # type: ignore[no-untyped-def] return self._items[index] diff --git a/aiohttp/web_runner.py b/aiohttp/web_runner.py index 19a4441658f..2fe229c4e50 100644 --- a/aiohttp/web_runner.py +++ b/aiohttp/web_runner.py @@ -3,7 +3,7 @@ import socket import warnings from abc import ABC, abstractmethod -from typing import Any, Awaitable, Callable, List, Optional, Set +from typing import Any, List, Optional, Set from yarl import URL @@ -238,14 +238,7 @@ async def start(self) -> None: class BaseRunner(ABC): - __slots__ = ( - "shutdown_callback", - "_handle_signals", - "_kwargs", - "_server", - "_sites", - "_shutdown_timeout", - ) + __slots__ = ("_handle_signals", "_kwargs", "_server", "_sites", "_shutdown_timeout") def __init__( self, @@ -254,7 +247,6 @@ def __init__( shutdown_timeout: float = 60.0, **kwargs: Any, ) -> None: - self.shutdown_callback: Optional[Callable[[], Awaitable[None]]] = None self._handle_signals = handle_signals self._kwargs = kwargs self._server: Optional[Server] = None @@ -312,10 +304,6 @@ async def cleanup(self) -> None: await asyncio.sleep(0) self._server.pre_shutdown() await self.shutdown() - - if self.shutdown_callback: - await self.shutdown_callback() - await self._server.shutdown(self._shutdown_timeout) await self._cleanup_server() diff --git a/aiohttp/web_server.py b/aiohttp/web_server.py index 52faacb164a..ffc198d5780 100644 --- a/aiohttp/web_server.py +++ b/aiohttp/web_server.py @@ -1,9 +1,9 @@ """Low level HTTP server.""" + import asyncio from typing import Any, Awaitable, Callable, Dict, List, Optional # noqa from .abc import AbstractStreamWriter -from .helpers import get_running_loop from .http_parser import RawRequestMessage from .streams import StreamReader from .web_protocol import RequestHandler, _RequestFactory, _RequestHandler @@ -22,7 +22,7 @@ def __init__( loop: Optional[asyncio.AbstractEventLoop] = None, **kwargs: Any ) -> None: - self._loop = get_running_loop(loop) + self._loop = loop or asyncio.get_running_loop() self._connections: Dict[RequestHandler, asyncio.Transport] = {} self._kwargs = kwargs self.requests_count = 0 @@ -43,7 +43,12 @@ def connection_lost( self, handler: RequestHandler, exc: Optional[BaseException] = None ) -> None: if handler in self._connections: - del self._connections[handler] + if handler._task_handler: + handler._task_handler.add_done_callback( + lambda f: self._connections.pop(handler, None) + ) + else: + del self._connections[handler] def _make_request( self, diff --git a/aiohttp/web_urldispatcher.py b/aiohttp/web_urldispatcher.py index 954291f6449..558fb7d0c9b 100644 --- a/aiohttp/web_urldispatcher.py +++ b/aiohttp/web_urldispatcher.py @@ -8,6 +8,7 @@ import keyword import os import re +import sys import warnings from contextlib import contextmanager from functools import wraps @@ -78,6 +79,12 @@ else: BaseDict = dict +CIRCULAR_SYMLINK_ERROR = ( + (OSError,) + if sys.version_info < (3, 10) and sys.platform.startswith("win32") + else (RuntimeError,) if sys.version_info < (3, 13) else () +) + YARL_VERSION: Final[Tuple[int, ...]] = tuple(map(int, yarl_version.split(".")[:2])) HTTP_METHOD_RE: Final[Pattern[str]] = re.compile( @@ -199,7 +206,7 @@ def __init__( @wraps(handler) async def handler_wrapper(request: Request) -> StreamResponse: - result = old_handler(request) + result = old_handler(request) # type: ignore[call-arg] if asyncio.iscoroutine(result): result = await result assert isinstance(result, StreamResponse) @@ -557,14 +564,11 @@ def __init__( ) -> None: super().__init__(prefix, name=name) try: - directory = Path(directory) - if str(directory).startswith("~"): - directory = Path(os.path.expanduser(str(directory))) - directory = directory.resolve() - if not directory.is_dir(): - raise ValueError("Not a directory") - except (FileNotFoundError, ValueError) as error: - raise ValueError(f"No directory exists at '{directory}'") from error + directory = Path(directory).expanduser().resolve(strict=True) + except FileNotFoundError as error: + raise ValueError(f"'{directory}' does not exist") from error + if not directory.is_dir(): + raise ValueError(f"'{directory}' is not a directory") self._directory = directory self._show_index = show_index self._chunk_size = chunk_size @@ -664,59 +668,64 @@ def __iter__(self) -> Iterator[AbstractRoute]: async def _handle(self, request: Request) -> StreamResponse: rel_url = request.match_info["filename"] + filename = Path(rel_url) + if filename.anchor: + # rel_url is an absolute name like + # /static/\\machine_name\c$ or /static/D:\path + # where the static dir is totally different + raise HTTPForbidden() + + unresolved_path = self._directory.joinpath(filename) + loop = asyncio.get_running_loop() + return await loop.run_in_executor( + None, self._resolve_path_to_response, unresolved_path + ) + + def _resolve_path_to_response(self, unresolved_path: Path) -> StreamResponse: + """Take the unresolved path and query the file system to form a response.""" + # Check for access outside the root directory. For follow symlinks, URI + # cannot traverse out, but symlinks can. Otherwise, no access outside + # root is permitted. try: - filename = Path(rel_url) - if filename.anchor: - # rel_url is an absolute name like - # /static/\\machine_name\c$ or /static/D:\path - # where the static dir is totally different - raise HTTPForbidden() - unresolved_path = self._directory.joinpath(filename) if self._follow_symlinks: normalized_path = Path(os.path.normpath(unresolved_path)) normalized_path.relative_to(self._directory) - filepath = normalized_path.resolve() + file_path = normalized_path.resolve() else: - filepath = unresolved_path.resolve() - filepath.relative_to(self._directory) - except (ValueError, FileNotFoundError) as error: - # relatively safe - raise HTTPNotFound() from error - except HTTPForbidden: - raise - except Exception as error: - # perm error or other kind! - request.app.logger.exception(error) + file_path = unresolved_path.resolve() + file_path.relative_to(self._directory) + except (ValueError, *CIRCULAR_SYMLINK_ERROR) as error: + # ValueError is raised for the relative check. Circular symlinks + # raise here on resolving for python < 3.13. raise HTTPNotFound() from error - # on opening a dir, load its contents if allowed - if filepath.is_dir(): - if self._show_index: - try: + # if path is a directory, return the contents if permitted. Note the + # directory check will raise if a segment is not readable. + try: + if file_path.is_dir(): + if self._show_index: return Response( - text=self._directory_as_html(filepath), content_type="text/html" + text=self._directory_as_html(file_path), + content_type="text/html", ) - except PermissionError: + else: raise HTTPForbidden() - else: - raise HTTPForbidden() - elif filepath.is_file(): - return FileResponse(filepath, chunk_size=self._chunk_size) - else: - raise HTTPNotFound + except PermissionError as error: + raise HTTPForbidden() from error - def _directory_as_html(self, filepath: Path) -> str: - # returns directory's index as html + # Return the file response, which handles all other checks. + return FileResponse(file_path, chunk_size=self._chunk_size) - # sanity check - assert filepath.is_dir() + def _directory_as_html(self, dir_path: Path) -> str: + """returns directory's index as html.""" + assert dir_path.is_dir() - relative_path_to_dir = filepath.relative_to(self._directory).as_posix() + relative_path_to_dir = dir_path.relative_to(self._directory).as_posix() index_of = f"Index of /{html_escape(relative_path_to_dir)}" h1 = f"

{index_of}

" index_list = [] - dir_index = filepath.iterdir() + dir_index = dir_path.iterdir() for _file in sorted(dir_index): # show file url as relative to static path rel_path = _file.relative_to(self._directory).as_posix() @@ -750,13 +759,20 @@ class PrefixedSubAppResource(PrefixResource): def __init__(self, prefix: str, app: "Application") -> None: super().__init__(prefix) self._app = app - for resource in app.router.resources(): - resource.add_prefix(prefix) + self._add_prefix_to_resources(prefix) def add_prefix(self, prefix: str) -> None: super().add_prefix(prefix) - for resource in self._app.router.resources(): + self._add_prefix_to_resources(prefix) + + def _add_prefix_to_resources(self, prefix: str) -> None: + router = self._app.router + for resource in router.resources(): + # Since the canonical path of a resource is about + # to change, we need to unindex it and then reindex + router.unindex_resource(resource) resource.add_prefix(prefix) + router.index_resource(resource) def url_for(self, *args: str, **kwargs: str) -> URL: raise RuntimeError(".url_for() is not supported " "by sub-application root") @@ -765,11 +781,6 @@ def get_info(self) -> _InfoDict: return {"app": self._app, "prefix": self._prefix} async def resolve(self, request: Request) -> _Resolve: - if ( - not request.url.raw_path.startswith(self._prefix2) - and request.url.raw_path != self._prefix - ): - return None, set() match_info = await self._app.router.resolve(request) match_info.add_app(self._app) if isinstance(match_info.http_exception, HTTPMethodNotAllowed): @@ -1015,12 +1026,39 @@ def __init__(self) -> None: super().__init__() self._resources: List[AbstractResource] = [] self._named_resources: Dict[str, AbstractResource] = {} + self._resource_index: dict[str, list[AbstractResource]] = {} + self._matched_sub_app_resources: List[MatchedSubAppResource] = [] async def resolve(self, request: Request) -> UrlMappingMatchInfo: - method = request.method + resource_index = self._resource_index allowed_methods: Set[str] = set() - for resource in self._resources: + # Walk the url parts looking for candidates. We walk the url backwards + # to ensure the most explicit match is found first. If there are multiple + # candidates for a given url part because there are multiple resources + # registered for the same canonical path, we resolve them in a linear + # fashion to ensure registration order is respected. + url_part = request.rel_url.raw_path + while url_part: + for candidate in resource_index.get(url_part, ()): + match_dict, allowed = await candidate.resolve(request) + if match_dict is not None: + return match_dict + else: + allowed_methods |= allowed + if url_part == "/": + break + url_part = url_part.rpartition("/")[0] or "/" + + # + # We didn't find any candidates, so we'll try the matched sub-app + # resources which we have to walk in a linear fashion because they + # have regex/wildcard match rules and we cannot index them. + # + # For most cases we do not expect there to be many of these since + # currently they are only added by `add_domain` + # + for resource in self._matched_sub_app_resources: match_dict, allowed = await resource.resolve(request) if match_dict is not None: return match_dict @@ -1028,9 +1066,9 @@ async def resolve(self, request: Request) -> UrlMappingMatchInfo: allowed_methods |= allowed if allowed_methods: - return MatchInfoError(HTTPMethodNotAllowed(method, allowed_methods)) - else: - return MatchInfoError(HTTPNotFound()) + return MatchInfoError(HTTPMethodNotAllowed(request.method, allowed_methods)) + + return MatchInfoError(HTTPNotFound()) def __iter__(self) -> Iterator[str]: return iter(self._named_resources) @@ -1086,6 +1124,36 @@ def register_resource(self, resource: AbstractResource) -> None: self._named_resources[name] = resource self._resources.append(resource) + if isinstance(resource, MatchedSubAppResource): + # We cannot index match sub-app resources because they have match rules + self._matched_sub_app_resources.append(resource) + else: + self.index_resource(resource) + + def _get_resource_index_key(self, resource: AbstractResource) -> str: + """Return a key to index the resource in the resource index.""" + if "{" in (index_key := resource.canonical): + # strip at the first { to allow for variables, and than + # rpartition at / to allow for variable parts in the path + # For example if the canonical path is `/core/locations{tail:.*}` + # the index key will be `/core` since index is based on the + # url parts split by `/` + index_key = index_key.partition("{")[0].rpartition("/")[0] + return index_key.rstrip("/") or "/" + + def index_resource(self, resource: AbstractResource) -> None: + """Add a resource to the resource index.""" + resource_key = self._get_resource_index_key(resource) + # There may be multiple resources for a canonical path + # so we keep them in a list to ensure that registration + # order is respected. + self._resource_index.setdefault(resource_key, []).append(resource) + + def unindex_resource(self, resource: AbstractResource) -> None: + """Remove a resource from the resource index.""" + resource_key = self._get_resource_index_key(resource) + self._resource_index[resource_key].remove(resource) + def add_resource(self, path: str, *, name: Optional[str] = None) -> Resource: if path and not path.startswith("/"): raise ValueError("path should be started with / or be empty") diff --git a/aiohttp/web_ws.py b/aiohttp/web_ws.py index 9fe66527539..ba3332715a6 100644 --- a/aiohttp/web_ws.py +++ b/aiohttp/web_ws.py @@ -11,7 +11,7 @@ from . import hdrs from .abc import AbstractStreamWriter -from .helpers import call_later, set_exception, set_result +from .helpers import calculate_timeout_when, set_exception, set_result from .http import ( WS_CLOSED_MESSAGE, WS_CLOSING_MESSAGE, @@ -81,68 +81,108 @@ def __init__( self._conn_lost = 0 self._close_code: Optional[int] = None self._loop: Optional[asyncio.AbstractEventLoop] = None - self._waiting: Optional[asyncio.Future[bool]] = None + self._waiting: bool = False + self._close_wait: Optional[asyncio.Future[None]] = None self._exception: Optional[BaseException] = None self._timeout = timeout self._receive_timeout = receive_timeout self._autoclose = autoclose self._autoping = autoping self._heartbeat = heartbeat + self._heartbeat_when = 0.0 self._heartbeat_cb: Optional[asyncio.TimerHandle] = None if heartbeat is not None: self._pong_heartbeat = heartbeat / 2.0 self._pong_response_cb: Optional[asyncio.TimerHandle] = None self._compress = compress self._max_msg_size = max_msg_size + self._ping_task: Optional[asyncio.Task[None]] = None def _cancel_heartbeat(self) -> None: - if self._pong_response_cb is not None: - self._pong_response_cb.cancel() - self._pong_response_cb = None - + self._cancel_pong_response_cb() if self._heartbeat_cb is not None: self._heartbeat_cb.cancel() self._heartbeat_cb = None + if self._ping_task is not None: + self._ping_task.cancel() + self._ping_task = None - def _reset_heartbeat(self) -> None: - self._cancel_heartbeat() + def _cancel_pong_response_cb(self) -> None: + if self._pong_response_cb is not None: + self._pong_response_cb.cancel() + self._pong_response_cb = None - if self._heartbeat is not None: - assert self._loop is not None - self._heartbeat_cb = call_later( - self._send_heartbeat, - self._heartbeat, - self._loop, - timeout_ceil_threshold=self._req._protocol._timeout_ceil_threshold - if self._req is not None - else 5, - ) + def _reset_heartbeat(self) -> None: + if self._heartbeat is None: + return + self._cancel_pong_response_cb() + req = self._req + timeout_ceil_threshold = ( + req._protocol._timeout_ceil_threshold if req is not None else 5 + ) + loop = self._loop + assert loop is not None + now = loop.time() + when = calculate_timeout_when(now, self._heartbeat, timeout_ceil_threshold) + self._heartbeat_when = when + if self._heartbeat_cb is None: + # We do not cancel the previous heartbeat_cb here because + # it generates a significant amount of TimerHandle churn + # which causes asyncio to rebuild the heap frequently. + # Instead _send_heartbeat() will reschedule the next + # heartbeat if it fires too early. + self._heartbeat_cb = loop.call_at(when, self._send_heartbeat) def _send_heartbeat(self) -> None: - if self._heartbeat is not None and not self._closed: - assert self._loop is not None - # fire-and-forget a task is not perfect but maybe ok for - # sending ping. Otherwise we need a long-living heartbeat - # task in the class. - self._loop.create_task(self._writer.ping()) # type: ignore[union-attr] - - if self._pong_response_cb is not None: - self._pong_response_cb.cancel() - self._pong_response_cb = call_later( - self._pong_not_received, - self._pong_heartbeat, - self._loop, - timeout_ceil_threshold=self._req._protocol._timeout_ceil_threshold - if self._req is not None - else 5, + self._heartbeat_cb = None + loop = self._loop + assert loop is not None and self._writer is not None + now = loop.time() + if now < self._heartbeat_when: + # Heartbeat fired too early, reschedule + self._heartbeat_cb = loop.call_at( + self._heartbeat_when, self._send_heartbeat ) + return + + req = self._req + timeout_ceil_threshold = ( + req._protocol._timeout_ceil_threshold if req is not None else 5 + ) + when = calculate_timeout_when(now, self._pong_heartbeat, timeout_ceil_threshold) + self._cancel_pong_response_cb() + self._pong_response_cb = loop.call_at(when, self._pong_not_received) + + if sys.version_info >= (3, 12): + # Optimization for Python 3.12, try to send the ping + # immediately to avoid having to schedule + # the task on the event loop. + ping_task = asyncio.Task(self._writer.ping(), loop=loop, eager_start=True) + else: + ping_task = loop.create_task(self._writer.ping()) + + if not ping_task.done(): + self._ping_task = ping_task + ping_task.add_done_callback(self._ping_task_done) + + def _ping_task_done(self, task: "asyncio.Task[None]") -> None: + """Callback for when the ping task completes.""" + self._ping_task = None def _pong_not_received(self) -> None: if self._req is not None and self._req.transport is not None: - self._closed = True + self._set_closed() self._set_code_close_transport(WSCloseCode.ABNORMAL_CLOSURE) self._exception = asyncio.TimeoutError() + def _set_closed(self) -> None: + """Set the connection to closed. + + Cancel any heartbeat timers and set the closed flag. + """ + self._closed = True + self._cancel_heartbeat() + async def prepare(self, request: BaseRequest) -> AbstractStreamWriter: # make pre-check to don't hide it by do_handshake() exceptions if self._payload_writer is not None: @@ -372,14 +412,17 @@ async def close( # we need to break `receive()` cycle first, # `close()` may be called from different task - if self._waiting is not None and not self._closed: + if self._waiting and not self._closed: + if not self._close_wait: + assert self._loop is not None + self._close_wait = self._loop.create_future() reader.feed_data(WS_CLOSING_MESSAGE, 0) - await self._waiting + await self._close_wait if self._closed: return False - self._closed = True + self._set_closed() try: await self._writer.close(code, message) writer = self._payload_writer @@ -411,7 +454,7 @@ async def close( self._set_code_close_transport(WSCloseCode.ABNORMAL_CLOSURE) return True - if msg.type == WSMsgType.CLOSE: + if msg.type is WSMsgType.CLOSE: self._set_code_close_transport(msg.data) return True @@ -423,6 +466,7 @@ def _set_closing(self, code: WSCloseCode) -> None: """Set the close code and mark the connection as closing.""" self._closing = True self._close_code = code + self._cancel_heartbeat() def _set_code_close_transport(self, code: WSCloseCode) -> None: """Set the close code and close the transport.""" @@ -441,7 +485,7 @@ async def receive(self, timeout: Optional[float] = None) -> WSMessage: loop = self._loop assert loop is not None while True: - if self._waiting is not None: + if self._waiting: raise RuntimeError("Concurrent call to receive() is not allowed") if self._closed: @@ -453,15 +497,15 @@ async def receive(self, timeout: Optional[float] = None) -> WSMessage: return WS_CLOSING_MESSAGE try: - self._waiting = loop.create_future() + self._waiting = True try: async with async_timeout.timeout(timeout or self._receive_timeout): msg = await self._reader.read() self._reset_heartbeat() finally: - waiter = self._waiting - set_result(waiter, True) - self._waiting = None + self._waiting = False + if self._close_wait: + set_result(self._close_wait, None) except asyncio.TimeoutError: raise except EofStream: @@ -478,7 +522,7 @@ async def receive(self, timeout: Optional[float] = None) -> WSMessage: await self.close() return WSMessage(WSMsgType.ERROR, exc, None) - if msg.type == WSMsgType.CLOSE: + if msg.type is WSMsgType.CLOSE: self._set_closing(msg.data) # Could be closed while awaiting reader. if not self._closed and self._autoclose: @@ -487,19 +531,19 @@ async def receive(self, timeout: Optional[float] = None) -> WSMessage: # want to drain any pending writes as it will # likely result writing to a broken pipe. await self.close(drain=False) - elif msg.type == WSMsgType.CLOSING: + elif msg.type is WSMsgType.CLOSING: self._set_closing(WSCloseCode.OK) - elif msg.type == WSMsgType.PING and self._autoping: + elif msg.type is WSMsgType.PING and self._autoping: await self.pong(msg.data) continue - elif msg.type == WSMsgType.PONG and self._autoping: + elif msg.type is WSMsgType.PONG and self._autoping: continue return msg async def receive_str(self, *, timeout: Optional[float] = None) -> str: msg = await self.receive(timeout) - if msg.type != WSMsgType.TEXT: + if msg.type is not WSMsgType.TEXT: raise TypeError( "Received message {}:{!r} is not WSMsgType.TEXT".format( msg.type, msg.data @@ -509,7 +553,7 @@ async def receive_str(self, *, timeout: Optional[float] = None) -> str: async def receive_bytes(self, *, timeout: Optional[float] = None) -> bytes: msg = await self.receive(timeout) - if msg.type != WSMsgType.BINARY: + if msg.type is not WSMsgType.BINARY: raise TypeError(f"Received message {msg.type}:{msg.data!r} is not bytes") return cast(bytes, msg.data) @@ -535,5 +579,6 @@ def _cancel(self, exc: BaseException) -> None: # web_protocol calls this from connection_lost # or when the server is shutting down. self._closing = True + self._cancel_heartbeat() if self._reader is not None: set_exception(self._reader, exc) diff --git a/docs/abc.rst b/docs/abc.rst index d2695673fcf..4eea6715991 100644 --- a/docs/abc.rst +++ b/docs/abc.rst @@ -181,3 +181,57 @@ Abstract Access Logger :param response: :class:`aiohttp.web.Response` object. :param float time: Time taken to serve the request. + + +Abstract Resolver +------------------------------- + +.. class:: AbstractResolver + + An abstract class, base for all resolver implementations. + + Method ``resolve`` should be overridden. + + .. method:: resolve(host, port, family) + + Resolve host name to IP address. + + :param str host: host name to resolve. + + :param int port: port number. + + :param int family: socket family. + + :return: list of :class:`aiohttp.abc.ResolveResult` instances. + + .. method:: close() + + Release resolver. + +.. class:: ResolveResult + + Result of host name resolution. + + .. attribute:: hostname + + The host name that was provided. + + .. attribute:: host + + The IP address that was resolved. + + .. attribute:: port + + The port that was resolved. + + .. attribute:: family + + The address family that was resolved. + + .. attribute:: proto + + The protocol that was resolved. + + .. attribute:: flags + + The flags that were resolved. diff --git a/docs/client_reference.rst b/docs/client_reference.rst index fdf66e1bef0..738892c6cc6 100644 --- a/docs/client_reference.rst +++ b/docs/client_reference.rst @@ -373,7 +373,7 @@ The client session supports the context manager protocol for self closing. read_until_eof=True, \ read_bufsize=None, \ proxy=None, proxy_auth=None,\ - timeout=sentinel, ssl=None, \ + timeout=sentinel, ssl=True, \ verify_ssl=None, fingerprint=None, \ ssl_context=None, proxy_headers=None, \ server_hostname=None, auto_decompress=None) @@ -491,7 +491,7 @@ The client session supports the context manager protocol for self closing. If :class:`float` is passed it is a *total* timeout (in seconds). - :param ssl: SSL validation mode. ``None`` for default SSL check + :param ssl: SSL validation mode. ``True`` for default SSL check (:func:`ssl.create_default_context` is used), ``False`` for skip SSL certificate validation, :class:`aiohttp.Fingerprint` for fingerprint @@ -516,7 +516,7 @@ The client session supports the context manager protocol for self closing. :param bytes fingerprint: Pass the SHA256 digest of the expected certificate in DER format to verify that the certificate the server presents matches. Useful for `certificate pinning - `_. + `_. Warning: use of MD5 or SHA1 digests is insecure and removed. @@ -696,7 +696,7 @@ The client session supports the context manager protocol for self closing. origin=None, \ params=None, \ headers=None, \ - proxy=None, proxy_auth=None, ssl=None, \ + proxy=None, proxy_auth=None, ssl=True, \ verify_ssl=None, fingerprint=None, \ ssl_context=None, proxy_headers=None, \ compress=0, max_msg_size=4194304) @@ -760,7 +760,7 @@ The client session supports the context manager protocol for self closing. :param aiohttp.BasicAuth proxy_auth: an object that represents proxy HTTP Basic Authorization (optional) - :param ssl: SSL validation mode. ``None`` for default SSL check + :param ssl: SSL validation mode. ``True`` for default SSL check (:func:`ssl.create_default_context` is used), ``False`` for skip SSL certificate validation, :class:`aiohttp.Fingerprint` for fingerprint @@ -785,7 +785,7 @@ The client session supports the context manager protocol for self closing. :param bytes fingerprint: Pass the SHA256 digest of the expected certificate in DER format to verify that the certificate the server presents matches. Useful for `certificate pinning - `_. + `_. Note: use of MD5 or SHA1 digests is insecure and deprecated. @@ -1066,12 +1066,13 @@ is controlled by *force_close* constructor's parameter). overridden in subclasses. -.. class:: TCPConnector(*, ssl=None, verify_ssl=True, fingerprint=None, \ +.. class:: TCPConnector(*, ssl=True, verify_ssl=True, fingerprint=None, \ use_dns_cache=True, ttl_dns_cache=10, \ family=0, ssl_context=None, local_addr=None, \ resolver=None, keepalive_timeout=sentinel, \ force_close=False, limit=100, limit_per_host=0, \ - enable_cleanup_closed=False, loop=None) + enable_cleanup_closed=False, timeout_ceil_threshold=5, \ + happy_eyeballs_delay=0.25, interleave=None, loop=None) Connector for working with *HTTP* and *HTTPS* via *TCP* sockets. @@ -1083,7 +1084,7 @@ is controlled by *force_close* constructor's parameter). Constructor accepts all parameters suitable for :class:`BaseConnector` plus several TCP-specific ones: - :param ssl: SSL validation mode. ``None`` for default SSL check + :param ssl: SSL validation mode. ``True`` for default SSL check (:func:`ssl.create_default_context` is used), ``False`` for skip SSL certificate validation, :class:`aiohttp.Fingerprint` for fingerprint @@ -1106,7 +1107,7 @@ is controlled by *force_close* constructor's parameter). :param bytes fingerprint: pass the SHA256 digest of the expected certificate in DER format to verify that the certificate the server presents matches. Useful for `certificate pinning - `_. + `_. Note: use of MD5 or SHA1 digests is insecure and deprecated. @@ -1174,6 +1175,24 @@ is controlled by *force_close* constructor's parameter). If this parameter is set to True, aiohttp additionally aborts underlining transport after 2 seconds. It is off by default. + :param float happy_eyeballs_delay: The amount of time in seconds to wait for a + connection attempt to complete, before starting the next attempt in parallel. + This is the “Connection Attempt Delay” as defined in RFC 8305. To disable + Happy Eyeballs, set this to ``None``. The default value recommended by the + RFC is 0.25 (250 milliseconds). + + .. versionadded:: 3.10 + + :param int interleave: controls address reordering when a host name resolves + to multiple IP addresses. If ``0`` or unspecified, no reordering is done, and + addresses are tried in the order returned by the resolver. If a positive + integer is specified, the addresses are interleaved by address family, and + the given integer is interpreted as “First Address Family Count” as defined + in RFC 8305. The default is ``0`` if happy_eyeballs_delay is not specified, and + ``1`` if it is. + + .. versionadded:: 3.10 + .. attribute:: family *TCP* socket family e.g. :data:`socket.AF_INET` or @@ -2096,6 +2115,41 @@ All exceptions are available as members of *aiohttp* module. Invalid URL, :class:`yarl.URL` instance. + .. attribute:: description + + Invalid URL description, :class:`str` instance or :data:`None`. + +.. exception:: InvalidUrlClientError + + Base class for all errors related to client url. + + Derived from :exc:`InvalidURL` + +.. exception:: RedirectClientError + + Base class for all errors related to client redirects. + + Derived from :exc:`ClientError` + +.. exception:: NonHttpUrlClientError + + Base class for all errors related to non http client urls. + + Derived from :exc:`ClientError` + +.. exception:: InvalidUrlRedirectClientError + + Redirect URL is malformed, e.g. it does not contain host part. + + Derived from :exc:`InvalidUrlClientError` and :exc:`RedirectClientError` + +.. exception:: NonHttpUrlRedirectClientError + + Redirect URL does not contain http schema. + + Derived from :exc:`RedirectClientError` and :exc:`NonHttpUrlClientError` + + .. class:: ContentDisposition Represent Content-Disposition header @@ -2254,6 +2308,17 @@ Connection errors Derived from :exc:`ServerConnectionError` and :exc:`asyncio.TimeoutError` +.. class:: ConnectionTimeoutError + + Connection timeout on request: e.g. read timeout. + + Derived from :exc:`ServerTimeoutError` + +.. class:: SocketTimeoutError + + Reading from socket timeout. + + Derived from :exc:`ServerTimeoutError` Hierarchy of exceptions ^^^^^^^^^^^^^^^^^^^^^^^ @@ -2284,6 +2349,10 @@ Hierarchy of exceptions * :exc:`ServerTimeoutError` + * :exc:`ConnectionTimeoutError` + + * :exc:`SocketTimeoutError` + * :exc:`ClientPayloadError` * :exc:`ClientResponseError` @@ -2297,3 +2366,17 @@ Hierarchy of exceptions * :exc:`WSServerHandshakeError` * :exc:`InvalidURL` + + * :exc:`InvalidUrlClientError` + + * :exc:`InvalidUrlRedirectClientError` + + * :exc:`NonHttpUrlClientError` + + * :exc:`NonHttpUrlRedirectClientError` + + * :exc:`RedirectClientError` + + * :exc:`InvalidUrlRedirectClientError` + + * :exc:`NonHttpUrlRedirectClientError` diff --git a/docs/conf.py b/docs/conf.py index f21366fb488..23ac3e426ec 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -59,7 +59,7 @@ "sphinx.ext.viewcode", # Third-party extensions: "sphinxcontrib.blockdiag", - "sphinxcontrib.towncrier", # provides `towncrier-draft-entries` directive + "sphinxcontrib.towncrier.ext", # provides `towncrier-draft-entries` directive ] @@ -394,7 +394,8 @@ ("py:class", "aiohttp.protocol.HttpVersion"), # undocumented ("py:class", "aiohttp.ClientRequest"), # undocumented ("py:class", "aiohttp.payload.Payload"), # undocumented - ("py:class", "aiohttp.abc.AbstractResolver"), # undocumented + ("py:class", "aiohttp.resolver.AsyncResolver"), # undocumented + ("py:class", "aiohttp.resolver.ThreadedResolver"), # undocumented ("py:func", "aiohttp.ws_connect"), # undocumented ("py:meth", "start"), # undocumented ("py:exc", "aiohttp.ClientHttpProxyError"), # undocumented diff --git a/docs/contributing-admins.rst b/docs/contributing-admins.rst index 9444f8ac5c4..acfaebc0e97 100644 --- a/docs/contributing-admins.rst +++ b/docs/contributing-admins.rst @@ -52,6 +52,6 @@ Back on the original release branch, bump the version number and append ``.dev0` If doing a minor release: #. Create a new release branch for future features to go to: e.g. ``git checkout -b 3.10 3.9 && git push`` -#. Update ``target-branch`` for Dependabot to reference the new branch name in ``.github/dependabot.yml``. +#. Update both ``target-branch`` backports for Dependabot to reference the new branch name in ``.github/dependabot.yml``. #. Delete the older backport label (e.g. backport-3.8): https://github.com/aio-libs/aiohttp/labels #. Add a new backport label (e.g. backport-3.10). diff --git a/docs/index.rst b/docs/index.rst index 4f55c5ddf09..9692152cb99 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -174,7 +174,7 @@ Communication channels Feel free to post your questions and ideas here. -*gitter chat* https://gitter.im/aio-libs/Lobby +*Matrix*: `#aio-libs:matrix.org `_ We support `Stack Overflow `_. diff --git a/docs/testing.rst b/docs/testing.rst index 027ba63a039..c2937b82282 100644 --- a/docs/testing.rst +++ b/docs/testing.rst @@ -111,11 +111,11 @@ app test client:: body='value: {}'.format(request.app[value]).encode('utf-8')) @pytest.fixture - def cli(loop, aiohttp_client): + async def cli(aiohttp_client): app = web.Application() app.router.add_get('/', previous) app.router.add_post('/', previous) - return loop.run_until_complete(aiohttp_client(app)) + return await aiohttp_client(app) async def test_set_value(cli): resp = await cli.post('/', data={'value': 'foo'}) diff --git a/docs/third_party.rst b/docs/third_party.rst index 5c354f1e6c6..797f9f011ec 100644 --- a/docs/third_party.rst +++ b/docs/third_party.rst @@ -295,3 +295,9 @@ ask to raise the status. - `rsocket `_ Python implementation of `RSocket protocol `_. + +- `nacl_middleware `_ + An aiohttp middleware library for asymmetric encryption of data transmitted via http and/or websocket connections. + +- `aiohttp-asgi-connector `_ + An aiohttp connector for using a ``ClientSession`` to interface directly with separate ASGI applications. diff --git a/docs/web_advanced.rst b/docs/web_advanced.rst index d2ba3013e30..dc94bea33bf 100644 --- a/docs/web_advanced.rst +++ b/docs/web_advanced.rst @@ -48,6 +48,8 @@ socket closing on the peer side without reading the full server response. except OSError: # disconnected +.. _web-handler-cancellation: + Web handler cancellation ^^^^^^^^^^^^^^^^^^^^^^^^ @@ -68,38 +70,48 @@ needed to deal with them. .. warning:: - :term:`web-handler` execution could be canceled on every ``await`` - if client drops connection without reading entire response's BODY. + :term:`web-handler` execution could be canceled on every ``await`` or + ``async with`` if client drops connection without reading entire response's BODY. Sometimes it is a desirable behavior: on processing ``GET`` request the code might fetch data from a database or other web resource, the fetching is potentially slow. -Canceling this fetch is a good idea: the peer dropped connection +Canceling this fetch is a good idea: the client dropped the connection already, so there is no reason to waste time and resources (memory etc) -by getting data from a DB without any chance to send it back to peer. +by getting data from a DB without any chance to send it back to the client. -But sometimes the cancellation is bad: on ``POST`` request very often -it is needed to save data to a DB regardless of peer closing. +But sometimes the cancellation is bad: on ``POST`` requests very often +it is needed to save data to a DB regardless of connection closing. Cancellation prevention could be implemented in several ways: -* Applying :func:`asyncio.shield` to a coroutine that saves data. -* Using aiojobs_ or another third party library. +* Applying :func:`aiojobs.aiohttp.shield` to a coroutine that saves data. +* Using aiojobs_ or another third party library to run a task in the background. + +:func:`aiojobs.aiohttp.shield` can work well. The only disadvantage is you +need to split the web handler into two async functions: one for the handler +itself and another for protected code. + +.. warning:: -:func:`asyncio.shield` can work well. The only disadvantage is you -need to split web handler into exactly two async functions: one -for handler itself and other for protected code. + We don't recommend using :func:`asyncio.shield` for this because the shielded + task cannot be tracked by the application and therefore there is a risk that + the task will get cancelled during application shutdown. The function provided + by aiojobs_ operates in the same way except the inner task will be tracked + by the Scheduler and will get waited on during the cleanup phase. For example the following snippet is not safe:: + from aiojobs.aiohttp import shield + async def handler(request): - await asyncio.shield(write_to_redis(request)) - await asyncio.shield(write_to_postgres(request)) + await shield(request, write_to_redis(request)) + await shield(request, write_to_postgres(request)) return web.Response(text="OK") -Cancellation might occur while saving data in REDIS, so -``write_to_postgres`` will not be called, potentially +Cancellation might occur while saving data in REDIS, so the +``write_to_postgres`` function will not be called, potentially leaving your data in an inconsistent state. Instead, you would need to write something like:: @@ -109,7 +121,7 @@ Instead, you would need to write something like:: await write_to_postgres(request) async def handler(request): - await asyncio.shield(write_data(request)) + await shield(request, write_data(request)) return web.Response(text="OK") Alternatively, if you want to spawn a task without waiting for @@ -160,7 +172,7 @@ restoring the default disconnection behavior only for specific handlers:: app.router.add_post("/", handler) It prevents all of the ``handler`` async function from cancellation, -so ``write_to_db`` will be never interrupted. +so ``write_to_db`` will never be interrupted. .. _aiojobs: http://aiojobs.readthedocs.io/en/latest/ @@ -936,30 +948,24 @@ always satisfactory. When aiohttp is run with :func:`run_app`, it will attempt a graceful shutdown by following these steps (if using a :ref:`runner `, then calling :meth:`AppRunner.cleanup` will perform these steps, excluding -steps 4 and 7). +step 7). 1. Stop each site listening on sockets, so new connections will be rejected. 2. Close idle keep-alive connections (and set active ones to close upon completion). 3. Call the :attr:`Application.on_shutdown` signal. This should be used to shutdown long-lived connections, such as websockets (see below). -4. Wait a short time for running tasks to complete. This allows any pending handlers - or background tasks to complete successfully. The timeout can be adjusted with - ``shutdown_timeout`` in :func:`run_app`. +4. Wait a short time for running handlers to complete. This allows any pending handlers + to complete successfully. The timeout can be adjusted with ``shutdown_timeout`` + in :func:`run_app`. 5. Close any remaining connections and cancel their handlers. It will wait on the canceling handlers for a short time, again adjustable with ``shutdown_timeout``. 6. Call the :attr:`Application.on_cleanup` signal. This should be used to cleanup any resources (such as DB connections). This includes completing the - :ref:`cleanup contexts`. + :ref:`cleanup contexts` which may be used to ensure + background tasks are completed successfully (see + :ref:`handler cancellation` or aiojobs_ for examples). 7. Cancel any remaining tasks and wait on them to complete. -.. note:: - - When creating new tasks in a handler which _should_ be cancelled on server shutdown, - then it is important to keep track of those tasks and explicitly cancel them in a - :attr:`Application.on_shutdown` callback. As we can see from the above steps, - without this the server will wait on those new tasks to complete before it continues - with server shutdown. - Websocket shutdown ^^^^^^^^^^^^^^^^^^ diff --git a/docs/web_reference.rst b/docs/web_reference.rst index aedac0e54d1..bb22cfd6369 100644 --- a/docs/web_reference.rst +++ b/docs/web_reference.rst @@ -510,7 +510,6 @@ and :ref:`aiohttp-web-signals` handlers. required work will be processed by :mod:`aiohttp.web` internal machinery. - .. class:: Request A request used for receiving request's information by *web handler*. @@ -925,6 +924,31 @@ and :ref:`aiohttp-web-signals` handlers:: :attr:`~aiohttp.StreamResponse.body`, represented as :class:`str`. +.. class:: FileResponse(*, path, chunk_size=256*1024, status=200, reason=None, headers=None) + + The response class used to send files, inherited from :class:`StreamResponse`. + + Supports the ``Content-Range`` and ``If-Range`` HTTP Headers in requests. + + The actual :attr:`body` sending happens in overridden :meth:`~StreamResponse.prepare`. + + :param path: Path to file. Accepts both :class:`str` and :class:`pathlib.Path`. + :param int chunk_size: Chunk size in bytes which will be passed into + :meth:`io.RawIOBase.read` in the event that the + ``sendfile`` system call is not supported. + + :param int status: HTTP status code, ``200`` by default. + + :param str reason: HTTP reason. If param is ``None`` reason will be + calculated basing on *status* + parameter. Otherwise pass :class:`str` with + arbitrary *status* explanation.. + + :param collections.abc.Mapping headers: HTTP headers that should be added to + response's ones. The ``Content-Type`` response header + will be overridden if provided. + + .. class:: WebSocketResponse(*, timeout=10.0, receive_timeout=None, \ autoclose=True, autoping=True, heartbeat=None, \ protocols=(), compress=True, max_msg_size=4194304) @@ -1846,8 +1870,9 @@ Application and Router system call even if the platform supports it. This can be accomplished by by setting environment variable ``AIOHTTP_NOSENDFILE=1``. - If a gzip version of the static content exists at file path + ``.gz``, it - will be used for the response. + If a Brotli or gzip compressed version of the static content exists at + the requested path with the ``.br`` or ``.gz`` extension, it will be used + for the response. Brotli will be preferred over gzip if both files exist. .. warning:: @@ -1972,20 +1997,38 @@ unique *name* and at least one :term:`route`. :term:`web-handler` lookup is performed in the following way: -1. Router iterates over *resources* one-by-one. -2. If *resource* matches to requested URL the resource iterates over - own *routes*. -3. If route matches to requested HTTP method (or ``'*'`` wildcard) the - route's handler is used as found :term:`web-handler`. The lookup is - finished. -4. Otherwise router tries next resource from the *routing table*. -5. If the end of *routing table* is reached and no *resource* / - *route* pair found the *router* returns special :class:`~aiohttp.abc.AbstractMatchInfo` +1. The router splits the URL and checks the index from longest to shortest. + For example, '/one/two/three' will first check the index for + '/one/two/three', then '/one/two' and finally '/'. +2. If the URL part is found in the index, the list of routes for + that URL part is iterated over. If a route matches to requested HTTP + method (or ``'*'`` wildcard) the route's handler is used as the chosen + :term:`web-handler`. The lookup is finished. +3. If the route is not found in the index, the router tries to find + the route in the list of :class:`~aiohttp.web.MatchedSubAppResource`, + (current only created from :meth:`~aiohttp.web.Application.add_domain`), + and will iterate over the list of + :class:`~aiohttp.web.MatchedSubAppResource` in a linear fashion + until a match is found. +4. If no *resource* / *route* pair was found, the *router* + returns the special :class:`~aiohttp.abc.AbstractMatchInfo` instance with :attr:`aiohttp.abc.AbstractMatchInfo.http_exception` is not ``None`` but :exc:`HTTPException` with either *HTTP 404 Not Found* or *HTTP 405 Method Not Allowed* status code. Registered :meth:`~aiohttp.abc.AbstractMatchInfo.handler` raises this exception on call. +Fixed paths are preferred over variable paths. For example, +if you have two routes ``/a/b`` and ``/a/{name}``, then the first +route will always be preferred over the second one. + +If there are multiple dynamic paths with the same fixed prefix, +they will be resolved in order of registration. + +For example, if you have two dynamic routes that are prefixed +with the fixed ``/users`` path such as ``/users/{x}/{y}/z`` and +``/users/{x}/y/z``, the first one will be preferred over the +second one. + User should never instantiate resource classes but give it by :meth:`UrlDispatcher.add_resource` call. @@ -2007,7 +2050,10 @@ Resource classes hierarchy:: Resource PlainResource DynamicResource + PrefixResource StaticResource + PrefixedSubAppResource + MatchedSubAppResource .. class:: AbstractResource diff --git a/examples/fake_server.py b/examples/fake_server.py index 3157bab658c..2cfe3ed710e 100755 --- a/examples/fake_server.py +++ b/examples/fake_server.py @@ -3,10 +3,11 @@ import pathlib import socket import ssl +from typing import List import aiohttp from aiohttp import web -from aiohttp.abc import AbstractResolver +from aiohttp.abc import AbstractResolver, ResolveResult from aiohttp.resolver import DefaultResolver from aiohttp.test_utils import unused_port @@ -19,7 +20,12 @@ def __init__(self, fakes, *, loop): self._fakes = fakes self._resolver = DefaultResolver(loop=loop) - async def resolve(self, host, port=0, family=socket.AF_INET): + async def resolve( + self, + host: str, + port: int = 0, + family: socket.AddressFamily = socket.AF_INET, + ) -> List[ResolveResult]: fake_port = self._fakes.get(host) if fake_port is not None: return [ diff --git a/requirements/base.in b/requirements/base.in index df67f78afde..70493b6c83a 100644 --- a/requirements/base.in +++ b/requirements/base.in @@ -1,4 +1,3 @@ --r typing-extensions.in -r runtime-deps.in gunicorn diff --git a/requirements/base.txt b/requirements/base.txt index 77943e4e44a..888f9a77899 100644 --- a/requirements/base.txt +++ b/requirements/base.txt @@ -4,27 +4,29 @@ # # pip-compile --allow-unsafe --output-file=requirements/base.txt --strip-extras requirements/base.in # -aiodns==3.1.1 ; sys_platform == "linux" or sys_platform == "darwin" +aiodns==3.2.0 ; sys_platform == "linux" or sys_platform == "darwin" + # via -r requirements/runtime-deps.in +aiohappyeyeballs==2.3.4 # via -r requirements/runtime-deps.in aiosignal==1.3.1 # via -r requirements/runtime-deps.in async-timeout==4.0.3 ; python_version < "3.11" # via -r requirements/runtime-deps.in -attrs==23.1.0 +attrs==23.2.0 # via -r requirements/runtime-deps.in brotli==1.1.0 ; platform_python_implementation == "CPython" # via -r requirements/runtime-deps.in cffi==1.15.1 # via pycares -frozenlist==1.4.0 +frozenlist==1.4.1 # via # -r requirements/runtime-deps.in # aiosignal -gunicorn==21.2.0 +gunicorn==22.0.0 # via -r requirements/base.in idna==3.4 # via yarl -multidict==6.0.4 +multidict==6.0.5 # via # -r requirements/runtime-deps.in # yarl @@ -34,9 +36,7 @@ pycares==4.3.0 # via aiodns pycparser==2.21 # via cffi -typing-extensions==4.7.1 - # via -r requirements/typing-extensions.in uvloop==0.19.0 ; platform_system != "Windows" and implementation_name == "cpython" # via -r requirements/base.in -yarl==1.9.3 +yarl==1.9.4 # via -r requirements/runtime-deps.in diff --git a/requirements/constraints.txt b/requirements/constraints.txt index adba72bb204..b40b4440ae0 100644 --- a/requirements/constraints.txt +++ b/requirements/constraints.txt @@ -4,7 +4,11 @@ # # pip-compile --allow-unsafe --output-file=requirements/constraints.txt --resolver=backtracking --strip-extras requirements/constraints.in # -aiodns==3.1.1 ; sys_platform == "linux" or sys_platform == "darwin" +aiodns==3.2.0 ; sys_platform == "linux" or sys_platform == "darwin" + # via + # -r requirements/lint.in + # -r requirements/runtime-deps.in +aiohappyeyeballs==2.3.4 # via -r requirements/runtime-deps.in aiohttp-theme==0.1.6 # via -r requirements/doc.in @@ -20,7 +24,7 @@ async-timeout==4.0.3 ; python_version < "3.11" # via # -r requirements/runtime-deps.in # aioredis -attrs==23.1.0 +attrs==23.2.0 # via -r requirements/runtime-deps.in babel==2.9.1 # via sphinx @@ -30,7 +34,7 @@ blockdiag==2.0.1 # via sphinxcontrib-blockdiag brotli==1.1.0 ; platform_python_implementation == "CPython" # via -r requirements/runtime-deps.in -build==0.9.0 +build==1.0.3 # via pip-tools certifi==2023.7.22 # via requests @@ -52,7 +56,7 @@ click==8.0.3 # towncrier # typer # wait-for-it -coverage==7.3.2 +coverage==7.6.0 # via # -r requirements/test.in # pytest-cov @@ -60,7 +64,7 @@ cryptography==41.0.2 # via # pyjwt # trustme -cython==3.0.5 +cython==3.0.10 # via -r requirements/cython.in distlib==0.3.3 # via virtualenv @@ -70,9 +74,9 @@ exceptiongroup==1.1.2 # via pytest filelock==3.3.2 # via virtualenv -freezegun==1.3.0 +freezegun==1.5.1 # via -r requirements/test.in -frozenlist==1.4.0 +frozenlist==1.4.1 # via # -r requirements/runtime-deps.in # aiosignal @@ -80,7 +84,7 @@ funcparserlib==1.0.1 # via blockdiag gidgethub==5.0.1 # via cherry-picker -gunicorn==21.2.0 +gunicorn==22.0.0 # via -r requirements/base.in identify==2.3.5 # via pre-commit @@ -91,6 +95,12 @@ idna==3.3 # yarl imagesize==1.3.0 # via sphinx +importlib-metadata==7.0.0 + # via + # build + # sphinx +importlib-resources==6.1.1 + # via towncrier incremental==22.10.0 # via towncrier iniconfig==1.1.1 @@ -101,12 +111,12 @@ jinja2==3.0.3 # towncrier markupsafe==2.0.1 # via jinja2 -multidict==6.0.4 +multidict==6.0.5 # via # -r requirements/multidict.in # -r requirements/runtime-deps.in # yarl -mypy==1.7.1 ; implementation_name == "cpython" +mypy==1.11.1 ; implementation_name == "cpython" # via # -r requirements/lint.in # -r requirements/test.in @@ -120,21 +130,19 @@ packaging==21.2 # gunicorn # pytest # sphinx -pep517==0.12.0 - # via build pillow==9.5.0 # via # -c requirements/broken-projects.in # blockdiag -pip-tools==7.3.0 +pip-tools==7.4.1 # via -r requirements/dev.in platformdirs==2.4.0 # via virtualenv -pluggy==1.0.0 +pluggy==1.5.0 # via pytest pre-commit==3.5.0 # via -r requirements/lint.in -proxy-py==2.4.4rc4 +proxy-py==2.4.4 # via -r requirements/test.in pycares==4.3.0 # via aiodns @@ -154,20 +162,28 @@ pyjwt==2.3.0 # pyjwt pyparsing==2.4.7 # via packaging -pytest==7.4.3 +pyproject-hooks==1.0.0 + # via + # build + # pip-tools +pytest==8.3.2 # via # -r requirements/lint.in # -r requirements/test.in # pytest-cov # pytest-mock -pytest-cov==4.1.0 +pytest-cov==5.0.0 # via -r requirements/test.in -pytest-mock==3.12.0 +pytest-mock==3.14.0 # via -r requirements/test.in python-dateutil==2.8.2 # via freezegun -python-on-whales==0.67.0 - # via -r requirements/test.in +python-on-whales==0.72.0 + # via + # -r requirements/lint.in + # -r requirements/test.in +pytz==2023.3.post1 + # via babel pyyaml==6.0.1 # via pre-commit re-assert==1.1.0 @@ -185,7 +201,7 @@ six==1.16.0 # via # python-dateutil # virtualenv -slotscheck==0.17.1 +slotscheck==0.19.0 # via -r requirements/lint.in snowballstemmer==2.1.0 # via sphinx @@ -219,8 +235,8 @@ tomli==2.0.1 # cherry-picker # coverage # mypy - # pep517 # pip-tools + # pyproject-hooks # pytest # slotscheck # towncrier @@ -234,9 +250,8 @@ trustme==1.1.0 ; platform_machine != "i686" # via -r requirements/test.in typer==0.6.1 # via python-on-whales -typing-extensions==4.7.1 +typing-extensions==4.11.0 # via - # -r requirements/typing-extensions.in # aioredis # annotated-types # mypy @@ -259,8 +274,12 @@ webcolors==1.11.1 # via blockdiag wheel==0.37.0 # via pip-tools -yarl==1.9.3 +yarl==1.9.4 # via -r requirements/runtime-deps.in +zipp==3.17.0 + # via + # importlib-metadata + # importlib-resources # The following packages are considered to be unsafe in a requirements file: pip==23.2.1 diff --git a/requirements/cython.in b/requirements/cython.in index ee07533e17c..6f0238f170d 100644 --- a/requirements/cython.in +++ b/requirements/cython.in @@ -1,4 +1,3 @@ -r multidict.in --r typing-extensions.in # required for parsing aiohttp/hdrs.py by tools/gen.py Cython diff --git a/requirements/cython.txt b/requirements/cython.txt index 5851f1d8b48..72b9a67af98 100644 --- a/requirements/cython.txt +++ b/requirements/cython.txt @@ -4,9 +4,9 @@ # # pip-compile --allow-unsafe --output-file=requirements/cython.txt --resolver=backtracking --strip-extras requirements/cython.in # -cython==3.0.5 +cython==3.0.10 # via -r requirements/cython.in -multidict==6.0.4 +multidict==6.0.5 # via -r requirements/multidict.in -typing-extensions==4.7.1 +typing-extensions==4.11.0 # via -r requirements/typing-extensions.in diff --git a/requirements/dev.txt b/requirements/dev.txt index 3d5926c12bd..3ad4f54b209 100644 --- a/requirements/dev.txt +++ b/requirements/dev.txt @@ -4,7 +4,11 @@ # # pip-compile --allow-unsafe --output-file=requirements/dev.txt --resolver=backtracking --strip-extras requirements/dev.in # -aiodns==3.1.1 ; sys_platform == "linux" or sys_platform == "darwin" +aiodns==3.2.0 ; sys_platform == "linux" or sys_platform == "darwin" + # via + # -r requirements/lint.in + # -r requirements/runtime-deps.in +aiohappyeyeballs==2.3.4 # via -r requirements/runtime-deps.in aiohttp-theme==0.1.6 # via -r requirements/doc.in @@ -20,7 +24,7 @@ async-timeout==4.0.3 ; python_version < "3.11" # via # -r requirements/runtime-deps.in # aioredis -attrs==23.1.0 +attrs==23.2.0 # via -r requirements/runtime-deps.in babel==2.12.1 # via sphinx @@ -28,7 +32,7 @@ blockdiag==3.0.0 # via sphinxcontrib-blockdiag brotli==1.1.0 ; platform_python_implementation == "CPython" # via -r requirements/runtime-deps.in -build==0.10.0 +build==1.0.3 # via pip-tools certifi==2023.7.22 # via requests @@ -50,7 +54,7 @@ click==8.1.6 # towncrier # typer # wait-for-it -coverage==7.3.2 +coverage==7.6.0 # via # -r requirements/test.in # pytest-cov @@ -66,9 +70,9 @@ exceptiongroup==1.1.2 # via pytest filelock==3.12.2 # via virtualenv -freezegun==1.3.0 +freezegun==1.5.1 # via -r requirements/test.in -frozenlist==1.4.0 +frozenlist==1.4.1 # via # -r requirements/runtime-deps.in # aiosignal @@ -76,7 +80,7 @@ funcparserlib==1.0.1 # via blockdiag gidgethub==5.3.0 # via cherry-picker -gunicorn==21.2.0 +gunicorn==22.0.0 # via -r requirements/base.in identify==2.5.26 # via pre-commit @@ -87,6 +91,12 @@ idna==3.4 # yarl imagesize==1.4.1 # via sphinx +importlib-metadata==7.0.0 + # via + # build + # sphinx +importlib-resources==6.1.1 + # via towncrier incremental==22.10.0 # via towncrier iniconfig==2.0.0 @@ -97,11 +107,11 @@ jinja2==3.1.2 # towncrier markupsafe==2.1.3 # via jinja2 -multidict==6.0.4 +multidict==6.0.5 # via # -r requirements/runtime-deps.in # yarl -mypy==1.7.1 ; implementation_name == "cpython" +mypy==1.11.1 ; implementation_name == "cpython" # via # -r requirements/lint.in # -r requirements/test.in @@ -119,15 +129,15 @@ pillow==9.5.0 # via # -c requirements/broken-projects.in # blockdiag -pip-tools==7.3.0 +pip-tools==7.4.1 # via -r requirements/dev.in platformdirs==3.10.0 # via virtualenv -pluggy==1.2.0 +pluggy==1.5.0 # via pytest pre-commit==3.5.0 # via -r requirements/lint.in -proxy-py==2.4.4rc4 +proxy-py==2.4.4 # via -r requirements/test.in pycares==4.3.0 # via aiodns @@ -144,21 +154,27 @@ pyjwt==2.8.0 # gidgethub # pyjwt pyproject-hooks==1.0.0 - # via build -pytest==7.4.3 + # via + # build + # pip-tools +pytest==8.3.2 # via # -r requirements/lint.in # -r requirements/test.in # pytest-cov # pytest-mock -pytest-cov==4.1.0 +pytest-cov==5.0.0 # via -r requirements/test.in -pytest-mock==3.12.0 +pytest-mock==3.14.0 # via -r requirements/test.in python-dateutil==2.8.2 # via freezegun -python-on-whales==0.67.0 - # via -r requirements/test.in +python-on-whales==0.72.0 + # via + # -r requirements/lint.in + # -r requirements/test.in +pytz==2023.3.post1 + # via babel pyyaml==6.0.1 # via pre-commit re-assert==1.1.0 @@ -174,7 +190,7 @@ setuptools-git==1.2 # via -r requirements/test.in six==1.16.0 # via python-dateutil -slotscheck==0.17.1 +slotscheck==0.19.0 # via -r requirements/lint.in snowballstemmer==2.2.0 # via sphinx @@ -220,9 +236,8 @@ trustme==1.1.0 ; platform_machine != "i686" # via -r requirements/test.in typer==0.9.0 # via python-on-whales -typing-extensions==4.7.1 +typing-extensions==4.11.0 # via - # -r requirements/typing-extensions.in # aioredis # annotated-types # mypy @@ -246,8 +261,12 @@ webcolors==1.13 # via blockdiag wheel==0.41.0 # via pip-tools -yarl==1.9.3 +yarl==1.9.4 # via -r requirements/runtime-deps.in +zipp==3.17.0 + # via + # importlib-metadata + # importlib-resources # The following packages are considered to be unsafe in a requirements file: pip==23.2.1 diff --git a/requirements/lint.in b/requirements/lint.in index 34616155912..0d46809a083 100644 --- a/requirements/lint.in +++ b/requirements/lint.in @@ -1,8 +1,11 @@ --r typing-extensions.in - +aiodns aioredis +freezegun mypy; implementation_name == "cpython" pre-commit pytest +pytest-mock +python-on-whales slotscheck +trustme uvloop; platform_system != "Windows" diff --git a/requirements/lint.txt b/requirements/lint.txt index 28d0bf65778..97809fe3dde 100644 --- a/requirements/lint.txt +++ b/requirements/lint.txt @@ -4,25 +4,45 @@ # # pip-compile --allow-unsafe --output-file=requirements/lint.txt --resolver=backtracking --strip-extras requirements/lint.in # +aiodns==3.2.0 + # via -r requirements/lint.in aioredis==2.0.1 # via -r requirements/lint.in +annotated-types==0.6.0 + # via pydantic async-timeout==4.0.3 # via aioredis +certifi==2024.2.2 + # via requests +cffi==1.16.0 + # via pycares cfgv==3.3.1 # via pre-commit +charset-normalizer==3.3.2 + # via requests click==8.1.6 - # via slotscheck + # via + # slotscheck + # typer distlib==0.3.7 # via virtualenv exceptiongroup==1.1.2 # via pytest filelock==3.12.2 # via virtualenv +freezegun==1.5.1 + # via -r requirements/lint.in identify==2.5.26 # via pre-commit +idna==3.7 + # via requests iniconfig==2.0.0 # via pytest -mypy==1.7.1 ; implementation_name == "cpython" +markdown-it-py==3.0.0 + # via rich +mdurl==0.1.2 + # via markdown-it-py +mypy==1.11.1 ; implementation_name == "cpython" # via -r requirements/lint.in mypy-extensions==1.0.0 # via mypy @@ -32,26 +52,59 @@ packaging==23.1 # via pytest platformdirs==3.10.0 # via virtualenv -pluggy==1.2.0 +pluggy==1.5.0 # via pytest pre-commit==3.5.0 # via -r requirements/lint.in -pytest==7.4.3 +pycares==4.4.0 + # via aiodns +pycparser==2.22 + # via cffi +pydantic==2.7.1 + # via python-on-whales +pydantic-core==2.18.2 + # via pydantic +pygments==2.17.2 + # via rich +pytest==8.3.2 + # via -r requirements/lint.in +pytest-mock==3.14.0 + # via -r requirements/lint.in +python-on-whales==0.72.0 # via -r requirements/lint.in pyyaml==6.0.1 # via pre-commit -slotscheck==0.17.1 +requests==2.31.0 + # via python-on-whales +rich==13.7.1 + # via typer +shellingham==1.5.4 + # via typer +slotscheck==0.19.0 # via -r requirements/lint.in tomli==2.0.1 # via # mypy # pytest # slotscheck -typing-extensions==4.7.1 +tqdm==4.66.2 + # via python-on-whales +trustme==1.1.0 + # via -r requirements/lint.in +typer==0.12.3 + # via python-on-whales +typing-extensions==4.11.0 # via - # -r requirements/typing-extensions.in # aioredis + # annotated-types # mypy + # pydantic + # pydantic-core + # python-on-whales + # rich + # typer +urllib3==2.2.1 + # via requests uvloop==0.19.0 ; platform_system != "Windows" # via -r requirements/lint.in virtualenv==20.24.2 diff --git a/requirements/multidict.txt b/requirements/multidict.txt index 9c4f984cd75..915f9c24dcc 100644 --- a/requirements/multidict.txt +++ b/requirements/multidict.txt @@ -4,5 +4,5 @@ # # pip-compile --allow-unsafe --output-file=requirements/multidict.txt --resolver=backtracking --strip-extras requirements/multidict.in # -multidict==6.0.4 +multidict==6.0.5 # via -r requirements/multidict.in diff --git a/requirements/runtime-deps.in b/requirements/runtime-deps.in index b2df16f1680..2299584a463 100644 --- a/requirements/runtime-deps.in +++ b/requirements/runtime-deps.in @@ -1,6 +1,7 @@ # Extracted from `setup.cfg` via `make sync-direct-runtime-deps` -aiodns; sys_platform=="linux" or sys_platform=="darwin" +aiodns >= 3.2.0; sys_platform=="linux" or sys_platform=="darwin" +aiohappyeyeballs >= 2.3.0 aiosignal >= 1.1.2 async-timeout >= 4.0, < 5.0 ; python_version < "3.11" attrs >= 17.3.0 diff --git a/requirements/runtime-deps.txt b/requirements/runtime-deps.txt index a0f2aa861f7..5f98dceaf9c 100644 --- a/requirements/runtime-deps.txt +++ b/requirements/runtime-deps.txt @@ -4,25 +4,27 @@ # # pip-compile --allow-unsafe --output-file=requirements/runtime-deps.txt --strip-extras requirements/runtime-deps.in # -aiodns==3.1.1 ; sys_platform == "linux" or sys_platform == "darwin" +aiodns==3.2.0 ; sys_platform == "linux" or sys_platform == "darwin" + # via -r requirements/runtime-deps.in +aiohappyeyeballs==2.3.4 # via -r requirements/runtime-deps.in aiosignal==1.3.1 # via -r requirements/runtime-deps.in async-timeout==4.0.3 ; python_version < "3.11" # via -r requirements/runtime-deps.in -attrs==23.1.0 +attrs==23.2.0 # via -r requirements/runtime-deps.in brotli==1.1.0 ; platform_python_implementation == "CPython" # via -r requirements/runtime-deps.in cffi==1.15.1 # via pycares -frozenlist==1.4.0 +frozenlist==1.4.1 # via # -r requirements/runtime-deps.in # aiosignal idna==3.4 # via yarl -multidict==6.0.4 +multidict==6.0.5 # via # -r requirements/runtime-deps.in # yarl @@ -30,5 +32,5 @@ pycares==4.3.0 # via aiodns pycparser==2.21 # via cffi -yarl==1.9.3 +yarl==1.9.4 # via -r requirements/runtime-deps.in diff --git a/requirements/test.txt b/requirements/test.txt index 57c00fc2439..803705f6da0 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -4,7 +4,9 @@ # # pip-compile --allow-unsafe --output-file=requirements/test.txt --resolver=backtracking --strip-extras requirements/test.in # -aiodns==3.1.1 ; sys_platform == "linux" or sys_platform == "darwin" +aiodns==3.2.0 ; sys_platform == "linux" or sys_platform == "darwin" + # via -r requirements/runtime-deps.in +aiohappyeyeballs==2.3.4 # via -r requirements/runtime-deps.in aiosignal==1.3.1 # via -r requirements/runtime-deps.in @@ -12,7 +14,7 @@ annotated-types==0.5.0 # via pydantic async-timeout==4.0.3 ; python_version < "3.11" # via -r requirements/runtime-deps.in -attrs==23.1.0 +attrs==23.2.0 # via -r requirements/runtime-deps.in brotli==1.1.0 ; platform_python_implementation == "CPython" # via -r requirements/runtime-deps.in @@ -28,7 +30,7 @@ click==8.1.6 # via # typer # wait-for-it -coverage==7.3.2 +coverage==7.6.0 # via # -r requirements/test.in # pytest-cov @@ -36,13 +38,13 @@ cryptography==41.0.2 # via trustme exceptiongroup==1.1.2 # via pytest -freezegun==1.3.0 +freezegun==1.5.1 # via -r requirements/test.in -frozenlist==1.4.0 +frozenlist==1.4.1 # via # -r requirements/runtime-deps.in # aiosignal -gunicorn==21.2.0 +gunicorn==22.0.0 # via -r requirements/base.in idna==3.4 # via @@ -51,11 +53,11 @@ idna==3.4 # yarl iniconfig==2.0.0 # via pytest -multidict==6.0.4 +multidict==6.0.5 # via # -r requirements/runtime-deps.in # yarl -mypy==1.7.1 ; implementation_name == "cpython" +mypy==1.11.1 ; implementation_name == "cpython" # via -r requirements/test.in mypy-extensions==1.0.0 # via mypy @@ -63,9 +65,9 @@ packaging==23.1 # via # gunicorn # pytest -pluggy==1.2.0 +pluggy==1.5.0 # via pytest -proxy-py==2.4.4rc4 +proxy-py==2.4.4 # via -r requirements/test.in pycares==4.3.0 # via aiodns @@ -75,18 +77,18 @@ pydantic==2.2.0 # via python-on-whales pydantic-core==2.6.0 # via pydantic -pytest==7.4.3 +pytest==8.3.2 # via # -r requirements/test.in # pytest-cov # pytest-mock -pytest-cov==4.1.0 +pytest-cov==5.0.0 # via -r requirements/test.in -pytest-mock==3.12.0 +pytest-mock==3.14.0 # via -r requirements/test.in python-dateutil==2.8.2 # via freezegun -python-on-whales==0.67.0 +python-on-whales==0.72.0 # via -r requirements/test.in re-assert==1.1.0 # via -r requirements/test.in @@ -109,9 +111,8 @@ trustme==1.1.0 ; platform_machine != "i686" # via -r requirements/test.in typer==0.9.0 # via python-on-whales -typing-extensions==4.7.1 +typing-extensions==4.11.0 # via - # -r requirements/typing-extensions.in # annotated-types # mypy # pydantic @@ -124,5 +125,5 @@ uvloop==0.19.0 ; platform_system != "Windows" and implementation_name == "cpytho # via -r requirements/base.in wait-for-it==2.2.2 # via -r requirements/test.in -yarl==1.9.3 +yarl==1.9.4 # via -r requirements/runtime-deps.in diff --git a/requirements/typing-extensions.in b/requirements/typing-extensions.in deleted file mode 100644 index 5fd4f05f341..00000000000 --- a/requirements/typing-extensions.in +++ /dev/null @@ -1 +0,0 @@ -typing_extensions diff --git a/requirements/typing-extensions.txt b/requirements/typing-extensions.txt deleted file mode 100644 index c45af7262f7..00000000000 --- a/requirements/typing-extensions.txt +++ /dev/null @@ -1,8 +0,0 @@ -# -# This file is autogenerated by pip-compile with python 3.8 -# by the following command: -# -# pip-compile --allow-unsafe --output-file=requirements/typing-extensions.txt --resolver=backtracking --strip-extras requirements/typing-extensions.in -# -typing-extensions==4.7.1 - # via -r requirements/typing-extensions.in diff --git a/setup.cfg b/setup.cfg index 15d22a2f5f7..cfd1be5610f 100644 --- a/setup.cfg +++ b/setup.cfg @@ -49,9 +49,10 @@ zip_safe = False include_package_data = True install_requires = + aiohappyeyeballs >= 2.3.0 aiosignal >= 1.1.2 - attrs >= 17.3.0 async-timeout >= 4.0, < 5.0 ; python_version < "3.11" + attrs >= 17.3.0 frozenlist >= 1.1.1 multidict >=4.5, < 7.0 yarl >= 1.0, < 2.0 @@ -64,7 +65,7 @@ install_requires = [options.extras_require] speedups = # required c-ares (aiodns' backend) will not build on windows - aiodns; sys_platform=="linux" or sys_platform=="darwin" + aiodns >= 3.2.0; sys_platform=="linux" or sys_platform=="darwin" Brotli; platform_python_implementation == 'CPython' brotlicffi; platform_python_implementation != 'CPython' diff --git a/tests/autobahn/test_autobahn.py b/tests/autobahn/test_autobahn.py index f30f6afd693..651183d5f92 100644 --- a/tests/autobahn/test_autobahn.py +++ b/tests/autobahn/test_autobahn.py @@ -73,7 +73,8 @@ def test_client(report_dir: Path, request: Any) -> None: print("Stopping client and server") client.terminate() client.wait() - autobahn_container.stop() + # https://github.com/gabrieldemarmiesse/python-on-whales/pull/580 + autobahn_container.stop() # type: ignore[union-attr] failed_messages = get_failed_tests(f"{report_dir}/clients", "aiohttp") diff --git a/tests/conftest.py b/tests/conftest.py index 44e5fb7285c..1cb64b3a6f8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,15 +1,19 @@ import asyncio +import base64 import os import socket import ssl import sys -from hashlib import md5, sha256 +from hashlib import md5, sha1, sha256 from pathlib import Path from tempfile import TemporaryDirectory +from typing import Any +from unittest import mock from uuid import uuid4 import pytest +from aiohttp.http import WS_KEY from aiohttp.test_utils import loop_context try: @@ -167,6 +171,17 @@ def pipe_name(): return name +@pytest.fixture +def create_mocked_conn(loop: Any): + def _proto_factory(conn_closing_result=None, **kwargs): + proto = mock.Mock(**kwargs) + proto.closed = loop.create_future() + proto.closed.set_result(conn_closing_result) + return proto + + yield _proto_factory + + @pytest.fixture def selector_loop(): policy = asyncio.WindowsSelectorEventLoopPolicy() @@ -197,3 +212,28 @@ def netrc_contents( monkeypatch.setenv("NETRC", str(netrc_file_path)) return netrc_file_path + + +@pytest.fixture +def start_connection(): + with mock.patch( + "aiohttp.connector.aiohappyeyeballs.start_connection", + autospec=True, + spec_set=True, + ) as start_connection_mock: + yield start_connection_mock + + +@pytest.fixture +def key_data(): + return os.urandom(16) + + +@pytest.fixture +def key(key_data: Any): + return base64.b64encode(key_data) + + +@pytest.fixture +def ws_key(key: Any): + return base64.b64encode(sha1(key + WS_KEY).digest()).decode() diff --git a/tests/test_circular_imports.py b/tests/test_circular_imports.py index 516326444c5..d513e9bde8b 100644 --- a/tests/test_circular_imports.py +++ b/tests/test_circular_imports.py @@ -8,6 +8,7 @@ * https://github.com/pytest-dev/pytest/blob/d18c75b/testing/test_meta.py * https://twitter.com/codewithanthony/status/1229445110510735361 """ + import os import pkgutil import socket @@ -30,14 +31,16 @@ def _mark_aiohttp_worker_for_skipping( importables: List[str], ) -> List[Union[str, "ParameterSet"]]: return [ - pytest.param( - importable, - marks=pytest.mark.skipif( - not hasattr(socket, "AF_UNIX"), reason="It's a UNIX-only module" - ), + ( + pytest.param( + importable, + marks=pytest.mark.skipif( + not hasattr(socket, "AF_UNIX"), reason="It's a UNIX-only module" + ), + ) + if importable == "aiohttp.worker" + else importable ) - if importable == "aiohttp.worker" - else importable for importable in importables ] diff --git a/tests/test_client_exceptions.py b/tests/test_client_exceptions.py index f70ba5d09a6..d863d6674a3 100644 --- a/tests/test_client_exceptions.py +++ b/tests/test_client_exceptions.py @@ -5,6 +5,7 @@ from unittest import mock import pytest +from yarl import URL from aiohttp import client, client_reqrep @@ -298,8 +299,9 @@ def test_repr(self) -> None: class TestInvalidURL: def test_ctor(self) -> None: - err = client.InvalidURL(url=":wrong:url:") + err = client.InvalidURL(url=":wrong:url:", description=":description:") assert err.url == ":wrong:url:" + assert err.description == ":description:" def test_pickle(self) -> None: err = client.InvalidURL(url=":wrong:url:") @@ -310,10 +312,27 @@ def test_pickle(self) -> None: assert err2.url == ":wrong:url:" assert err2.foo == "bar" - def test_repr(self) -> None: + def test_repr_no_description(self) -> None: err = client.InvalidURL(url=":wrong:url:") + assert err.args == (":wrong:url:",) assert repr(err) == "" - def test_str(self) -> None: + def test_repr_yarl_URL(self) -> None: + err = client.InvalidURL(url=URL(":wrong:url:")) + assert repr(err) == "" + + def test_repr_with_description(self) -> None: + err = client.InvalidURL(url=":wrong:url:", description=":description:") + assert repr(err) == "" + + def test_str_no_description(self) -> None: err = client.InvalidURL(url=":wrong:url:") assert str(err) == ":wrong:url:" + + def test_none_description(self) -> None: + err = client.InvalidURL(":wrong:url:") + assert err.description is None + + def test_str_with_description(self) -> None: + err = client.InvalidURL(url=":wrong:url:", description=":description:") + assert str(err) == ":wrong:url: - :description:" diff --git a/tests/test_client_functional.py b/tests/test_client_functional.py index dbb2dff5ac4..872876d4a32 100644 --- a/tests/test_client_functional.py +++ b/tests/test_client_functional.py @@ -8,7 +8,9 @@ import pathlib import socket import ssl -from typing import Any, AsyncIterator +import sys +import time +from typing import Any, AsyncIterator, Type from unittest import mock import pytest @@ -18,7 +20,15 @@ import aiohttp from aiohttp import Fingerprint, ServerFingerprintMismatch, hdrs, web from aiohttp.abc import AbstractResolver -from aiohttp.client_exceptions import TooManyRedirects +from aiohttp.client_exceptions import ( + InvalidURL, + InvalidUrlClientError, + InvalidUrlRedirectClientError, + NonHttpUrlClientError, + NonHttpUrlRedirectClientError, + SocketTimeoutError, + TooManyRedirects, +) from aiohttp.pytest_plugin import AiohttpClient, TestClient from aiohttp.test_utils import unused_port @@ -214,6 +224,67 @@ async def handler(request): assert 0 == len(client._session.connector._conns) +async def test_keepalive_timeout_async_sleep() -> None: + async def handler(request): + body = await request.read() + assert b"" == body + return web.Response(body=b"OK") + + app = web.Application() + app.router.add_route("GET", "/", handler) + + runner = web.AppRunner(app, tcp_keepalive=True, keepalive_timeout=0.001) + await runner.setup() + + port = unused_port() + site = web.TCPSite(runner, host="localhost", port=port) + await site.start() + + try: + async with aiohttp.client.ClientSession() as sess: + resp1 = await sess.get(f"http://localhost:{port}/") + await resp1.read() + # wait for server keepalive_timeout + await asyncio.sleep(0.01) + resp2 = await sess.get(f"http://localhost:{port}/") + await resp2.read() + finally: + await asyncio.gather(runner.shutdown(), site.stop()) + + +@pytest.mark.skipif( + sys.version_info[:2] == (3, 11), + reason="https://github.com/pytest-dev/pytest/issues/10763", +) +async def test_keepalive_timeout_sync_sleep() -> None: + async def handler(request): + body = await request.read() + assert b"" == body + return web.Response(body=b"OK") + + app = web.Application() + app.router.add_route("GET", "/", handler) + + runner = web.AppRunner(app, tcp_keepalive=True, keepalive_timeout=0.001) + await runner.setup() + + port = unused_port() + site = web.TCPSite(runner, host="localhost", port=port) + await site.start() + + try: + async with aiohttp.client.ClientSession() as sess: + resp1 = await sess.get(f"http://localhost:{port}/") + await resp1.read() + # wait for server keepalive_timeout + # time.sleep is a more challenging scenario than asyncio.sleep + time.sleep(0.01) + resp2 = await sess.get(f"http://localhost:{port}/") + await resp2.read() + finally: + await asyncio.gather(runner.shutdown(), site.stop()) + + async def test_release_early(aiohttp_client) -> None: async def handler(request): await request.read() @@ -509,8 +580,6 @@ async def handler(request): async def test_format_task_get(aiohttp_server) -> None: - loop = asyncio.get_event_loop() - async def handler(request): return web.Response(body=b"OK") @@ -518,7 +587,7 @@ async def handler(request): app.router.add_route("GET", "/", handler) server = await aiohttp_server(app) client = aiohttp.ClientSession() - task = loop.create_task(client.get(server.make_url("/"))) + task = asyncio.create_task(client.get(server.make_url("/"))) assert f"{task}".startswith(" None: + async with aiohttp.ClientSession() as http_session: + with pytest.raises( + expected_exception_class, match=rf"^{error_message_url}( - [A-Za-z ]+)?" + ): + await http_session.get(url) + + +@pytest.mark.parametrize( + ("invalid_redirect_url", "error_message_url", "expected_exception_class"), + ( + *( + (url, message, InvalidUrlRedirectClientError) + for (url, message) in INVALID_URL_WITH_ERROR_MESSAGE_YARL_ORIGIN + + INVALID_URL_WITH_ERROR_MESSAGE_YARL_NEW + ), + *( + (url, message, NonHttpUrlRedirectClientError) + for (url, message) in NON_HTTP_URL_WITH_ERROR_MESSAGE + ), + ), +) +async def test_invalid_redirect_url( + aiohttp_client: Any, + invalid_redirect_url: Any, + error_message_url: str, + expected_exception_class: Any, +) -> None: + headers = {hdrs.LOCATION: invalid_redirect_url} + + async def generate_redirecting_response(request): + return web.Response(status=301, headers=headers) + + app = web.Application() + app.router.add_get("/redirect", generate_redirecting_response) + client = await aiohttp_client(app) + + with pytest.raises( + expected_exception_class, match=rf"^{error_message_url}( - [A-Za-z ]+)?" + ): + await client.get("/redirect") + + +@pytest.mark.parametrize( + ("invalid_redirect_url", "error_message_url", "expected_exception_class"), + ( + *( + (url, message, InvalidUrlRedirectClientError) + for (url, message) in INVALID_URL_WITH_ERROR_MESSAGE_YARL_ORIGIN + + INVALID_URL_WITH_ERROR_MESSAGE_YARL_NEW + ), + *( + (url, message, NonHttpUrlRedirectClientError) + for (url, message) in NON_HTTP_URL_WITH_ERROR_MESSAGE + ), + ), +) +async def test_invalid_redirect_url_multiple_redirects( + aiohttp_client: Any, + invalid_redirect_url: Any, + error_message_url: str, + expected_exception_class: Any, +) -> None: + app = web.Application() + + for path, location in [ + ("/redirect", "/redirect1"), + ("/redirect1", "/redirect2"), + ("/redirect2", invalid_redirect_url), + ]: + + async def generate_redirecting_response(request): + return web.Response(status=301, headers={hdrs.LOCATION: location}) + + app.router.add_get(path, generate_redirecting_response) + + client = await aiohttp_client(app) + + with pytest.raises( + expected_exception_class, match=rf"^{error_message_url}( - [A-Za-z ]+)?" + ): + await client.get("/redirect") + + @pytest.mark.parametrize( ("status", "expected_ok"), ( @@ -3001,21 +3196,20 @@ def connection_lost(self, exc): addr = server.sockets[0].getsockname() - connector = aiohttp.TCPConnector(limit=1) - session = aiohttp.ClientSession(connector=connector) - - url = "http://{}:{}/".format(*addr) + async with aiohttp.TCPConnector(limit=1) as connector: + async with aiohttp.ClientSession(connector=connector) as session: + url = "http://{}:{}/".format(*addr) - r = await session.request("GET", url) - await r.read() - assert 1 == len(connector._conns) + r = await session.request("GET", url) + await r.read() + assert 1 == len(connector._conns) + closed_conn = next(iter(connector._conns.values())) - with pytest.raises(aiohttp.ClientConnectionError): - await session.request("GET", url) - assert 0 == len(connector._conns) + await session.request("GET", url) + assert 1 == len(connector._conns) + new_conn = next(iter(connector._conns.values())) + assert closed_conn is not new_conn - await session.close() - await connector.close() server.close() await server.wait_closed() @@ -3164,6 +3358,21 @@ async def handler(request): await client.get("/") +async def test_socket_timeout(aiohttp_client: Any) -> None: + async def handler(request): + await asyncio.sleep(5) + return web.Response() + + app = web.Application() + app.add_routes([web.get("/", handler)]) + + timeout = aiohttp.ClientTimeout(sock_read=0.1) + client = await aiohttp_client(app, timeout=timeout) + + with pytest.raises(SocketTimeoutError): + await client.get("/") + + async def test_read_timeout_closes_connection(aiohttp_client: AiohttpClient) -> None: request_count = 0 @@ -3413,3 +3622,28 @@ async def not_ok_handler(request): "/ok", timeout=aiohttp.ClientTimeout(total=0.01) ) as resp_ok: assert 200 == resp_ok.status + + +@pytest.mark.parametrize( + ("value", "exc_type"), + [(42, TypeError), ("InvalidUrl", InvalidURL)], +) +async def test_request_with_wrong_proxy( + aiohttp_client: AiohttpClient, value: Any, exc_type: Type[Exception] +) -> None: + app = web.Application() + session = await aiohttp_client(app) + + with pytest.raises(exc_type): + await session.get("/", proxy=value) # type: ignore[arg-type] + + +async def test_raise_for_status_is_none(aiohttp_client: AiohttpClient) -> None: + async def handler(_: web.Request) -> web.Response: + return web.Response() + + app = web.Application() + app.router.add_get("/", handler) + session = await aiohttp_client(app, raise_for_status=None) # type: ignore[arg-type] + + await session.get("/") diff --git a/tests/test_client_request.py b/tests/test_client_request.py index 6084f685405..7d9f69b52f0 100644 --- a/tests/test_client_request.py +++ b/tests/test_client_request.py @@ -996,8 +996,15 @@ async def gen(): req = ClientRequest("POST", URL("http://python.org/"), data=gen(), loop=loop) assert req.chunked assert req.headers["TRANSFER-ENCODING"] == "chunked" + original_write_bytes = req.write_bytes - resp = await req.send(conn) + async def _mock_write_bytes(*args, **kwargs): + # Ensure the task is scheduled + await asyncio.sleep(0) + return await original_write_bytes(*args, **kwargs) + + with mock.patch.object(req, "write_bytes", _mock_write_bytes): + resp = await req.send(conn) assert asyncio.isfuture(req._writer) await resp.wait_for_close() assert req._writer is None @@ -1020,9 +1027,7 @@ async def gen(writer): assert req.headers["TRANSFER-ENCODING"] == "chunked" resp = await req.send(conn) - assert asyncio.isfuture(req._writer) await resp.wait_for_close() - assert req._writer is None assert ( buf.split(b"\r\n\r\n", 1)[1] == b"b\r\nbinary data\r\n7\r\n result\r\n0\r\n\r\n" ) @@ -1203,14 +1208,28 @@ async def test_oserror_on_write_bytes(loop, conn) -> None: async def test_terminate(loop, conn) -> None: req = ClientRequest("get", URL("http://python.org"), loop=loop) - resp = await req.send(conn) + + async def _mock_write_bytes(*args, **kwargs): + # Ensure the task is scheduled + await asyncio.sleep(0) + + with mock.patch.object(req, "write_bytes", _mock_write_bytes): + resp = await req.send(conn) + assert req._writer is not None - writer = req._writer = WriterMock() + assert resp._writer is not None + await resp._writer + writer = WriterMock() + writer.done = mock.Mock(return_value=False) writer.cancel = mock.Mock() + req._writer = writer + resp._writer = writer + assert req._writer is not None + assert resp._writer is not None req.terminate() - assert req._writer is None writer.cancel.assert_called_with() + writer.done.assert_called_with() resp.close() await req.close() @@ -1222,9 +1241,19 @@ def test_terminate_with_closed_loop(loop, conn) -> None: async def go(): nonlocal req, resp, writer req = ClientRequest("get", URL("http://python.org")) - resp = await req.send(conn) + + async def _mock_write_bytes(*args, **kwargs): + # Ensure the task is scheduled + await asyncio.sleep(0) + + with mock.patch.object(req, "write_bytes", _mock_write_bytes): + resp = await req.send(conn) + assert req._writer is not None - writer = req._writer = WriterMock() + writer = WriterMock() + writer.done = mock.Mock(return_value=False) + req._writer = writer + resp._writer = writer await asyncio.sleep(0.05) diff --git a/tests/test_client_session.py b/tests/test_client_session.py index 416b6bbce5d..051c0aeba24 100644 --- a/tests/test_client_session.py +++ b/tests/test_client_session.py @@ -4,8 +4,9 @@ import io import json from http.cookies import SimpleCookie -from typing import Any, List +from typing import Any, Awaitable, Callable, List from unittest import mock +from uuid import uuid4 import pytest from multidict import CIMultiDict, MultiDict @@ -15,10 +16,12 @@ import aiohttp from aiohttp import client, hdrs, web from aiohttp.client import ClientSession +from aiohttp.client_proto import ResponseHandler from aiohttp.client_reqrep import ClientRequest -from aiohttp.connector import BaseConnector, TCPConnector +from aiohttp.connector import BaseConnector, Connection, TCPConnector, UnixConnector from aiohttp.helpers import DEBUG from aiohttp.test_utils import make_mocked_coro +from aiohttp.tracing import Trace @pytest.fixture @@ -471,7 +474,124 @@ async def create_connection(req, traces, timeout): c.__del__() -async def test_cookie_jar_usage(loop, aiohttp_client) -> None: +@pytest.mark.parametrize("protocol", ["http", "https", "ws", "wss"]) +async def test_ws_connect_allowed_protocols( + create_session: Any, + create_mocked_conn: Any, + protocol: str, + ws_key: Any, + key_data: Any, +) -> None: + resp = mock.create_autospec(aiohttp.ClientResponse) + resp.status = 101 + resp.headers = { + hdrs.UPGRADE: "websocket", + hdrs.CONNECTION: "upgrade", + hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, + } + resp.url = URL(f"{protocol}://example") + resp.cookies = SimpleCookie() + resp.start = mock.AsyncMock() + + req = mock.create_autospec(aiohttp.ClientRequest, spec_set=True) + req_factory = mock.Mock(return_value=req) + req.send = mock.AsyncMock(return_value=resp) + # BaseConnector allows all high level protocols by default + connector = BaseConnector() + + session = await create_session(connector=connector, request_class=req_factory) + + connections = [] + original_connect = session._connector.connect + + async def connect(req, traces, timeout): + conn = await original_connect(req, traces, timeout) + connections.append(conn) + return conn + + async def create_connection(req, traces, timeout): + return create_mocked_conn() + + connector = session._connector + with mock.patch.object(connector, "connect", connect), mock.patch.object( + connector, "_create_connection", create_connection + ), mock.patch.object(connector, "_release"), mock.patch( + "aiohttp.client.os" + ) as m_os: + m_os.urandom.return_value = key_data + await session.ws_connect(f"{protocol}://example") + + # normally called during garbage collection. triggers an exception + # if the connection wasn't already closed + for c in connections: + c.close() + c.__del__() + + await session.close() + + +@pytest.mark.parametrize("protocol", ["http", "https", "ws", "wss", "unix"]) +async def test_ws_connect_unix_socket_allowed_protocols( + create_session: Callable[..., Awaitable[ClientSession]], + create_mocked_conn: Callable[[], ResponseHandler], + protocol: str, + ws_key: bytes, + key_data: bytes, +) -> None: + resp = mock.create_autospec(aiohttp.ClientResponse) + resp.status = 101 + resp.headers = { + hdrs.UPGRADE: "websocket", + hdrs.CONNECTION: "upgrade", + hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, + } + resp.url = URL(f"{protocol}://example") + resp.cookies = SimpleCookie() + resp.start = mock.AsyncMock() + + req = mock.create_autospec(aiohttp.ClientRequest, spec_set=True) + req_factory = mock.Mock(return_value=req) + req.send = mock.AsyncMock(return_value=resp) + # UnixConnector allows all high level protocols by default and unix sockets + session = await create_session( + connector=UnixConnector(path=""), request_class=req_factory + ) + + connections = [] + assert session._connector is not None + original_connect = session._connector.connect + + async def connect( + req: ClientRequest, traces: List[Trace], timeout: aiohttp.ClientTimeout + ) -> Connection: + conn = await original_connect(req, traces, timeout) + connections.append(conn) + return conn + + async def create_connection( + req: object, traces: object, timeout: object + ) -> ResponseHandler: + return create_mocked_conn() + + connector = session._connector + with mock.patch.object(connector, "connect", connect), mock.patch.object( + connector, "_create_connection", create_connection + ), mock.patch.object(connector, "_release"), mock.patch( + "aiohttp.client.os" + ) as m_os: + m_os.urandom.return_value = key_data + await session.ws_connect(f"{protocol}://example") + + # normally called during garbage collection. triggers an exception + # if the connection wasn't already closed + for c in connections: + c.close() + c.__del__() + + await session.close() + + +async def test_cookie_jar_usage(loop: Any, aiohttp_client: Any) -> None: req_url = None jar = mock.Mock() @@ -895,3 +1015,23 @@ async def test_instantiation_with_invalid_timeout_value(loop): ClientSession(timeout=1) # should not have "Unclosed client session" warning assert not logs + + +@pytest.mark.parametrize( + ("outer_name", "inner_name"), + [ + ("skip_auto_headers", "_skip_auto_headers"), + ("auth", "_default_auth"), + ("json_serialize", "_json_serialize"), + ("connector_owner", "_connector_owner"), + ("raise_for_status", "_raise_for_status"), + ("trust_env", "_trust_env"), + ("trace_configs", "_trace_configs"), + ], +) +async def test_properties( + session: ClientSession, outer_name: str, inner_name: str +) -> None: + value = uuid4() + setattr(session, inner_name, value) + assert value == getattr(session, outer_name) diff --git a/tests/test_client_ws.py b/tests/test_client_ws.py index f0b7757e420..a790fba43ec 100644 --- a/tests/test_client_ws.py +++ b/tests/test_client_ws.py @@ -2,33 +2,20 @@ import base64 import hashlib import os +from typing import Any from unittest import mock import pytest import aiohttp from aiohttp import client, hdrs +from aiohttp.client_exceptions import ServerDisconnectedError from aiohttp.http import WS_KEY from aiohttp.streams import EofStream from aiohttp.test_utils import make_mocked_coro -@pytest.fixture -def key_data(): - return os.urandom(16) - - -@pytest.fixture -def key(key_data): - return base64.b64encode(key_data) - - -@pytest.fixture -def ws_key(key): - return base64.b64encode(hashlib.sha1(key + WS_KEY).digest()).decode() - - -async def test_ws_connect(ws_key, loop, key_data) -> None: +async def test_ws_connect(ws_key: Any, loop: Any, key_data: Any) -> None: resp = mock.Mock() resp.status = 101 resp.headers = { @@ -37,6 +24,7 @@ async def test_ws_connect(ws_key, loop, key_data) -> None: hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, hdrs.SEC_WEBSOCKET_PROTOCOL: "chat", } + resp.connection.protocol.read_timeout = None with mock.patch("aiohttp.client.os") as m_os: with mock.patch("aiohttp.client.ClientSession.request") as m_req: m_os.urandom.return_value = key_data @@ -52,6 +40,97 @@ async def test_ws_connect(ws_key, loop, key_data) -> None: assert hdrs.ORIGIN not in m_req.call_args[1]["headers"] +async def test_ws_connect_read_timeout_is_reset_to_inf( + ws_key: Any, loop: Any, key_data: Any +) -> None: + resp = mock.Mock() + resp.status = 101 + resp.headers = { + hdrs.UPGRADE: "websocket", + hdrs.CONNECTION: "upgrade", + hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, + hdrs.SEC_WEBSOCKET_PROTOCOL: "chat", + } + resp.connection.protocol.read_timeout = 0.5 + with mock.patch("aiohttp.client.os") as m_os, mock.patch( + "aiohttp.client.ClientSession.request" + ) as m_req: + m_os.urandom.return_value = key_data + m_req.return_value = loop.create_future() + m_req.return_value.set_result(resp) + + res = await aiohttp.ClientSession().ws_connect( + "http://test.org", protocols=("t1", "t2", "chat") + ) + + assert isinstance(res, client.ClientWebSocketResponse) + assert res.protocol == "chat" + assert hdrs.ORIGIN not in m_req.call_args[1]["headers"] + assert resp.connection.protocol.read_timeout is None + + +async def test_ws_connect_read_timeout_stays_inf( + ws_key: Any, loop: Any, key_data: Any +) -> None: + resp = mock.Mock() + resp.status = 101 + resp.headers = { + hdrs.UPGRADE: "websocket", + hdrs.CONNECTION: "upgrade", + hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, + hdrs.SEC_WEBSOCKET_PROTOCOL: "chat", + } + resp.connection.protocol.read_timeout = None + with mock.patch("aiohttp.client.os") as m_os, mock.patch( + "aiohttp.client.ClientSession.request" + ) as m_req: + m_os.urandom.return_value = key_data + m_req.return_value = loop.create_future() + m_req.return_value.set_result(resp) + + res = await aiohttp.ClientSession().ws_connect( + "http://test.org", + protocols=("t1", "t2", "chat"), + receive_timeout=0.5, + ) + + assert isinstance(res, client.ClientWebSocketResponse) + assert res.protocol == "chat" + assert hdrs.ORIGIN not in m_req.call_args[1]["headers"] + assert resp.connection.protocol.read_timeout is None + + +async def test_ws_connect_read_timeout_reset_to_max( + ws_key: Any, loop: Any, key_data: Any +) -> None: + resp = mock.Mock() + resp.status = 101 + resp.headers = { + hdrs.UPGRADE: "websocket", + hdrs.CONNECTION: "upgrade", + hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, + hdrs.SEC_WEBSOCKET_PROTOCOL: "chat", + } + resp.connection.protocol.read_timeout = 0.5 + with mock.patch("aiohttp.client.os") as m_os, mock.patch( + "aiohttp.client.ClientSession.request" + ) as m_req: + m_os.urandom.return_value = key_data + m_req.return_value = loop.create_future() + m_req.return_value.set_result(resp) + + res = await aiohttp.ClientSession().ws_connect( + "http://test.org", + protocols=("t1", "t2", "chat"), + receive_timeout=1.0, + ) + + assert isinstance(res, client.ClientWebSocketResponse) + assert res.protocol == "chat" + assert hdrs.ORIGIN not in m_req.call_args[1]["headers"] + assert resp.connection.protocol.read_timeout == 1.0 + + async def test_ws_connect_with_origin(key_data, loop) -> None: resp = mock.Mock() resp.status = 403 @@ -82,6 +161,7 @@ async def test_ws_connect_with_params(ws_key, loop, key_data) -> None: hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, hdrs.SEC_WEBSOCKET_PROTOCOL: "chat", } + resp.connection.protocol.read_timeout = None with mock.patch("aiohttp.client.os") as m_os: with mock.patch("aiohttp.client.ClientSession.request") as m_req: m_os.urandom.return_value = key_data @@ -107,6 +187,7 @@ def read(self, decode=False): hdrs.CONNECTION: "upgrade", hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, } + resp.connection.protocol.read_timeout = None with mock.patch("aiohttp.client.os") as m_os: with mock.patch("aiohttp.client.ClientSession.request") as m_req: m_os.urandom.return_value = key_data @@ -229,6 +310,7 @@ async def mock_get(*args, **kwargs): hdrs.SEC_WEBSOCKET_ACCEPT: accept, hdrs.SEC_WEBSOCKET_PROTOCOL: "chat", } + resp.connection.protocol.read_timeout = None return resp with mock.patch("aiohttp.client.os") as m_os: @@ -259,6 +341,7 @@ async def test_close(loop, ws_key, key_data) -> None: hdrs.CONNECTION: "upgrade", hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, } + resp.connection.protocol.read_timeout = None with mock.patch("aiohttp.client.WebSocketWriter") as WebSocketWriter: with mock.patch("aiohttp.client.os") as m_os: with mock.patch("aiohttp.client.ClientSession.request") as m_req: @@ -299,6 +382,7 @@ async def test_close_eofstream(loop, ws_key, key_data) -> None: hdrs.CONNECTION: "upgrade", hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, } + resp.connection.protocol.read_timeout = None with mock.patch("aiohttp.client.WebSocketWriter") as WebSocketWriter: with mock.patch("aiohttp.client.os") as m_os: with mock.patch("aiohttp.client.ClientSession.request") as m_req: @@ -321,20 +405,56 @@ async def test_close_eofstream(loop, ws_key, key_data) -> None: await session.close() -async def test_close_exc(loop, ws_key, key_data) -> None: - resp = mock.Mock() - resp.status = 101 - resp.headers = { +async def test_close_connection_lost( + loop: asyncio.AbstractEventLoop, ws_key: bytes, key_data: bytes +) -> None: + """Test the websocket client handles the connection being closed out from under it.""" + mresp = mock.Mock(spec_set=client.ClientResponse) + mresp.status = 101 + mresp.headers = { hdrs.UPGRADE: "websocket", hdrs.CONNECTION: "upgrade", hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, } + mresp.connection.protocol.read_timeout = None + with mock.patch("aiohttp.client.WebSocketWriter"), mock.patch( + "aiohttp.client.os" + ) as m_os, mock.patch("aiohttp.client.ClientSession.request") as m_req: + m_os.urandom.return_value = key_data + m_req.return_value = loop.create_future() + m_req.return_value.set_result(mresp) + + session = aiohttp.ClientSession() + resp = await session.ws_connect("http://test.org") + assert not resp.closed + + exc = ServerDisconnectedError() + resp._reader.set_exception(exc) + + msg = await resp.receive() + assert msg.type is aiohttp.WSMsgType.CLOSED + assert resp.closed + + await session.close() + + +async def test_close_exc( + loop: asyncio.AbstractEventLoop, ws_key: bytes, key_data: bytes +) -> None: + mresp = mock.Mock() + mresp.status = 101 + mresp.headers = { + hdrs.UPGRADE: "websocket", + hdrs.CONNECTION: "upgrade", + hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, + } + mresp.connection.protocol.read_timeout = None with mock.patch("aiohttp.client.WebSocketWriter") as WebSocketWriter: with mock.patch("aiohttp.client.os") as m_os: with mock.patch("aiohttp.client.ClientSession.request") as m_req: m_os.urandom.return_value = key_data m_req.return_value = loop.create_future() - m_req.return_value.set_result(resp) + m_req.return_value.set_result(mresp) writer = mock.Mock() WebSocketWriter.return_value = writer writer.close = make_mocked_coro() @@ -361,6 +481,7 @@ async def test_close_exc2(loop, ws_key, key_data) -> None: hdrs.CONNECTION: "upgrade", hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, } + resp.connection.protocol.read_timeout = None with mock.patch("aiohttp.client.WebSocketWriter") as WebSocketWriter: with mock.patch("aiohttp.client.os") as m_os: with mock.patch("aiohttp.client.ClientSession.request") as m_req: @@ -395,6 +516,7 @@ async def test_send_data_after_close(ws_key, key_data, loop) -> None: hdrs.CONNECTION: "upgrade", hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, } + resp.connection.protocol.read_timeout = None with mock.patch("aiohttp.client.os") as m_os: with mock.patch("aiohttp.client.ClientSession.request") as m_req: m_os.urandom.return_value = key_data @@ -423,6 +545,7 @@ async def test_send_data_type_errors(ws_key, key_data, loop) -> None: hdrs.CONNECTION: "upgrade", hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, } + resp.connection.protocol.read_timeout = None with mock.patch("aiohttp.client.WebSocketWriter") as WebSocketWriter: with mock.patch("aiohttp.client.os") as m_os: with mock.patch("aiohttp.client.ClientSession.request") as m_req: @@ -451,6 +574,7 @@ async def test_reader_read_exception(ws_key, key_data, loop) -> None: hdrs.CONNECTION: "upgrade", hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, } + hresp.connection.protocol.read_timeout = None with mock.patch("aiohttp.client.WebSocketWriter") as WebSocketWriter: with mock.patch("aiohttp.client.os") as m_os: with mock.patch("aiohttp.client.ClientSession.request") as m_req: @@ -515,6 +639,7 @@ async def test_ws_connect_non_overlapped_protocols(ws_key, loop, key_data) -> No hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, hdrs.SEC_WEBSOCKET_PROTOCOL: "other,another", } + resp.connection.protocol.read_timeout = None with mock.patch("aiohttp.client.os") as m_os: with mock.patch("aiohttp.client.ClientSession.request") as m_req: m_os.urandom.return_value = key_data @@ -537,6 +662,7 @@ async def test_ws_connect_non_overlapped_protocols_2(ws_key, loop, key_data) -> hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, hdrs.SEC_WEBSOCKET_PROTOCOL: "other,another", } + resp.connection.protocol.read_timeout = None with mock.patch("aiohttp.client.os") as m_os: with mock.patch("aiohttp.client.ClientSession.request") as m_req: m_os.urandom.return_value = key_data @@ -561,6 +687,7 @@ async def test_ws_connect_deflate(loop, ws_key, key_data) -> None: hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, hdrs.SEC_WEBSOCKET_EXTENSIONS: "permessage-deflate", } + resp.connection.protocol.read_timeout = None with mock.patch("aiohttp.client.os") as m_os: with mock.patch("aiohttp.client.ClientSession.request") as m_req: m_os.urandom.return_value = key_data @@ -584,6 +711,7 @@ async def test_ws_connect_deflate_per_message(loop, ws_key, key_data) -> None: hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, hdrs.SEC_WEBSOCKET_EXTENSIONS: "permessage-deflate", } + resp.connection.protocol.read_timeout = None with mock.patch("aiohttp.client.WebSocketWriter") as WebSocketWriter: with mock.patch("aiohttp.client.os") as m_os: with mock.patch("aiohttp.client.ClientSession.request") as m_req: @@ -616,6 +744,7 @@ async def test_ws_connect_deflate_server_not_support(loop, ws_key, key_data) -> hdrs.CONNECTION: "upgrade", hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, } + resp.connection.protocol.read_timeout = None with mock.patch("aiohttp.client.os") as m_os: with mock.patch("aiohttp.client.ClientSession.request") as m_req: m_os.urandom.return_value = key_data @@ -640,6 +769,7 @@ async def test_ws_connect_deflate_notakeover(loop, ws_key, key_data) -> None: hdrs.SEC_WEBSOCKET_EXTENSIONS: "permessage-deflate; " "client_no_context_takeover", } + resp.connection.protocol.read_timeout = None with mock.patch("aiohttp.client.os") as m_os: with mock.patch("aiohttp.client.ClientSession.request") as m_req: m_os.urandom.return_value = key_data @@ -664,6 +794,7 @@ async def test_ws_connect_deflate_client_wbits(loop, ws_key, key_data) -> None: hdrs.SEC_WEBSOCKET_EXTENSIONS: "permessage-deflate; " "client_max_window_bits=10", } + resp.connection.protocol.read_timeout = None with mock.patch("aiohttp.client.os") as m_os: with mock.patch("aiohttp.client.ClientSession.request") as m_req: m_os.urandom.return_value = key_data diff --git a/tests/test_client_ws_functional.py b/tests/test_client_ws_functional.py index 6270675276e..907ae232e9a 100644 --- a/tests/test_client_ws_functional.py +++ b/tests/test_client_ws_functional.py @@ -1,11 +1,14 @@ import asyncio import sys +from typing import Any, NoReturn +from unittest import mock import pytest import aiohttp -from aiohttp import hdrs, web +from aiohttp import ServerTimeoutError, WSMsgType, hdrs, web from aiohttp.http import WSCloseCode +from aiohttp.pytest_plugin import AiohttpClient if sys.version_info >= (3, 11): import asyncio as async_timeout @@ -245,7 +248,7 @@ async def handler(request): await client_ws.close() msg = await ws.receive() - assert msg.type == aiohttp.WSMsgType.CLOSE + assert msg.type is aiohttp.WSMsgType.CLOSE return ws app = web.Application() @@ -256,11 +259,43 @@ async def handler(request): await ws.send_bytes(b"ask") msg = await ws.receive() - assert msg.type == aiohttp.WSMsgType.CLOSING + assert msg.type is aiohttp.WSMsgType.CLOSING await asyncio.sleep(0.01) msg = await ws.receive() - assert msg.type == aiohttp.WSMsgType.CLOSED + assert msg.type is aiohttp.WSMsgType.CLOSED + + +async def test_concurrent_close_multiple_tasks(aiohttp_client: Any) -> None: + async def handler(request): + ws = web.WebSocketResponse() + await ws.prepare(request) + + await ws.receive_bytes() + await ws.send_str("test") + + msg = await ws.receive() + assert msg.type is aiohttp.WSMsgType.CLOSE + return ws + + app = web.Application() + app.router.add_route("GET", "/", handler) + client = await aiohttp_client(app) + ws = await client.ws_connect("/") + + await ws.send_bytes(b"ask") + + task1 = asyncio.create_task(ws.close()) + task2 = asyncio.create_task(ws.close()) + + msg = await ws.receive() + assert msg.type is aiohttp.WSMsgType.CLOSED + + await task1 + await task2 + + msg = await ws.receive() + assert msg.type is aiohttp.WSMsgType.CLOSED async def test_concurrent_task_close(aiohttp_client) -> None: @@ -565,7 +600,8 @@ async def handler(request): assert ping_received -async def test_heartbeat_no_pong(aiohttp_client) -> None: +async def test_heartbeat_no_pong(aiohttp_client: AiohttpClient) -> None: + """Test that the connection is closed if no pong is received without sending messages.""" ping_received = False async def handler(request): @@ -590,8 +626,155 @@ async def handler(request): assert resp.close_code is WSCloseCode.ABNORMAL_CLOSURE -async def test_send_recv_compress(aiohttp_client) -> None: +async def test_heartbeat_no_pong_after_receive_many_messages( + aiohttp_client: AiohttpClient, +) -> None: + """Test that the connection is closed if no pong is received after receiving many messages.""" + ping_received = False + + async def handler(request: web.Request) -> NoReturn: + nonlocal ping_received + ws = web.WebSocketResponse(autoping=False) + await ws.prepare(request) + for _ in range(5): + await ws.send_str("test") + await asyncio.sleep(0.05) + for _ in range(5): + await ws.send_str("test") + msg = await ws.receive() + ping_received = msg.type is aiohttp.WSMsgType.PING + await ws.receive() + assert False + + app = web.Application() + app.router.add_route("GET", "/", handler) + + client = await aiohttp_client(app) + resp = await client.ws_connect("/", heartbeat=0.1) + + for _ in range(10): + test_msg = await resp.receive() + assert test_msg.data == "test" + # Connection should be closed roughly after 1.5x heartbeat. + + await asyncio.sleep(0.2) + assert ping_received + assert resp.close_code is WSCloseCode.ABNORMAL_CLOSURE + + +async def test_heartbeat_no_pong_after_send_many_messages( + aiohttp_client: AiohttpClient, +) -> None: + """Test that the connection is closed if no pong is received after sending many messages.""" + ping_received = False + + async def handler(request: web.Request) -> NoReturn: + nonlocal ping_received + ws = web.WebSocketResponse(autoping=False) + await ws.prepare(request) + for _ in range(10): + msg = await ws.receive() + assert msg.data == "test" + assert msg.type is aiohttp.WSMsgType.TEXT + msg = await ws.receive() + ping_received = msg.type is aiohttp.WSMsgType.PING + await ws.receive() + assert False + + app = web.Application() + app.router.add_route("GET", "/", handler) + + client = await aiohttp_client(app) + resp = await client.ws_connect("/", heartbeat=0.1) + + for _ in range(5): + await resp.send_str("test") + await asyncio.sleep(0.05) + for _ in range(5): + await resp.send_str("test") + # Connection should be closed roughly after 1.5x heartbeat. + await asyncio.sleep(0.2) + assert ping_received + assert resp.close_code is WSCloseCode.ABNORMAL_CLOSURE + + +async def test_heartbeat_no_pong_concurrent_receive( + aiohttp_client: AiohttpClient, +) -> None: + ping_received = False + async def handler(request): + nonlocal ping_received + ws = web.WebSocketResponse(autoping=False) + await ws.prepare(request) + msg = await ws.receive() + ping_received = msg.type is aiohttp.WSMsgType.PING + ws._reader.feed_eof = lambda: None + await asyncio.sleep(10.0) + + app = web.Application() + app.router.add_route("GET", "/", handler) + + client = await aiohttp_client(app) + resp = await client.ws_connect("/", heartbeat=0.1) + resp._reader.feed_eof = lambda: None + + # Connection should be closed roughly after 1.5x heartbeat. + msg = await resp.receive(5.0) + assert ping_received + assert resp.close_code is WSCloseCode.ABNORMAL_CLOSURE + assert msg + assert msg.type is WSMsgType.ERROR + assert isinstance(msg.data, ServerTimeoutError) + + +async def test_close_websocket_while_ping_inflight( + aiohttp_client: AiohttpClient, +) -> None: + """Test closing the websocket while a ping is in-flight.""" + ping_received = False + + async def handler(request: web.Request) -> NoReturn: + nonlocal ping_received + ws = web.WebSocketResponse(autoping=False) + await ws.prepare(request) + msg = await ws.receive() + assert msg.type is aiohttp.WSMsgType.BINARY + msg = await ws.receive() + ping_received = msg.type is aiohttp.WSMsgType.PING + await ws.receive() + assert False + + app = web.Application() + app.router.add_route("GET", "/", handler) + + client = await aiohttp_client(app) + resp = await client.ws_connect("/", heartbeat=0.1) + await resp.send_bytes(b"ask") + + cancelled = False + ping_stated = False + + async def delayed_ping() -> None: + nonlocal cancelled, ping_stated + ping_stated = True + try: + await asyncio.sleep(1) + except asyncio.CancelledError: + cancelled = True + raise + + with mock.patch.object(resp._writer, "ping", delayed_ping): + await asyncio.sleep(0.1) + + await resp.close() + await asyncio.sleep(0) + assert ping_stated is True + assert cancelled is True + + +async def test_send_recv_compress(aiohttp_client: AiohttpClient) -> None: + async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse() await ws.prepare(request) @@ -835,3 +1018,11 @@ async def handler(request): assert "answer" == msg.data await resp.close() + + +async def test_ws_connect_with_wrong_ssl_type(aiohttp_client: AiohttpClient) -> None: + app = web.Application() + session = await aiohttp_client(app) + + with pytest.raises(TypeError, match="ssl should be SSLContext, .*"): + await session.ws_connect("/", ssl=42) diff --git a/tests/test_connector.py b/tests/test_connector.py index 02e48bc108b..d146fb4ee51 100644 --- a/tests/test_connector.py +++ b/tests/test_connector.py @@ -10,9 +10,11 @@ import uuid from collections import deque from contextlib import closing +from typing import Any, List, Optional from unittest import mock import pytest +from aiohappyeyeballs import AddrInfoType from yarl import URL import aiohttp @@ -539,7 +541,9 @@ async def test__drop_acquire_per_host3(loop) -> None: assert conn._acquired_per_host[123] == {789} -async def test_tcp_connector_certificate_error(loop) -> None: +async def test_tcp_connector_certificate_error( + loop: Any, start_connection: mock.AsyncMock +) -> None: req = ClientRequest("GET", URL("https://127.0.0.1:443"), loop=loop) async def certificate_error(*args, **kwargs): @@ -556,8 +560,10 @@ async def certificate_error(*args, **kwargs): assert isinstance(ctx.value, aiohttp.ClientSSLError) -async def test_tcp_connector_server_hostname_default(loop) -> None: - conn = aiohttp.TCPConnector(loop=loop) +async def test_tcp_connector_server_hostname_default( + loop: Any, start_connection: mock.AsyncMock +) -> None: + conn = aiohttp.TCPConnector() with mock.patch.object( conn._loop, "create_connection", autospec=True, spec_set=True @@ -570,8 +576,10 @@ async def test_tcp_connector_server_hostname_default(loop) -> None: assert create_connection.call_args.kwargs["server_hostname"] == "127.0.0.1" -async def test_tcp_connector_server_hostname_override(loop) -> None: - conn = aiohttp.TCPConnector(loop=loop) +async def test_tcp_connector_server_hostname_override( + loop: Any, start_connection: mock.AsyncMock +) -> None: + conn = aiohttp.TCPConnector() with mock.patch.object( conn._loop, "create_connection", autospec=True, spec_set=True @@ -595,6 +603,7 @@ async def test_tcp_connector_multiple_hosts_errors(loop) -> None: ip4 = "192.168.1.4" ip5 = "192.168.1.5" ips = [ip1, ip2, ip3, ip4, ip5] + addrs_tried = [] ips_tried = [] fingerprint = hashlib.sha256(b"foo").digest() @@ -624,11 +633,24 @@ async def _resolve_host(host, port, traces=None): os_error = certificate_error = ssl_error = fingerprint_error = False connected = False + async def start_connection(*args, **kwargs): + addr_infos: List[AddrInfoType] = kwargs["addr_infos"] + + first_addr_info = addr_infos[0] + first_addr_info_addr = first_addr_info[-1] + addrs_tried.append(first_addr_info_addr) + + mock_socket = mock.create_autospec(socket.socket, spec_set=True, instance=True) + mock_socket.getpeername.return_value = first_addr_info_addr + return mock_socket + async def create_connection(*args, **kwargs): nonlocal os_error, certificate_error, ssl_error, fingerprint_error nonlocal connected - ip = args[1] + sock = kwargs["sock"] + addr_info = sock.getpeername() + ip = addr_info[0] ips_tried.append(ip) @@ -645,6 +667,12 @@ async def create_connection(*args, **kwargs): raise ssl.SSLError if ip == ip4: + sock: socket.socket = kwargs["sock"] + + # Close the socket since we are not actually connecting + # and we don't want to leak it. + sock.close() + fingerprint_error = True tr, pr = mock.Mock(), mock.Mock() @@ -660,12 +688,21 @@ def get_extra_info(param): if param == "peername": return ("192.168.1.5", 12345) + if param == "socket": + return sock + assert False, param tr.get_extra_info = get_extra_info return tr, pr if ip == ip5: + sock: socket.socket = kwargs["sock"] + + # Close the socket since we are not actually connecting + # and we don't want to leak it. + sock.close() + connected = True tr, pr = mock.Mock(), mock.Mock() @@ -687,8 +724,13 @@ def get_extra_info(param): conn._loop.create_connection = create_connection - established_connection = await conn.connect(req, [], ClientTimeout()) - assert ips == ips_tried + with mock.patch( + "aiohttp.connector.aiohappyeyeballs.start_connection", start_connection + ): + established_connection = await conn.connect(req, [], ClientTimeout()) + + assert ips_tried == ips + assert addrs_tried == [(ip, 443) for ip in ips] assert os_error assert certificate_error @@ -699,8 +741,214 @@ def get_extra_info(param): established_connection.close() -async def test_tcp_connector_resolve_host(loop) -> None: - conn = aiohttp.TCPConnector(loop=loop, use_dns_cache=True) +@pytest.mark.parametrize( + ("happy_eyeballs_delay"), + [0.1, 0.25, None], +) +async def test_tcp_connector_happy_eyeballs( + loop: Any, happy_eyeballs_delay: Optional[float] +) -> None: + conn = aiohttp.TCPConnector(happy_eyeballs_delay=happy_eyeballs_delay) + + ip1 = "dead::beef::" + ip2 = "192.168.1.1" + ips = [ip1, ip2] + addrs_tried = [] + + req = ClientRequest( + "GET", + URL("https://mocked.host"), + loop=loop, + ) + + async def _resolve_host(host, port, traces=None): + return [ + { + "hostname": host, + "host": ip, + "port": port, + "family": socket.AF_INET6 if ":" in ip else socket.AF_INET, + "proto": 0, + "flags": socket.AI_NUMERICHOST, + } + for ip in ips + ] + + conn._resolve_host = _resolve_host + + os_error = False + connected = False + + async def sock_connect(*args, **kwargs): + addr = args[1] + nonlocal os_error + + addrs_tried.append(addr) + + if addr[0] == ip1: + os_error = True + raise OSError + + async def create_connection(*args, **kwargs): + sock: socket.socket = kwargs["sock"] + + # Close the socket since we are not actually connecting + # and we don't want to leak it. + sock.close() + + nonlocal connected + connected = True + tr = create_mocked_conn(loop) + pr = create_mocked_conn(loop) + return tr, pr + + conn._loop.sock_connect = sock_connect + conn._loop.create_connection = create_connection + + established_connection = await conn.connect(req, [], ClientTimeout()) + + assert addrs_tried == [(ip1, 443, 0, 0), (ip2, 443)] + + assert os_error + assert connected + + established_connection.close() + + +async def test_tcp_connector_interleave(loop: Any) -> None: + conn = aiohttp.TCPConnector(interleave=2) + + ip1 = "192.168.1.1" + ip2 = "192.168.1.2" + ip3 = "dead::beef::" + ip4 = "aaaa::beef::" + ip5 = "192.168.1.5" + ips = [ip1, ip2, ip3, ip4, ip5] + success_ips = [] + interleave = None + + req = ClientRequest( + "GET", + URL("https://mocked.host"), + loop=loop, + ) + + async def _resolve_host(host, port, traces=None): + return [ + { + "hostname": host, + "host": ip, + "port": port, + "family": socket.AF_INET6 if ":" in ip else socket.AF_INET, + "proto": 0, + "flags": socket.AI_NUMERICHOST, + } + for ip in ips + ] + + conn._resolve_host = _resolve_host + + async def start_connection(*args, **kwargs): + nonlocal interleave + addr_infos: List[AddrInfoType] = kwargs["addr_infos"] + interleave = kwargs["interleave"] + # Mock the 4th host connecting successfully + fourth_addr_info = addr_infos[3] + fourth_addr_info_addr = fourth_addr_info[-1] + mock_socket = mock.create_autospec(socket.socket, spec_set=True, instance=True) + mock_socket.getpeername.return_value = fourth_addr_info_addr + return mock_socket + + async def create_connection(*args, **kwargs): + sock = kwargs["sock"] + addr_info = sock.getpeername() + ip = addr_info[0] + + success_ips.append(ip) + + sock: socket.socket = kwargs["sock"] + # Close the socket since we are not actually connecting + # and we don't want to leak it. + sock.close() + tr = create_mocked_conn(loop) + pr = create_mocked_conn(loop) + return tr, pr + + conn._loop.create_connection = create_connection + + with mock.patch( + "aiohttp.connector.aiohappyeyeballs.start_connection", start_connection + ): + established_connection = await conn.connect(req, [], ClientTimeout()) + + assert success_ips == [ip4] + assert interleave == 2 + established_connection.close() + + +async def test_tcp_connector_family_is_respected(loop: Any) -> None: + conn = aiohttp.TCPConnector(family=socket.AF_INET) + + ip1 = "dead::beef::" + ip2 = "192.168.1.1" + ips = [ip1, ip2] + addrs_tried = [] + + req = ClientRequest( + "GET", + URL("https://mocked.host"), + loop=loop, + ) + + async def _resolve_host(host, port, traces=None): + return [ + { + "hostname": host, + "host": ip, + "port": port, + "family": socket.AF_INET6 if ":" in ip else socket.AF_INET, + "proto": 0, + "flags": socket.AI_NUMERICHOST, + } + for ip in ips + ] + + conn._resolve_host = _resolve_host + connected = False + + async def sock_connect(*args, **kwargs): + addr = args[1] + addrs_tried.append(addr) + + async def create_connection(*args, **kwargs): + sock: socket.socket = kwargs["sock"] + + # Close the socket since we are not actually connecting + # and we don't want to leak it. + sock.close() + + nonlocal connected + connected = True + tr = create_mocked_conn(loop) + pr = create_mocked_conn(loop) + return tr, pr + + conn._loop.sock_connect = sock_connect + conn._loop.create_connection = create_connection + + established_connection = await conn.connect(req, [], ClientTimeout()) + + # We should only try the IPv4 address since we specified + # the family to be AF_INET + assert addrs_tried == [(ip2, 443)] + + assert connected + + established_connection.close() + + +async def test_tcp_connector_resolve_host(loop: Any) -> None: + conn = aiohttp.TCPConnector(use_dns_cache=True) res = await conn._resolve_host("localhost", 8080) assert res @@ -1233,7 +1481,19 @@ async def test_tcp_connector_ctor() -> None: assert conn.family == 0 -async def test_tcp_connector_ctor_fingerprint_valid(loop) -> None: +async def test_tcp_connector_allowed_protocols(loop: asyncio.AbstractEventLoop) -> None: + conn = aiohttp.TCPConnector() + assert conn.allowed_protocol_schema_set == {"", "tcp", "http", "https", "ws", "wss"} + + +async def test_invalid_ssl_param() -> None: + with pytest.raises(TypeError): + aiohttp.TCPConnector(ssl=object()) # type: ignore[arg-type] + + +async def test_tcp_connector_ctor_fingerprint_valid( + loop: asyncio.AbstractEventLoop, +) -> None: valid = aiohttp.Fingerprint(hashlib.sha256(b"foo").digest()) conn = aiohttp.TCPConnector(ssl=valid, loop=loop) assert conn._ssl is valid @@ -1391,8 +1651,23 @@ async def test_ctor_with_default_loop(loop) -> None: assert loop is conn._loop -async def test_connect_with_limit(loop, key) -> None: - proto = mock.Mock() +async def test_base_connector_allows_high_level_protocols( + loop: asyncio.AbstractEventLoop, +) -> None: + conn = aiohttp.BaseConnector() + assert conn.allowed_protocol_schema_set == { + "", + "http", + "https", + "ws", + "wss", + } + + +async def test_connect_with_limit( + loop: asyncio.AbstractEventLoop, key: ConnectionKey +) -> None: + proto = create_mocked_conn(loop) proto.is_connected.return_value = True req = ClientRequest( @@ -2047,7 +2322,8 @@ async def handler(request): session = aiohttp.ClientSession(connector=conn) url = srv.make_url("/") - with pytest.raises(aiohttp.ClientConnectorCertificateError) as ctx: + err = aiohttp.ClientConnectorCertificateError + with pytest.raises(err) as ctx: await session.get(url) assert isinstance(ctx.value, aiohttp.ClientConnectorCertificateError) @@ -2163,6 +2439,14 @@ async def handler(request): connector = aiohttp.UnixConnector(unix_sockname) assert unix_sockname == connector.path + assert connector.allowed_protocol_schema_set == { + "", + "http", + "https", + "ws", + "wss", + "unix", + } session = client.ClientSession(connector=connector) r = await session.get(url) @@ -2188,6 +2472,14 @@ async def handler(request): connector = aiohttp.NamedPipeConnector(pipe_name) assert pipe_name == connector.path + assert connector.allowed_protocol_schema_set == { + "", + "http", + "https", + "ws", + "wss", + "npipe", + } session = client.ClientSession(connector=connector) r = await session.get(url) diff --git a/tests/test_cookiejar.py b/tests/test_cookiejar.py index 9c608959c39..91352f50c3d 100644 --- a/tests/test_cookiejar.py +++ b/tests/test_cookiejar.py @@ -153,28 +153,6 @@ def test_domain_matching() -> None: assert not test_func("test.com", "127.0.0.1") -def test_path_matching() -> None: - test_func = CookieJar._is_path_match - - assert test_func("/", "") - assert test_func("", "/") - assert test_func("/file", "") - assert test_func("/folder/file", "") - assert test_func("/", "/") - assert test_func("/file", "/") - assert test_func("/file", "/file") - assert test_func("/folder/", "/folder/") - assert test_func("/folder/", "/") - assert test_func("/folder/file", "/") - - assert not test_func("/", "/file") - assert not test_func("/", "/folder/") - assert not test_func("/file", "/folder/file") - assert not test_func("/folder/", "/folder/file") - assert not test_func("/different-file", "/file") - assert not test_func("/different-folder/", "/folder/") - - async def test_constructor(loop, cookies_to_send, cookies_to_receive) -> None: jar = CookieJar(loop=loop) jar.update_cookies(cookies_to_send) @@ -243,8 +221,98 @@ async def test_filter_cookie_with_unicode_domain(loop) -> None: assert len(jar.filter_cookies(URL("http://xn--9caa.com"))) == 1 -async def test_domain_filter_ip_cookie_send(loop) -> None: - jar = CookieJar(loop=loop) +@pytest.mark.parametrize( + ("url", "expected_cookies"), + ( + ( + "http://pathtest.com/one/two/", + ( + "no-path-cookie", + "path1-cookie", + "path2-cookie", + "shared-cookie", + "path3-cookie", + "path4-cookie", + ), + ), + ( + "http://pathtest.com/one/two", + ( + "no-path-cookie", + "path1-cookie", + "path2-cookie", + "shared-cookie", + "path3-cookie", + ), + ), + ( + "http://pathtest.com/one/two/three/", + ( + "no-path-cookie", + "path1-cookie", + "path2-cookie", + "shared-cookie", + "path3-cookie", + "path4-cookie", + ), + ), + ( + "http://test1.example.com/", + ( + "shared-cookie", + "domain-cookie", + "subdomain1-cookie", + "dotted-domain-cookie", + ), + ), + ( + "http://pathtest.com/", + ( + "shared-cookie", + "no-path-cookie", + "path1-cookie", + ), + ), + ), +) +async def test_filter_cookies_with_domain_path_lookup_multilevelpath( + loop, + url, + expected_cookies, +) -> None: + jar = CookieJar() + cookies = SimpleCookie( + "shared-cookie=first; " + "domain-cookie=second; Domain=example.com; " + "subdomain1-cookie=third; Domain=test1.example.com; " + "subdomain2-cookie=fourth; Domain=test2.example.com; " + "dotted-domain-cookie=fifth; Domain=.example.com; " + "different-domain-cookie=sixth; Domain=different.org; " + "secure-cookie=seventh; Domain=secure.com; Secure; " + "no-path-cookie=eighth; Domain=pathtest.com; " + "path1-cookie=ninth; Domain=pathtest.com; Path=/; " + "path2-cookie=tenth; Domain=pathtest.com; Path=/one; " + "path3-cookie=eleventh; Domain=pathtest.com; Path=/one/two; " + "path4-cookie=twelfth; Domain=pathtest.com; Path=/one/two/; " + "expires-cookie=thirteenth; Domain=expirestest.com; Path=/;" + " Expires=Tue, 1 Jan 1980 12:00:00 GMT; " + "max-age-cookie=fourteenth; Domain=maxagetest.com; Path=/;" + " Max-Age=60; " + "invalid-max-age-cookie=fifteenth; Domain=invalid-values.com; " + " Max-Age=string; " + "invalid-expires-cookie=sixteenth; Domain=invalid-values.com; " + " Expires=string;" + ) + jar.update_cookies(cookies) + cookies = jar.filter_cookies(URL(url)) + + assert len(cookies) == len(expected_cookies) + for c in cookies: + assert c in expected_cookies + + +async def test_domain_filter_ip_cookie_send() -> None: + jar = CookieJar() cookies = SimpleCookie( "shared-cookie=first; " "domain-cookie=second; Domain=example.com; " @@ -486,11 +554,11 @@ def test_domain_filter_diff_host(self) -> None: def test_domain_filter_host_only(self) -> None: self.jar.update_cookies(self.cookies_to_receive, URL("http://example.com/")) + sub_cookie = SimpleCookie("subdomain=spam; Path=/;") + self.jar.update_cookies(sub_cookie, URL("http://foo.example.com/")) - cookies_sent = self.jar.filter_cookies(URL("http://example.com/")) - self.assertIn("unconstrained-cookie", set(cookies_sent.keys())) - - cookies_sent = self.jar.filter_cookies(URL("http://different.org/")) + cookies_sent = self.jar.filter_cookies(URL("http://foo.example.com/")) + self.assertIn("subdomain", set(cookies_sent.keys())) self.assertNotIn("unconstrained-cookie", set(cookies_sent.keys())) def test_secure_filter(self) -> None: @@ -784,7 +852,28 @@ async def test_cookie_jar_clear_expired(): assert len(sut) == 0 -async def test_cookie_jar_clear_domain(): +async def test_cookie_jar_filter_cookies_expires(): + """Test that calling filter_cookies will expire stale cookies.""" + jar = CookieJar() + assert len(jar) == 0 + + cookie = SimpleCookie() + + cookie["foo"] = "bar" + cookie["foo"]["expires"] = "Tue, 1 Jan 1990 12:00:00 GMT" + + with freeze_time("1980-01-01"): + jar.update_cookies(cookie) + + assert len(jar) == 1 + + # filter_cookies should expire stale cookies + jar.filter_cookies(URL("http://any.com/")) + + assert len(jar) == 0 + + +async def test_cookie_jar_clear_domain() -> None: sut = CookieJar() cookie = SimpleCookie() cookie["foo"] = "bar" @@ -825,7 +914,7 @@ async def test_pickle_format(cookies_to_send) -> None: with file_path.open("wb") as f: pickle.dump(cookies, f, pickle.HIGHEST_PROTOCOL) """ - pickled = b"\x80\x05\x95\xc5\x07\x00\x00\x00\x00\x00\x00\x8c\x0bcollections\x94\x8c\x0bdefaultdict\x94\x93\x94\x8c\x0chttp.cookies\x94\x8c\x0cSimpleCookie\x94\x93\x94\x85\x94R\x94(\x8c\x00\x94\x8c\x01/\x94\x86\x94h\x05)\x81\x94\x8c\rshared-cookie\x94h\x03\x8c\x06Morsel\x94\x93\x94)\x81\x94(\x8c\x07expires\x94h\x08\x8c\x04path\x94h\t\x8c\x07comment\x94h\x08\x8c\x06domain\x94h\x08\x8c\x07max-age\x94h\x08\x8c\x06secure\x94h\x08\x8c\x08httponly\x94h\x08\x8c\x07version\x94h\x08\x8c\x08samesite\x94h\x08u}\x94(\x8c\x03key\x94h\x0c\x8c\x05value\x94\x8c\x05first\x94\x8c\x0bcoded_value\x94h\x1cubs\x8c\x0bexample.com\x94h\t\x86\x94h\x05)\x81\x94(\x8c\rdomain-cookie\x94h\x0e)\x81\x94(h\x10h\x08h\x11h\th\x12h\x08h\x13h\x1eh\x14h\x08h\x15h\x08h\x16h\x08h\x17h\x08h\x18h\x08u}\x94(h\x1ah!h\x1b\x8c\x06second\x94h\x1dh$ub\x8c\x14dotted-domain-cookie\x94h\x0e)\x81\x94(h\x10h\x08h\x11h\th\x12h\x08h\x13\x8c\x0bexample.com\x94h\x14h\x08h\x15h\x08h\x16h\x08h\x17h\x08h\x18h\x08u}\x94(h\x1ah%h\x1b\x8c\x05fifth\x94h\x1dh)ubu\x8c\x11test1.example.com\x94h\t\x86\x94h\x05)\x81\x94\x8c\x11subdomain1-cookie\x94h\x0e)\x81\x94(h\x10h\x08h\x11h\th\x12h\x08h\x13h*h\x14h\x08h\x15h\x08h\x16h\x08h\x17h\x08h\x18h\x08u}\x94(h\x1ah-h\x1b\x8c\x05third\x94h\x1dh0ubs\x8c\x11test2.example.com\x94h\t\x86\x94h\x05)\x81\x94\x8c\x11subdomain2-cookie\x94h\x0e)\x81\x94(h\x10h\x08h\x11h\th\x12h\x08h\x13h1h\x14h\x08h\x15h\x08h\x16h\x08h\x17h\x08h\x18h\x08u}\x94(h\x1ah4h\x1b\x8c\x06fourth\x94h\x1dh7ubs\x8c\rdifferent.org\x94h\t\x86\x94h\x05)\x81\x94\x8c\x17different-domain-cookie\x94h\x0e)\x81\x94(h\x10h\x08h\x11h\th\x12h\x08h\x13h8h\x14h\x08h\x15h\x08h\x16h\x08h\x17h\x08h\x18h\x08u}\x94(h\x1ah;h\x1b\x8c\x05sixth\x94h\x1dh>ubs\x8c\nsecure.com\x94h\t\x86\x94h\x05)\x81\x94\x8c\rsecure-cookie\x94h\x0e)\x81\x94(h\x10h\x08h\x11h\th\x12h\x08h\x13h?h\x14h\x08h\x15\x88h\x16h\x08h\x17h\x08h\x18h\x08u}\x94(h\x1ahBh\x1b\x8c\x07seventh\x94h\x1dhEubs\x8c\x0cpathtest.com\x94h\t\x86\x94h\x05)\x81\x94(\x8c\x0eno-path-cookie\x94h\x0e)\x81\x94(h\x10h\x08h\x11h\th\x12h\x08h\x13hFh\x14h\x08h\x15h\x08h\x16h\x08h\x17h\x08h\x18h\x08u}\x94(h\x1ahIh\x1b\x8c\x06eighth\x94h\x1dhLub\x8c\x0cpath1-cookie\x94h\x0e)\x81\x94(h\x10h\x08h\x11h\th\x12h\x08h\x13\x8c\x0cpathtest.com\x94h\x14h\x08h\x15h\x08h\x16h\x08h\x17h\x08h\x18h\x08u}\x94(h\x1ahMh\x1b\x8c\x05ninth\x94h\x1dhQubu\x8c\x0cpathtest.com\x94\x8c\x04/one\x94\x86\x94h\x05)\x81\x94\x8c\x0cpath2-cookie\x94h\x0e)\x81\x94(h\x10h\x08h\x11hSh\x12h\x08h\x13hRh\x14h\x08h\x15h\x08h\x16h\x08h\x17h\x08h\x18h\x08u}\x94(h\x1ahVh\x1b\x8c\x05tenth\x94h\x1dhYubs\x8c\x0cpathtest.com\x94\x8c\x08/one/two\x94\x86\x94h\x05)\x81\x94\x8c\x0cpath3-cookie\x94h\x0e)\x81\x94(h\x10h\x08h\x11h[h\x12h\x08h\x13hZh\x14h\x08h\x15h\x08h\x16h\x08h\x17h\x08h\x18h\x08u}\x94(h\x1ah^h\x1b\x8c\x08eleventh\x94h\x1dhaubs\x8c\x0cpathtest.com\x94\x8c\t/one/two/\x94\x86\x94h\x05)\x81\x94\x8c\x0cpath4-cookie\x94h\x0e)\x81\x94(h\x10h\x08h\x11hch\x12h\x08h\x13hbh\x14h\x08h\x15h\x08h\x16h\x08h\x17h\x08h\x18h\x08u}\x94(h\x1ahfh\x1b\x8c\x07twelfth\x94h\x1dhiubs\x8c\x0fexpirestest.com\x94h\t\x86\x94h\x05)\x81\x94\x8c\x0eexpires-cookie\x94h\x0e)\x81\x94(h\x10\x8c\x1cTue, 1 Jan 2999 12:00:00 GMT\x94h\x11h\th\x12h\x08h\x13hjh\x14h\x08h\x15h\x08h\x16h\x08h\x17h\x08h\x18h\x08u}\x94(h\x1ahmh\x1b\x8c\nthirteenth\x94h\x1dhqubs\x8c\x0emaxagetest.com\x94h\t\x86\x94h\x05)\x81\x94\x8c\x0emax-age-cookie\x94h\x0e)\x81\x94(h\x10h\x08h\x11h\th\x12h\x08h\x13hrh\x14\x8c\x0260\x94h\x15h\x08h\x16h\x08h\x17h\x08h\x18h\x08u}\x94(h\x1ahuh\x1b\x8c\nfourteenth\x94h\x1dhyubs\x8c\x12invalid-values.com\x94h\t\x86\x94h\x05)\x81\x94(\x8c\x16invalid-max-age-cookie\x94h\x0e)\x81\x94(h\x10h\x08h\x11h\th\x12h\x08h\x13hzh\x14h\x08h\x15h\x08h\x16h\x08h\x17h\x08h\x18h\x08u}\x94(h\x1ah}h\x1b\x8c\tfifteenth\x94h\x1dh\x80ub\x8c\x16invalid-expires-cookie\x94h\x0e)\x81\x94(h\x10h\x08h\x11h\th\x12h\x08h\x13\x8c\x12invalid-values.com\x94h\x14h\x08h\x15h\x08h\x16h\x08h\x17h\x08h\x18h\x08u}\x94(h\x1ah\x81h\x1b\x8c\tsixteenth\x94h\x1dh\x85ubuu." + pickled = b"\x80\x04\x95\xc8\x0b\x00\x00\x00\x00\x00\x00\x8c\x0bcollections\x94\x8c\x0bdefaultdict\x94\x93\x94\x8c\x0chttp.cookies\x94\x8c\x0cSimpleCookie\x94\x93\x94\x85\x94R\x94(\x8c\x00\x94h\x08\x86\x94h\x05)\x81\x94\x8c\rshared-cookie\x94h\x03\x8c\x06Morsel\x94\x93\x94)\x81\x94(\x8c\x07expires\x94h\x08\x8c\x04path\x94\x8c\x01/\x94\x8c\x07comment\x94h\x08\x8c\x06domain\x94h\x08\x8c\x07max-age\x94h\x08\x8c\x06secure\x94h\x08\x8c\x08httponly\x94h\x08\x8c\x07version\x94h\x08\x8c\x08samesite\x94h\x08u}\x94(\x8c\x03key\x94h\x0b\x8c\x05value\x94\x8c\x05first\x94\x8c\x0bcoded_value\x94h\x1cubs\x8c\x0bexample.com\x94h\x08\x86\x94h\x05)\x81\x94(\x8c\rdomain-cookie\x94h\r)\x81\x94(\x8c\x07expires\x94h\x08\x8c\x04path\x94h\x11\x8c\x07comment\x94h\x08\x8c\x06domain\x94h\x1e\x8c\x07max-age\x94h\x08\x8c\x06secure\x94h\x08\x8c\x08httponly\x94h\x08\x8c\x07version\x94h\x08\x8c\x08samesite\x94h\x08u}\x94(h\x1ah!h\x1b\x8c\x06second\x94h\x1dh-ub\x8c\x14dotted-domain-cookie\x94h\r)\x81\x94(\x8c\x07expires\x94h\x08\x8c\x04path\x94h\x11\x8c\x07comment\x94h\x08\x8c\x06domain\x94\x8c\x0bexample.com\x94\x8c\x07max-age\x94h\x08\x8c\x06secure\x94h\x08\x8c\x08httponly\x94h\x08\x8c\x07version\x94h\x08\x8c\x08samesite\x94h\x08u}\x94(h\x1ah.h\x1b\x8c\x05fifth\x94h\x1dh;ubu\x8c\x11test1.example.com\x94h\x08\x86\x94h\x05)\x81\x94\x8c\x11subdomain1-cookie\x94h\r)\x81\x94(\x8c\x07expires\x94h\x08\x8c\x04path\x94h\x11\x8c\x07comment\x94h\x08\x8c\x06domain\x94h<\x8c\x07max-age\x94h\x08\x8c\x06secure\x94h\x08\x8c\x08httponly\x94h\x08\x8c\x07version\x94h\x08\x8c\x08samesite\x94h\x08u}\x94(h\x1ah?h\x1b\x8c\x05third\x94h\x1dhKubs\x8c\x11test2.example.com\x94h\x08\x86\x94h\x05)\x81\x94\x8c\x11subdomain2-cookie\x94h\r)\x81\x94(\x8c\x07expires\x94h\x08\x8c\x04path\x94h\x11\x8c\x07comment\x94h\x08\x8c\x06domain\x94hL\x8c\x07max-age\x94h\x08\x8c\x06secure\x94h\x08\x8c\x08httponly\x94h\x08\x8c\x07version\x94h\x08\x8c\x08samesite\x94h\x08u}\x94(h\x1ahOh\x1b\x8c\x06fourth\x94h\x1dh[ubs\x8c\rdifferent.org\x94h\x08\x86\x94h\x05)\x81\x94\x8c\x17different-domain-cookie\x94h\r)\x81\x94(\x8c\x07expires\x94h\x08\x8c\x04path\x94h\x11\x8c\x07comment\x94h\x08\x8c\x06domain\x94h\\\x8c\x07max-age\x94h\x08\x8c\x06secure\x94h\x08\x8c\x08httponly\x94h\x08\x8c\x07version\x94h\x08\x8c\x08samesite\x94h\x08u}\x94(h\x1ah_h\x1b\x8c\x05sixth\x94h\x1dhkubs\x8c\nsecure.com\x94h\x08\x86\x94h\x05)\x81\x94\x8c\rsecure-cookie\x94h\r)\x81\x94(\x8c\x07expires\x94h\x08\x8c\x04path\x94h\x11\x8c\x07comment\x94h\x08\x8c\x06domain\x94hl\x8c\x07max-age\x94h\x08\x8c\x06secure\x94\x88\x8c\x08httponly\x94h\x08\x8c\x07version\x94h\x08\x8c\x08samesite\x94h\x08u}\x94(h\x1ahoh\x1b\x8c\x07seventh\x94h\x1dh{ubs\x8c\x0cpathtest.com\x94h\x08\x86\x94h\x05)\x81\x94(\x8c\x0eno-path-cookie\x94h\r)\x81\x94(\x8c\x07expires\x94h\x08\x8c\x04path\x94h\x11\x8c\x07comment\x94h\x08\x8c\x06domain\x94h|\x8c\x07max-age\x94h\x08\x8c\x06secure\x94h\x08\x8c\x08httponly\x94h\x08\x8c\x07version\x94h\x08\x8c\x08samesite\x94h\x08u}\x94(h\x1ah\x7fh\x1b\x8c\x06eighth\x94h\x1dh\x8bub\x8c\x0cpath1-cookie\x94h\r)\x81\x94(\x8c\x07expires\x94h\x08\x8c\x04path\x94h\x11\x8c\x07comment\x94h\x08\x8c\x06domain\x94\x8c\x0cpathtest.com\x94\x8c\x07max-age\x94h\x08\x8c\x06secure\x94h\x08\x8c\x08httponly\x94h\x08\x8c\x07version\x94h\x08\x8c\x08samesite\x94h\x08u}\x94(h\x1ah\x8ch\x1b\x8c\x05ninth\x94h\x1dh\x99ubu\x8c\x0cpathtest.com\x94\x8c\x04/one\x94\x86\x94h\x05)\x81\x94\x8c\x0cpath2-cookie\x94h\r)\x81\x94(\x8c\x07expires\x94h\x08\x8c\x04path\x94h\x9b\x8c\x07comment\x94h\x08\x8c\x06domain\x94h\x9a\x8c\x07max-age\x94h\x08\x8c\x06secure\x94h\x08\x8c\x08httponly\x94h\x08\x8c\x07version\x94h\x08\x8c\x08samesite\x94h\x08u}\x94(h\x1ah\x9eh\x1b\x8c\x05tenth\x94h\x1dh\xaaubs\x8c\x0cpathtest.com\x94\x8c\x08/one/two\x94\x86\x94h\x05)\x81\x94(\x8c\x0cpath3-cookie\x94h\r)\x81\x94(\x8c\x07expires\x94h\x08\x8c\x04path\x94h\xac\x8c\x07comment\x94h\x08\x8c\x06domain\x94h\xab\x8c\x07max-age\x94h\x08\x8c\x06secure\x94h\x08\x8c\x08httponly\x94h\x08\x8c\x07version\x94h\x08\x8c\x08samesite\x94h\x08u}\x94(h\x1ah\xafh\x1b\x8c\x08eleventh\x94h\x1dh\xbbub\x8c\x0cpath4-cookie\x94h\r)\x81\x94(\x8c\x07expires\x94h\x08\x8c\x04path\x94\x8c\t/one/two/\x94\x8c\x07comment\x94h\x08\x8c\x06domain\x94\x8c\x0cpathtest.com\x94\x8c\x07max-age\x94h\x08\x8c\x06secure\x94h\x08\x8c\x08httponly\x94h\x08\x8c\x07version\x94h\x08\x8c\x08samesite\x94h\x08u}\x94(h\x1ah\xbch\x1b\x8c\x07twelfth\x94h\x1dh\xcaubu\x8c\x0fexpirestest.com\x94h\x08\x86\x94h\x05)\x81\x94\x8c\x0eexpires-cookie\x94h\r)\x81\x94(\x8c\x07expires\x94\x8c\x1cTue, 1 Jan 2999 12:00:00 GMT\x94\x8c\x04path\x94h\x11\x8c\x07comment\x94h\x08\x8c\x06domain\x94h\xcb\x8c\x07max-age\x94h\x08\x8c\x06secure\x94h\x08\x8c\x08httponly\x94h\x08\x8c\x07version\x94h\x08\x8c\x08samesite\x94h\x08u}\x94(h\x1ah\xceh\x1b\x8c\nthirteenth\x94h\x1dh\xdbubs\x8c\x0emaxagetest.com\x94h\x08\x86\x94h\x05)\x81\x94\x8c\x0emax-age-cookie\x94h\r)\x81\x94(\x8c\x07expires\x94h\x08\x8c\x04path\x94h\x11\x8c\x07comment\x94h\x08\x8c\x06domain\x94h\xdc\x8c\x07max-age\x94\x8c\x0260\x94\x8c\x06secure\x94h\x08\x8c\x08httponly\x94h\x08\x8c\x07version\x94h\x08\x8c\x08samesite\x94h\x08u}\x94(h\x1ah\xdfh\x1b\x8c\nfourteenth\x94h\x1dh\xecubs\x8c\x12invalid-values.com\x94h\x08\x86\x94h\x05)\x81\x94(\x8c\x16invalid-max-age-cookie\x94h\r)\x81\x94(\x8c\x07expires\x94h\x08\x8c\x04path\x94h\x11\x8c\x07comment\x94h\x08\x8c\x06domain\x94h\xed\x8c\x07max-age\x94h\x08\x8c\x06secure\x94h\x08\x8c\x08httponly\x94h\x08\x8c\x07version\x94h\x08\x8c\x08samesite\x94h\x08u}\x94(h\x1ah\xf0h\x1b\x8c\tfifteenth\x94h\x1dh\xfcub\x8c\x16invalid-expires-cookie\x94h\r)\x81\x94(\x8c\x07expires\x94h\x08\x8c\x04path\x94h\x11\x8c\x07comment\x94h\x08\x8c\x06domain\x94\x8c\x12invalid-values.com\x94\x8c\x07max-age\x94h\x08\x8c\x06secure\x94h\x08\x8c\x08httponly\x94h\x08\x8c\x07version\x94h\x08\x8c\x08samesite\x94h\x08u}\x94(h\x1ah\xfdh\x1b\x8c\tsixteenth\x94h\x1dj\n\x01\x00\x00ubuu." cookies = pickle.loads(pickled) cj = CookieJar() diff --git a/tests/test_helpers.py b/tests/test_helpers.py index b59528d3468..67af32dc3be 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -607,18 +607,6 @@ def test_proxies_from_env_http_with_auth(url_input, expected_scheme) -> None: assert proxy_auth.encoding == "latin1" -# ------------ get_running_loop --------------------------------- - - -def test_get_running_loop_not_running(loop) -> None: - with pytest.warns(DeprecationWarning): - helpers.get_running_loop() - - -async def test_get_running_loop_ok(loop) -> None: - assert helpers.get_running_loop() is loop - - # --------------------- get_env_proxy_for_url ------------------------------ diff --git a/tests/test_http_parser.py b/tests/test_http_parser.py index ee7dc4aabc5..0e9aff68dc2 100644 --- a/tests/test_http_parser.py +++ b/tests/test_http_parser.py @@ -294,9 +294,20 @@ def test_parse_headers_longline(parser: Any) -> None: parser.feed_data(text) +@pytest.fixture +def xfail_c_parser_status(request) -> None: + if isinstance(request.getfixturevalue("parser"), HttpRequestParserPy): + return + request.node.add_marker( + pytest.mark.xfail( + reason="Regression test for Py parser. May match C behaviour later.", + raises=http_exceptions.BadStatusLine, + ) + ) + + +@pytest.mark.usefixtures("xfail_c_parser_status") def test_parse_unusual_request_line(parser) -> None: - if not isinstance(response, HttpResponseParserPy): - pytest.xfail("Regression test for Py parser. May match C behaviour later.") text = b"#smol //a HTTP/1.3\r\n\r\n" messages, upgrade, tail = parser.feed_data(text) assert len(messages) == 1 @@ -632,9 +643,6 @@ def test_invalid_header_spacing(parser, pad1: bytes, pad2: bytes, hdr: bytes) -> if pad1 == pad2 == b"" and hdr != b"": # one entry in param matrix is correct: non-empty name, not padded expectation = nullcontext() - if pad1 == pad2 == hdr == b"": - if not isinstance(response, HttpResponseParserPy): - pytest.xfail("Regression test for Py parser. May match C behaviour later.") with expectation: parser.feed_data(text) @@ -815,9 +823,40 @@ def test_http_request_upgrade(parser: Any) -> None: assert tail == b"some raw data" +async def test_http_request_upgrade_unknown(parser: Any) -> None: + text = ( + b"POST / HTTP/1.1\r\n" + b"Connection: Upgrade\r\n" + b"Content-Length: 2\r\n" + b"Upgrade: unknown\r\n" + b"Content-Type: application/json\r\n\r\n" + b"{}" + ) + messages, upgrade, tail = parser.feed_data(text) + + msg = messages[0][0] + assert not msg.should_close + assert msg.upgrade + assert not upgrade + assert not msg.chunked + assert tail == b"" + assert await messages[0][-1].read() == b"{}" + + +@pytest.fixture +def xfail_c_parser_url(request) -> None: + if isinstance(request.getfixturevalue("parser"), HttpRequestParserPy): + return + request.node.add_marker( + pytest.mark.xfail( + reason="Regression test for Py parser. May match C behaviour later.", + raises=http_exceptions.InvalidURLError, + ) + ) + + +@pytest.mark.usefixtures("xfail_c_parser_url") def test_http_request_parser_utf8_request_line(parser) -> None: - if not isinstance(response, HttpResponseParserPy): - pytest.xfail("Regression test for Py parser. May match C behaviour later.") messages, upgrade, tail = parser.feed_data( # note the truncated unicode sequence b"GET /P\xc3\xbcnktchen\xa0\xef\xb7 HTTP/1.1\r\n" + @@ -837,7 +876,9 @@ def test_http_request_parser_utf8_request_line(parser) -> None: assert msg.compression is None assert not msg.upgrade assert not msg.chunked - assert msg.url.path == URL("/P%C3%BCnktchen\udca0\udcef\udcb7").path + # python HTTP parser depends on Cython and CPython URL to match + # .. but yarl.URL("/abs") is not equal to URL.build(path="/abs"), see #6409 + assert msg.url == URL.build(path="/Pünktchen\udca0\udcef\udcb7", encoded=True) def test_http_request_parser_utf8(parser) -> None: @@ -1209,8 +1250,8 @@ def test_parse_chunked_payload_chunk_extension(parser) -> None: assert payload.is_eof() -def _test_parse_no_length_or_te_on_post(loop, protocol, request_cls): - parser = request_cls(protocol, loop, readall=True) +def test_parse_no_length_or_te_on_post(loop: Any, protocol: Any, request_cls: Any): + parser = request_cls(protocol, loop, limit=2**16) text = b"POST /test HTTP/1.1\r\n\r\n" msg, payload = parser.feed_data(text)[0][0] @@ -1454,29 +1495,16 @@ def test_parse_bad_method_for_c_parser_raises(loop, protocol): class TestParsePayload: async def test_parse_eof_payload(self, stream) -> None: - out = aiohttp.FlowControlDataQueue( - stream, 2**16, loop=asyncio.get_event_loop() - ) - p = HttpPayloadParser(out, readall=True) + out = aiohttp.FlowControlDataQueue(stream, 2**16, loop=asyncio.get_event_loop()) + p = HttpPayloadParser(out) p.feed_data(b"data") p.feed_eof() assert out.is_eof() assert [(bytearray(b"data"), 4)] == list(out._buffer) - async def test_parse_no_body(self, stream) -> None: - out = aiohttp.FlowControlDataQueue( - stream, 2**16, loop=asyncio.get_event_loop() - ) - p = HttpPayloadParser(out, method="PUT") - - assert out.is_eof() - assert p.done - async def test_parse_length_payload_eof(self, stream) -> None: - out = aiohttp.FlowControlDataQueue( - stream, 2**16, loop=asyncio.get_event_loop() - ) + out = aiohttp.FlowControlDataQueue(stream, 2**16, loop=asyncio.get_event_loop()) p = HttpPayloadParser(out, length=4) p.feed_data(b"da") @@ -1485,9 +1513,7 @@ async def test_parse_length_payload_eof(self, stream) -> None: p.feed_eof() async def test_parse_chunked_payload_size_error(self, stream) -> None: - out = aiohttp.FlowControlDataQueue( - stream, 2**16, loop=asyncio.get_event_loop() - ) + out = aiohttp.FlowControlDataQueue(stream, 2**16, loop=asyncio.get_event_loop()) p = HttpPayloadParser(out, chunked=True) with pytest.raises(http_exceptions.TransferEncodingError): p.feed_data(b"blah\r\n") @@ -1550,9 +1576,7 @@ async def test_parse_chunked_payload_split_end_trailers4(self, protocol) -> None assert b"asdf" == b"".join(out._buffer) async def test_http_payload_parser_length(self, stream) -> None: - out = aiohttp.FlowControlDataQueue( - stream, 2**16, loop=asyncio.get_event_loop() - ) + out = aiohttp.FlowControlDataQueue(stream, 2**16, loop=asyncio.get_event_loop()) p = HttpPayloadParser(out, length=2) eof, tail = p.feed_data(b"1245") assert eof @@ -1565,9 +1589,7 @@ async def test_http_payload_parser_deflate(self, stream) -> None: COMPRESSED = b"x\x9cKI,I\x04\x00\x04\x00\x01\x9b" length = len(COMPRESSED) - out = aiohttp.FlowControlDataQueue( - stream, 2**16, loop=asyncio.get_event_loop() - ) + out = aiohttp.FlowControlDataQueue(stream, 2**16, loop=asyncio.get_event_loop()) p = HttpPayloadParser(out, length=length, compression="deflate") p.feed_data(COMPRESSED) assert b"data" == b"".join(d for d, _ in out._buffer) @@ -1579,9 +1601,7 @@ async def test_http_payload_parser_deflate_no_hdrs(self, stream: Any) -> None: COMPRESSED = b"KI,I\x04\x00" length = len(COMPRESSED) - out = aiohttp.FlowControlDataQueue( - stream, 2**16, loop=asyncio.get_event_loop() - ) + out = aiohttp.FlowControlDataQueue(stream, 2**16, loop=asyncio.get_event_loop()) p = HttpPayloadParser(out, length=length, compression="deflate") p.feed_data(COMPRESSED) assert b"data" == b"".join(d for d, _ in out._buffer) @@ -1592,19 +1612,15 @@ async def test_http_payload_parser_deflate_light(self, stream) -> None: COMPRESSED = b"\x18\x95KI,I\x04\x00\x04\x00\x01\x9b" length = len(COMPRESSED) - out = aiohttp.FlowControlDataQueue( - stream, 2**16, loop=asyncio.get_event_loop() - ) + out = aiohttp.FlowControlDataQueue(stream, 2**16, loop=asyncio.get_event_loop()) p = HttpPayloadParser(out, length=length, compression="deflate") p.feed_data(COMPRESSED) assert b"data" == b"".join(d for d, _ in out._buffer) assert out.is_eof() async def test_http_payload_parser_deflate_split(self, stream) -> None: - out = aiohttp.FlowControlDataQueue( - stream, 2**16, loop=asyncio.get_event_loop() - ) - p = HttpPayloadParser(out, compression="deflate", readall=True) + out = aiohttp.FlowControlDataQueue(stream, 2**16, loop=asyncio.get_event_loop()) + p = HttpPayloadParser(out, compression="deflate") # Feeding one correct byte should be enough to choose exact # deflate decompressor p.feed_data(b"x", 1) @@ -1613,10 +1629,8 @@ async def test_http_payload_parser_deflate_split(self, stream) -> None: assert b"data" == b"".join(d for d, _ in out._buffer) async def test_http_payload_parser_deflate_split_err(self, stream) -> None: - out = aiohttp.FlowControlDataQueue( - stream, 2**16, loop=asyncio.get_event_loop() - ) - p = HttpPayloadParser(out, compression="deflate", readall=True) + out = aiohttp.FlowControlDataQueue(stream, 2**16, loop=asyncio.get_event_loop()) + p = HttpPayloadParser(out, compression="deflate") # Feeding one wrong byte should be enough to choose exact # deflate decompressor p.feed_data(b"K", 1) @@ -1625,9 +1639,7 @@ async def test_http_payload_parser_deflate_split_err(self, stream) -> None: assert b"data" == b"".join(d for d, _ in out._buffer) async def test_http_payload_parser_length_zero(self, stream) -> None: - out = aiohttp.FlowControlDataQueue( - stream, 2**16, loop=asyncio.get_event_loop() - ) + out = aiohttp.FlowControlDataQueue(stream, 2**16, loop=asyncio.get_event_loop()) p = HttpPayloadParser(out, length=0) assert p.done assert out.is_eof() @@ -1635,9 +1647,7 @@ async def test_http_payload_parser_length_zero(self, stream) -> None: @pytest.mark.skipif(brotli is None, reason="brotli is not installed") async def test_http_payload_brotli(self, stream) -> None: compressed = brotli.compress(b"brotli data") - out = aiohttp.FlowControlDataQueue( - stream, 2**16, loop=asyncio.get_event_loop() - ) + out = aiohttp.FlowControlDataQueue(stream, 2**16, loop=asyncio.get_event_loop()) p = HttpPayloadParser(out, length=len(compressed), compression="br") p.feed_data(compressed) assert b"brotli data" == b"".join(d for d, _ in out._buffer) @@ -1646,9 +1656,7 @@ async def test_http_payload_brotli(self, stream) -> None: class TestDeflateBuffer: async def test_feed_data(self, stream) -> None: - buf = aiohttp.FlowControlDataQueue( - stream, 2**16, loop=asyncio.get_event_loop() - ) + buf = aiohttp.FlowControlDataQueue(stream, 2**16, loop=asyncio.get_event_loop()) dbuf = DeflateBuffer(buf, "deflate") dbuf.decompressor = mock.Mock() @@ -1659,9 +1667,7 @@ async def test_feed_data(self, stream) -> None: assert [b"line"] == list(d for d, _ in buf._buffer) async def test_feed_data_err(self, stream) -> None: - buf = aiohttp.FlowControlDataQueue( - stream, 2**16, loop=asyncio.get_event_loop() - ) + buf = aiohttp.FlowControlDataQueue(stream, 2**16, loop=asyncio.get_event_loop()) dbuf = DeflateBuffer(buf, "deflate") exc = ValueError() @@ -1674,9 +1680,7 @@ async def test_feed_data_err(self, stream) -> None: dbuf.feed_data(b"xsomedata", 9) async def test_feed_eof(self, stream) -> None: - buf = aiohttp.FlowControlDataQueue( - stream, 2**16, loop=asyncio.get_event_loop() - ) + buf = aiohttp.FlowControlDataQueue(stream, 2**16, loop=asyncio.get_event_loop()) dbuf = DeflateBuffer(buf, "deflate") dbuf.decompressor = mock.Mock() @@ -1687,9 +1691,7 @@ async def test_feed_eof(self, stream) -> None: assert buf._eof async def test_feed_eof_err_deflate(self, stream) -> None: - buf = aiohttp.FlowControlDataQueue( - stream, 2**16, loop=asyncio.get_event_loop() - ) + buf = aiohttp.FlowControlDataQueue(stream, 2**16, loop=asyncio.get_event_loop()) dbuf = DeflateBuffer(buf, "deflate") dbuf.decompressor = mock.Mock() @@ -1700,9 +1702,7 @@ async def test_feed_eof_err_deflate(self, stream) -> None: dbuf.feed_eof() async def test_feed_eof_no_err_gzip(self, stream) -> None: - buf = aiohttp.FlowControlDataQueue( - stream, 2**16, loop=asyncio.get_event_loop() - ) + buf = aiohttp.FlowControlDataQueue(stream, 2**16, loop=asyncio.get_event_loop()) dbuf = DeflateBuffer(buf, "gzip") dbuf.decompressor = mock.Mock() @@ -1713,9 +1713,7 @@ async def test_feed_eof_no_err_gzip(self, stream) -> None: assert [b"line"] == list(d for d, _ in buf._buffer) async def test_feed_eof_no_err_brotli(self, stream) -> None: - buf = aiohttp.FlowControlDataQueue( - stream, 2**16, loop=asyncio.get_event_loop() - ) + buf = aiohttp.FlowControlDataQueue(stream, 2**16, loop=asyncio.get_event_loop()) dbuf = DeflateBuffer(buf, "br") dbuf.decompressor = mock.Mock() @@ -1726,9 +1724,7 @@ async def test_feed_eof_no_err_brotli(self, stream) -> None: assert [b"line"] == list(d for d, _ in buf._buffer) async def test_empty_body(self, stream) -> None: - buf = aiohttp.FlowControlDataQueue( - stream, 2**16, loop=asyncio.get_event_loop() - ) + buf = aiohttp.FlowControlDataQueue(stream, 2**16, loop=asyncio.get_event_loop()) dbuf = DeflateBuffer(buf, "deflate") dbuf.feed_eof() diff --git a/tests/test_proxy.py b/tests/test_proxy.py index 6366a13d573..f335e42c254 100644 --- a/tests/test_proxy.py +++ b/tests/test_proxy.py @@ -4,6 +4,7 @@ import ssl import sys import unittest +from typing import Any from unittest import mock import pytest @@ -40,7 +41,12 @@ def tearDown(self): gc.collect() @mock.patch("aiohttp.connector.ClientRequest") - def test_connect(self, ClientRequestMock) -> None: + @mock.patch( + "aiohttp.connector.aiohappyeyeballs.start_connection", + autospec=True, + spec_set=True, + ) + def test_connect(self, start_connection: Any, ClientRequestMock: Any) -> None: req = ClientRequest( "GET", URL("http://www.python.org"), @@ -54,7 +60,18 @@ async def make_conn(): return aiohttp.TCPConnector() connector = self.loop.run_until_complete(make_conn()) - connector._resolve_host = make_mocked_coro([mock.MagicMock()]) + connector._resolve_host = make_mocked_coro( + [ + { + "hostname": "hostname", + "host": "127.0.0.1", + "port": 80, + "family": socket.AF_INET, + "proto": 0, + "flags": 0, + } + ] + ) proto = mock.Mock( **{ @@ -81,7 +98,12 @@ async def make_conn(): conn.close() @mock.patch("aiohttp.connector.ClientRequest") - def test_proxy_headers(self, ClientRequestMock) -> None: + @mock.patch( + "aiohttp.connector.aiohappyeyeballs.start_connection", + autospec=True, + spec_set=True, + ) + def test_proxy_headers(self, start_connection: Any, ClientRequestMock: Any) -> None: req = ClientRequest( "GET", URL("http://www.python.org"), @@ -96,7 +118,18 @@ async def make_conn(): return aiohttp.TCPConnector() connector = self.loop.run_until_complete(make_conn()) - connector._resolve_host = make_mocked_coro([mock.MagicMock()]) + connector._resolve_host = make_mocked_coro( + [ + { + "hostname": "hostname", + "host": "127.0.0.1", + "port": 80, + "family": socket.AF_INET, + "proto": 0, + "flags": 0, + } + ] + ) proto = mock.Mock( **{ @@ -122,7 +155,12 @@ async def make_conn(): conn.close() - def test_proxy_auth(self) -> None: + @mock.patch( + "aiohttp.connector.aiohappyeyeballs.start_connection", + autospec=True, + spec_set=True, + ) + def test_proxy_auth(self, start_connection: Any) -> None: with self.assertRaises(ValueError) as ctx: ClientRequest( "GET", @@ -136,11 +174,16 @@ def test_proxy_auth(self) -> None: "proxy_auth must be None or BasicAuth() tuple", ) - def test_proxy_dns_error(self) -> None: + @mock.patch( + "aiohttp.connector.aiohappyeyeballs.start_connection", + autospec=True, + spec_set=True, + ) + def test_proxy_dns_error(self, start_connection: Any) -> None: async def make_conn(): return aiohttp.TCPConnector() - connector = self.loop.run_until_complete(make_conn()) + connector: aiohttp.TCPConnector = self.loop.run_until_complete(make_conn()) connector._resolve_host = make_mocked_coro( raise_exception=OSError("dont take it serious") ) @@ -159,7 +202,12 @@ async def make_conn(): self.assertEqual(req.url.path, "/") self.assertEqual(dict(req.headers), expected_headers) - def test_proxy_connection_error(self) -> None: + @mock.patch( + "aiohttp.connector.aiohappyeyeballs.start_connection", + autospec=True, + spec_set=True, + ) + def test_proxy_connection_error(self, start_connection: Any) -> None: async def make_conn(): return aiohttp.TCPConnector() @@ -192,7 +240,14 @@ async def make_conn(): ) @mock.patch("aiohttp.connector.ClientRequest") - def test_proxy_server_hostname_default(self, ClientRequestMock) -> None: + @mock.patch( + "aiohttp.connector.aiohappyeyeballs.start_connection", + autospec=True, + spec_set=True, + ) + def test_proxy_server_hostname_default( + self, start_connection: Any, ClientRequestMock: Any + ) -> None: proxy_req = ClientRequest( "GET", URL("http://proxy.example.com"), loop=self.loop ) @@ -252,7 +307,14 @@ async def make_conn(): self.loop.run_until_complete(req.close()) @mock.patch("aiohttp.connector.ClientRequest") - def test_proxy_server_hostname_override(self, ClientRequestMock) -> None: + @mock.patch( + "aiohttp.connector.aiohappyeyeballs.start_connection", + autospec=True, + spec_set=True, + ) + def test_proxy_server_hostname_override( + self, start_connection: Any, ClientRequestMock: Any + ) -> None: proxy_req = ClientRequest( "GET", URL("http://proxy.example.com"), @@ -316,7 +378,12 @@ async def make_conn(): self.loop.run_until_complete(req.close()) @mock.patch("aiohttp.connector.ClientRequest") - def test_https_connect(self, ClientRequestMock) -> None: + @mock.patch( + "aiohttp.connector.aiohappyeyeballs.start_connection", + autospec=True, + spec_set=True, + ) + def test_https_connect(self, start_connection: Any, ClientRequestMock: Any) -> None: proxy_req = ClientRequest( "GET", URL("http://proxy.example.com"), loop=self.loop ) @@ -376,7 +443,14 @@ async def make_conn(): self.loop.run_until_complete(req.close()) @mock.patch("aiohttp.connector.ClientRequest") - def test_https_connect_certificate_error(self, ClientRequestMock) -> None: + @mock.patch( + "aiohttp.connector.aiohappyeyeballs.start_connection", + autospec=True, + spec_set=True, + ) + def test_https_connect_certificate_error( + self, start_connection: Any, ClientRequestMock: Any + ) -> None: proxy_req = ClientRequest( "GET", URL("http://proxy.example.com"), loop=self.loop ) @@ -430,7 +504,14 @@ async def make_conn(): ) @mock.patch("aiohttp.connector.ClientRequest") - def test_https_connect_ssl_error(self, ClientRequestMock) -> None: + @mock.patch( + "aiohttp.connector.aiohappyeyeballs.start_connection", + autospec=True, + spec_set=True, + ) + def test_https_connect_ssl_error( + self, start_connection: Any, ClientRequestMock: Any + ) -> None: proxy_req = ClientRequest( "GET", URL("http://proxy.example.com"), loop=self.loop ) @@ -486,7 +567,14 @@ async def make_conn(): ) @mock.patch("aiohttp.connector.ClientRequest") - def test_https_connect_http_proxy_error(self, ClientRequestMock) -> None: + @mock.patch( + "aiohttp.connector.aiohappyeyeballs.start_connection", + autospec=True, + spec_set=True, + ) + def test_https_connect_http_proxy_error( + self, start_connection: Any, ClientRequestMock: Any + ) -> None: proxy_req = ClientRequest( "GET", URL("http://proxy.example.com"), loop=self.loop ) @@ -545,7 +633,14 @@ async def make_conn(): self.loop.run_until_complete(req.close()) @mock.patch("aiohttp.connector.ClientRequest") - def test_https_connect_resp_start_error(self, ClientRequestMock) -> None: + @mock.patch( + "aiohttp.connector.aiohappyeyeballs.start_connection", + autospec=True, + spec_set=True, + ) + def test_https_connect_resp_start_error( + self, start_connection: Any, ClientRequestMock: Any + ) -> None: proxy_req = ClientRequest( "GET", URL("http://proxy.example.com"), loop=self.loop ) @@ -598,7 +693,12 @@ async def make_conn(): ) @mock.patch("aiohttp.connector.ClientRequest") - def test_request_port(self, ClientRequestMock) -> None: + @mock.patch( + "aiohttp.connector.aiohappyeyeballs.start_connection", + autospec=True, + spec_set=True, + ) + def test_request_port(self, start_connection: Any, ClientRequestMock: Any) -> None: proxy_req = ClientRequest( "GET", URL("http://proxy.example.com"), loop=self.loop ) @@ -656,7 +756,14 @@ def test_proxy_auth_property_default(self) -> None: self.assertIsNone(req.proxy_auth) @mock.patch("aiohttp.connector.ClientRequest") - def test_https_connect_pass_ssl_context(self, ClientRequestMock) -> None: + @mock.patch( + "aiohttp.connector.aiohappyeyeballs.start_connection", + autospec=True, + spec_set=True, + ) + def test_https_connect_pass_ssl_context( + self, start_connection: Any, ClientRequestMock: Any + ) -> None: proxy_req = ClientRequest( "GET", URL("http://proxy.example.com"), loop=self.loop ) @@ -724,7 +831,12 @@ async def make_conn(): self.loop.run_until_complete(req.close()) @mock.patch("aiohttp.connector.ClientRequest") - def test_https_auth(self, ClientRequestMock) -> None: + @mock.patch( + "aiohttp.connector.aiohappyeyeballs.start_connection", + autospec=True, + spec_set=True, + ) + def test_https_auth(self, start_connection: Any, ClientRequestMock: Any) -> None: proxy_req = ClientRequest( "GET", URL("http://proxy.example.com"), diff --git a/tests/test_proxy_functional.py b/tests/test_proxy_functional.py index 099922ac77f..c15ca326288 100644 --- a/tests/test_proxy_functional.py +++ b/tests/test_proxy_functional.py @@ -16,18 +16,6 @@ from aiohttp.client_exceptions import ClientConnectionError from aiohttp.helpers import IS_MACOS, IS_WINDOWS -pytestmark = [ - pytest.mark.filterwarnings( - "ignore:unclosed = (3, 11) @@ -121,16 +109,14 @@ async def test_secure_https_proxy_absolute_path( conn = aiohttp.TCPConnector() sess = aiohttp.ClientSession(connector=conn) - response = await sess.get( + async with sess.get( web_server_endpoint_url, proxy=secure_proxy_url, ssl=client_ssl_ctx, # used for both proxy and endpoint connections - ) - - assert response.status == 200 - assert await response.text() == web_server_endpoint_payload + ) as response: + assert response.status == 200 + assert await response.text() == web_server_endpoint_payload - response.close() await sess.close() await conn.close() @@ -192,13 +178,17 @@ async def test_https_proxy_unsupported_tls_in_tls( r"$" ) - with pytest.warns(RuntimeWarning, match=expected_warning_text,), pytest.raises( + with pytest.warns( + RuntimeWarning, + match=expected_warning_text, + ), pytest.raises( ClientConnectionError, match=expected_exception_reason, ) as conn_err: - await sess.get(url, proxy=secure_proxy_url, ssl=client_ssl_ctx) + async with sess.get(url, proxy=secure_proxy_url, ssl=client_ssl_ctx): + pass - assert type(conn_err.value.__cause__) == TypeError + assert isinstance(conn_err.value.__cause__, TypeError) assert match_regex(f"^{type_err!s}$", str(conn_err.value.__cause__)) await sess.close() @@ -256,13 +246,11 @@ async def proxy_server(): def get_request(loop): async def _request(method="GET", *, url, trust_env=False, **kwargs): connector = aiohttp.TCPConnector(ssl=False, loop=loop) - client = aiohttp.ClientSession(connector=connector, trust_env=trust_env) - try: - resp = await client.request(method, url, **kwargs) - await resp.release() - return resp - finally: - await client.close() + async with aiohttp.ClientSession( + connector=connector, trust_env=trust_env + ) as client: + async with client.request(method, url, **kwargs) as resp: + return resp return _request @@ -402,11 +390,8 @@ async def test_proxy_http_acquired_cleanup_force(proxy_test_server, loop) -> Non assert 0 == len(conn._acquired) async def request(): - resp = await sess.get(url, proxy=proxy.url) - - assert 1 == len(conn._acquired) - - await resp.release() + async with sess.get(url, proxy=proxy.url): + assert 1 == len(conn._acquired) await request() @@ -430,13 +415,11 @@ async def request(pid): # process requests only one by one nonlocal current_pid - resp = await sess.get(url, proxy=proxy.url) - - current_pid = pid - await asyncio.sleep(0.2, loop=loop) - assert current_pid == pid + async with sess.get(url, proxy=proxy.url) as resp: + current_pid = pid + await asyncio.sleep(0.2, loop=loop) + assert current_pid == pid - await resp.release() return resp requests = [request(pid) for pid in range(multi_conn_num)] @@ -487,9 +470,8 @@ async def xtest_proxy_https_send_body(proxy_test_server, loop): proxy.return_value = {"status": 200, "body": b"1" * (2**20)} url = "https://www.google.com.ua/search?q=aiohttp proxy" - resp = await sess.get(url, proxy=proxy.url) - body = await resp.read() - await resp.release() + async with sess.get(url, proxy=proxy.url) as resp: + body = await resp.read() await sess.close() assert body == b"1" * (2**20) @@ -583,11 +565,8 @@ async def xtest_proxy_https_acquired_cleanup(proxy_test_server, loop): assert 0 == len(conn._acquired) async def request(): - resp = await sess.get(url, proxy=proxy.url) - - assert 1 == len(conn._acquired) - - await resp.release() + async with sess.get(url, proxy=proxy.url): + assert 1 == len(conn._acquired) await request() @@ -607,11 +586,8 @@ async def xtest_proxy_https_acquired_cleanup_force(proxy_test_server, loop): assert 0 == len(conn._acquired) async def request(): - resp = await sess.get(url, proxy=proxy.url) - - assert 1 == len(conn._acquired) - - await resp.release() + async with sess.get(url, proxy=proxy.url): + assert 1 == len(conn._acquired) await request() @@ -635,13 +611,11 @@ async def request(pid): # process requests only one by one nonlocal current_pid - resp = await sess.get(url, proxy=proxy.url) - - current_pid = pid - await asyncio.sleep(0.2, loop=loop) - assert current_pid == pid + async with sess.get(url, proxy=proxy.url) as resp: + current_pid = pid + await asyncio.sleep(0.2, loop=loop) + assert current_pid == pid - await resp.release() return resp requests = [request(pid) for pid in range(multi_conn_num)] @@ -847,8 +821,9 @@ async def test_proxy_auth() -> None: with pytest.raises( ValueError, match=r"proxy_auth must be None or BasicAuth\(\) tuple" ): - await session.get( + async with session.get( "http://python.org", proxy="http://proxy.example.com", proxy_auth=("user", "pass"), - ) + ): + pass diff --git a/tests/test_pytest_plugin.py b/tests/test_pytest_plugin.py index b25a553b868..ad222545294 100644 --- a/tests/test_pytest_plugin.py +++ b/tests/test_pytest_plugin.py @@ -19,6 +19,8 @@ def test_aiohttp_plugin(testdir) -> None: from aiohttp import web +value = web.AppKey('value', str) + async def hello(request): return web.Response(body=b'Hello, world') @@ -75,10 +77,10 @@ async def test_noop() -> None: async def previous(request): if request.method == 'POST': with pytest.deprecated_call(): # FIXME: this isn't actually called - request.app['value'] = (await request.post())['value'] + request.app[value] = (await request.post())['value'] return web.Response(body=b'thanks for the data') else: - v = request.app.get('value', 'unknown') + v = request.app.get(value, 'unknown') return web.Response(body='value: {}'.format(v).encode()) @@ -98,7 +100,7 @@ async def test_set_value(cli) -> None: assert resp.status == 200 text = await resp.text() assert text == 'thanks for the data' - assert cli.server.app['value'] == 'foo' + assert cli.server.app[value] == 'foo' async def test_get_value(cli) -> None: @@ -107,7 +109,7 @@ async def test_get_value(cli) -> None: text = await resp.text() assert text == 'value: unknown' with pytest.warns(DeprecationWarning): - cli.server.app['value'] = 'bar' + cli.server.app[value] = 'bar' resp = await cli.get('/') assert resp.status == 200 text = await resp.text() @@ -119,7 +121,6 @@ def test_noncoro() -> None: async def test_failed_to_create_client(aiohttp_client) -> None: - def make_app(loop): raise RuntimeError() @@ -142,7 +143,6 @@ async def test_custom_port_test_server(aiohttp_server, aiohttp_unused_port): port = aiohttp_unused_port() server = await aiohttp_server(app, port=port) assert server.port == port - """ ) testdir.makeconftest(CONFTEST) diff --git a/tests/test_resolver.py b/tests/test_resolver.py index 1b389f3601b..f51506a6999 100644 --- a/tests/test_resolver.py +++ b/tests/test_resolver.py @@ -1,25 +1,57 @@ import asyncio import ipaddress import socket -from typing import Any, List +from ipaddress import ip_address +from typing import Any, Awaitable, Callable, Collection, List, NamedTuple, Tuple, Union from unittest.mock import Mock, patch import pytest -from aiohttp.resolver import AsyncResolver, DefaultResolver, ThreadedResolver +from aiohttp.resolver import ( + _NUMERIC_SOCKET_FLAGS, + _SUPPORTS_SCOPE_ID, + AsyncResolver, + DefaultResolver, + ThreadedResolver, +) try: import aiodns - gethostbyname = hasattr(aiodns.DNSResolver, "gethostbyname") + getaddrinfo: Any = hasattr(aiodns.DNSResolver, "getaddrinfo") except ImportError: - aiodns = None - gethostbyname = False + aiodns = None # type: ignore[assignment] + getaddrinfo = False -class FakeResult: - def __init__(self, addresses): - self.addresses = addresses +class FakeAIODNSAddrInfoNode(NamedTuple): + + family: int + addr: Union[Tuple[bytes, int], Tuple[bytes, int, int, int]] + + +class FakeAIODNSAddrInfoIPv4Result: + def __init__(self, hosts: Collection[str]) -> None: + self.nodes = [ + FakeAIODNSAddrInfoNode(socket.AF_INET, (h.encode(), 0)) for h in hosts + ] + + +class FakeAIODNSAddrInfoIPv6Result: + def __init__(self, hosts: Collection[str]) -> None: + self.nodes = [ + FakeAIODNSAddrInfoNode( + socket.AF_INET6, + (h.encode(), 0, 0, 3 if ip_address(h).is_link_local else 0), + ) + for h in hosts + ] + + +class FakeAIODNSNameInfoIPv6Result: + def __init__(self, host: str) -> None: + self.node = host + self.service = None class FakeQueryResult: @@ -27,16 +59,30 @@ def __init__(self, host): self.host = host -async def fake_result(addresses): - return FakeResult(addresses=tuple(addresses)) +async def fake_aiodns_getaddrinfo_ipv4_result( + hosts: Collection[str], +) -> FakeAIODNSAddrInfoIPv4Result: + return FakeAIODNSAddrInfoIPv4Result(hosts=hosts) + + +async def fake_aiodns_getaddrinfo_ipv6_result( + hosts: Collection[str], +) -> FakeAIODNSAddrInfoIPv6Result: + return FakeAIODNSAddrInfoIPv6Result(hosts=hosts) + + +async def fake_aiodns_getnameinfo_ipv6_result( + host: str, +) -> FakeAIODNSNameInfoIPv6Result: + return FakeAIODNSNameInfoIPv6Result(host) async def fake_query_result(result): return [FakeQueryResult(host=h) for h in result] -def fake_addrinfo(hosts): - async def fake(*args, **kwargs): +def fake_addrinfo(hosts: Collection[str]) -> Callable[..., Awaitable[Any]]: + async def fake(*args: Any, **kwargs: Any) -> List[Any]: if not hosts: raise socket.gaierror @@ -45,33 +91,83 @@ async def fake(*args, **kwargs): return fake -@pytest.mark.skipif(not gethostbyname, reason="aiodns 1.1 required") -async def test_async_resolver_positive_lookup(loop) -> None: +def fake_ipv6_addrinfo(hosts: Collection[str]) -> Callable[..., Awaitable[Any]]: + async def fake(*args: Any, **kwargs: Any) -> List[Any]: + if not hosts: + raise socket.gaierror + + return [ + ( + socket.AF_INET6, + None, + socket.SOCK_STREAM, + None, + (h, 0, 0, 3 if ip_address(h).is_link_local else 0), + ) + for h in hosts + ] + + return fake + + +def fake_ipv6_nameinfo(host: str) -> Callable[..., Awaitable[Any]]: + async def fake(*args: Any, **kwargs: Any) -> Tuple[str, int]: + return host, 0 + + return fake + + +@pytest.mark.skipif(not getaddrinfo, reason="aiodns >=3.2.0 required") +async def test_async_resolver_positive_ipv4_lookup(loop: Any) -> None: with patch("aiodns.DNSResolver") as mock: - mock().gethostbyname.return_value = fake_result(["127.0.0.1"]) - resolver = AsyncResolver(loop=loop) + mock().getaddrinfo.return_value = fake_aiodns_getaddrinfo_ipv4_result( + ["127.0.0.1"] + ) + resolver = AsyncResolver() real = await resolver.resolve("www.python.org") ipaddress.ip_address(real[0]["host"]) - mock().gethostbyname.assert_called_with("www.python.org", socket.AF_INET) - - -@pytest.mark.skipif(aiodns is None, reason="aiodns required") -async def test_async_resolver_query_positive_lookup(loop) -> None: + mock().getaddrinfo.assert_called_with( + "www.python.org", + family=socket.AF_INET, + flags=socket.AI_ADDRCONFIG, + port=0, + type=socket.SOCK_STREAM, + ) + + +@pytest.mark.skipif(not getaddrinfo, reason="aiodns >=3.2.0 required") +@pytest.mark.skipif( + not _SUPPORTS_SCOPE_ID, reason="python version does not support scope id" +) +async def test_async_resolver_positive_link_local_ipv6_lookup(loop: Any) -> None: with patch("aiodns.DNSResolver") as mock: - del mock().gethostbyname - mock().query.return_value = fake_query_result(["127.0.0.1"]) - resolver = AsyncResolver(loop=loop) + mock().getaddrinfo.return_value = fake_aiodns_getaddrinfo_ipv6_result( + ["fe80::1"] + ) + mock().getnameinfo.return_value = fake_aiodns_getnameinfo_ipv6_result( + "fe80::1%eth0" + ) + resolver = AsyncResolver() real = await resolver.resolve("www.python.org") ipaddress.ip_address(real[0]["host"]) - mock().query.assert_called_with("www.python.org", "A") - - -@pytest.mark.skipif(not gethostbyname, reason="aiodns 1.1 required") -async def test_async_resolver_multiple_replies(loop) -> None: + mock().getaddrinfo.assert_called_with( + "www.python.org", + family=socket.AF_INET, + flags=socket.AI_ADDRCONFIG, + port=0, + type=socket.SOCK_STREAM, + ) + mock().getnameinfo.assert_called_with( + ("fe80::1", 0, 0, 3), _NUMERIC_SOCKET_FLAGS + ) + + +@pytest.mark.skipif(not getaddrinfo, reason="aiodns >=3.2.0 required") +async def test_async_resolver_multiple_replies(loop: Any) -> None: with patch("aiodns.DNSResolver") as mock: ips = ["127.0.0.1", "127.0.0.2", "127.0.0.3", "127.0.0.4"] - mock().gethostbyname.return_value = fake_result(ips) - resolver = AsyncResolver(loop=loop) + mock().getaddrinfo.return_value = fake_aiodns_getaddrinfo_ipv4_result(ips) + resolver = AsyncResolver() real = await resolver.resolve("www.google.com") ips = [ipaddress.ip_address(x["host"]) for x in real] assert len(ips) > 3, "Expecting multiple addresses" @@ -88,40 +184,20 @@ async def test_async_resolver_query_multiple_replies(loop) -> None: ips = [ipaddress.ip_address(x["host"]) for x in real] -@pytest.mark.skipif(not gethostbyname, reason="aiodns 1.1 required") -async def test_async_resolver_negative_lookup(loop) -> None: - with patch("aiodns.DNSResolver") as mock: - mock().gethostbyname.side_effect = aiodns.error.DNSError() - resolver = AsyncResolver(loop=loop) - with pytest.raises(OSError): - await resolver.resolve("doesnotexist.bla") - - -@pytest.mark.skipif(aiodns is None, reason="aiodns required") -async def test_async_resolver_query_negative_lookup(loop) -> None: +@pytest.mark.skipif(not getaddrinfo, reason="aiodns >=3.2.0 required") +async def test_async_resolver_negative_lookup(loop: Any) -> None: with patch("aiodns.DNSResolver") as mock: - del mock().gethostbyname - mock().query.side_effect = aiodns.error.DNSError() - resolver = AsyncResolver(loop=loop) + mock().getaddrinfo.side_effect = aiodns.error.DNSError() + resolver = AsyncResolver() with pytest.raises(OSError): await resolver.resolve("doesnotexist.bla") -@pytest.mark.skipif(aiodns is None, reason="aiodns required") -async def test_async_resolver_no_hosts_in_query(loop) -> None: +@pytest.mark.skipif(not getaddrinfo, reason="aiodns >=3.2.0 required") +async def test_async_resolver_no_hosts_in_getaddrinfo(loop: Any) -> None: with patch("aiodns.DNSResolver") as mock: - del mock().gethostbyname - mock().query.return_value = fake_query_result([]) - resolver = AsyncResolver(loop=loop) - with pytest.raises(OSError): - await resolver.resolve("doesnotexist.bla") - - -@pytest.mark.skipif(not gethostbyname, reason="aiodns 1.1 required") -async def test_async_resolver_no_hosts_in_gethostbyname(loop) -> None: - with patch("aiodns.DNSResolver") as mock: - mock().gethostbyname.return_value = fake_result([]) - resolver = AsyncResolver(loop=loop) + mock().getaddrinfo.return_value = fake_aiodns_getaddrinfo_ipv4_result([]) + resolver = AsyncResolver() with pytest.raises(OSError): await resolver.resolve("doesnotexist.bla") @@ -135,6 +211,20 @@ async def test_threaded_resolver_positive_lookup() -> None: ipaddress.ip_address(real[0]["host"]) +@pytest.mark.skipif( + not _SUPPORTS_SCOPE_ID, reason="python version does not support scope id" +) +async def test_threaded_resolver_positive_ipv6_link_local_lookup() -> None: + loop = Mock() + loop.getaddrinfo = fake_ipv6_addrinfo(["fe80::1"]) + loop.getnameinfo = fake_ipv6_nameinfo("fe80::1%eth0") + resolver = ThreadedResolver() + resolver._loop = loop + real = await resolver.resolve("www.python.org") + assert real[0]["hostname"] == "www.python.org" + ipaddress.ip_address(real[0]["host"]) + + async def test_threaded_resolver_multiple_replies() -> None: loop = Mock() ips = ["127.0.0.1", "127.0.0.2", "127.0.0.3", "127.0.0.4"] @@ -154,6 +244,16 @@ async def test_threaded_negative_lookup() -> None: await resolver.resolve("doesnotexist.bla") +async def test_threaded_negative_ipv6_lookup() -> None: + loop = Mock() + ips: List[Any] = [] + loop.getaddrinfo = fake_ipv6_addrinfo(ips) + resolver = ThreadedResolver() + resolver._loop = loop + with pytest.raises(socket.gaierror): + await resolver.resolve("doesnotexist.bla") + + async def test_threaded_negative_lookup_with_unknown_result() -> None: loop = Mock() @@ -195,21 +295,20 @@ async def test_default_loop_for_threaded_resolver(loop) -> None: assert resolver._loop is loop -@pytest.mark.skipif(aiodns is None, reason="aiodns required") -async def test_default_loop_for_async_resolver(loop) -> None: - asyncio.set_event_loop(loop) - resolver = AsyncResolver() - assert resolver._loop is loop - - -@pytest.mark.skipif(not gethostbyname, reason="aiodns 1.1 required") -async def test_async_resolver_ipv6_positive_lookup(loop) -> None: +@pytest.mark.skipif(not getaddrinfo, reason="aiodns >=3.2.0 required") +async def test_async_resolver_ipv6_positive_lookup(loop: Any) -> None: with patch("aiodns.DNSResolver") as mock: - mock().gethostbyname.return_value = fake_result(["::1"]) - resolver = AsyncResolver(loop=loop) - real = await resolver.resolve("www.python.org", family=socket.AF_INET6) + mock().getaddrinfo.return_value = fake_aiodns_getaddrinfo_ipv6_result(["::1"]) + resolver = AsyncResolver() + real = await resolver.resolve("www.python.org") ipaddress.ip_address(real[0]["host"]) - mock().gethostbyname.assert_called_with("www.python.org", socket.AF_INET6) + mock().getaddrinfo.assert_called_with( + "www.python.org", + family=socket.AF_INET, + flags=socket.AI_ADDRCONFIG, + port=0, + type=socket.SOCK_STREAM, + ) @pytest.mark.skipif(aiodns is None, reason="aiodns required") @@ -229,9 +328,11 @@ async def test_async_resolver_aiodns_not_present(loop, monkeypatch) -> None: AsyncResolver(loop=loop) -def test_default_resolver() -> None: - # if gethostbyname: - # assert DefaultResolver is AsyncResolver - # else: - # assert DefaultResolver is ThreadedResolver +@pytest.mark.skipif(not getaddrinfo, reason="aiodns >=3.2.0 required") +def test_aio_dns_is_default() -> None: + assert DefaultResolver is AsyncResolver + + +@pytest.mark.skipif(getaddrinfo, reason="aiodns <3.2.0 required") +def test_threaded_resolver_is_default() -> None: assert DefaultResolver is ThreadedResolver diff --git a/tests/test_run_app.py b/tests/test_run_app.py index 5696928b219..c1d5f8e14f4 100644 --- a/tests/test_run_app.py +++ b/tests/test_run_app.py @@ -15,7 +15,7 @@ import pytest -from aiohttp import ClientConnectorError, ClientSession, WSCloseCode, web +from aiohttp import ClientConnectorError, ClientSession, ClientTimeout, WSCloseCode, web from aiohttp.test_utils import make_mocked_coro from aiohttp.web_runner import BaseRunner @@ -915,13 +915,34 @@ async def stop(self, request: web.Request) -> web.Response: return web.Response() def run_app(self, port: int, timeout: int, task, extra_test=None) -> asyncio.Task: + num_connections = -1 + + class DictRecordClear(dict): + def clear(self): + nonlocal num_connections + # During Server.shutdown() we want to know how many connections still + # remained before it got cleared. If the handler completed successfully + # the connection should've been removed already. If not, this may + # indicate a memory leak. + num_connections = len(self) + super().clear() + + class ServerWithRecordClear(web.Server): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._connections = DictRecordClear() + async def test() -> None: await asyncio.sleep(0.5) async with ClientSession() as sess: for _ in range(5): # pragma: no cover try: - async with sess.get(f"http://localhost:{port}/"): - pass + with pytest.raises(asyncio.TimeoutError): + async with sess.get( + f"http://localhost:{port}/", + timeout=ClientTimeout(total=0.1), + ): + pass except ClientConnectorError: await asyncio.sleep(0.5) else: @@ -941,6 +962,7 @@ async def run_test(app: web.Application) -> None: async def handler(request: web.Request) -> web.Response: nonlocal t t = asyncio.create_task(task()) + await t return web.Response(text="FOO") t = test_task = None @@ -949,11 +971,12 @@ async def handler(request: web.Request) -> web.Response: app.router.add_get("/", handler) app.router.add_get("/stop", self.stop) - web.run_app(app, port=port, shutdown_timeout=timeout) + with mock.patch("aiohttp.web_app.Server", ServerWithRecordClear): + web.run_app(app, port=port, shutdown_timeout=timeout) assert test_task.exception() is None - return t + return t, num_connections - def test_shutdown_wait_for_task( + def test_shutdown_wait_for_handler( self, aiohttp_unused_port: Callable[[], int] ) -> None: port = aiohttp_unused_port() @@ -964,13 +987,14 @@ async def task(): await asyncio.sleep(2) finished = True - t = self.run_app(port, 3, task) + t, connection_count = self.run_app(port, 3, task) assert finished is True assert t.done() assert not t.cancelled() + assert connection_count == 0 - def test_shutdown_timeout_task( + def test_shutdown_timeout_handler( self, aiohttp_unused_port: Callable[[], int] ) -> None: port = aiohttp_unused_port() @@ -981,39 +1005,12 @@ async def task(): await asyncio.sleep(2) finished = True - t = self.run_app(port, 1, task) + t, connection_count = self.run_app(port, 1, task) assert finished is False assert t.done() assert t.cancelled() - - def test_shutdown_wait_for_spawned_task( - self, aiohttp_unused_port: Callable[[], int] - ) -> None: - port = aiohttp_unused_port() - finished = False - finished_sub = False - sub_t = None - - async def sub_task(): - nonlocal finished_sub - await asyncio.sleep(1.5) - finished_sub = True - - async def task(): - nonlocal finished, sub_t - await asyncio.sleep(0.5) - sub_t = asyncio.create_task(sub_task()) - finished = True - - t = self.run_app(port, 3, task) - - assert finished is True - assert t.done() - assert not t.cancelled() - assert finished_sub is True - assert sub_t.done() - assert not sub_t.cancelled() + assert connection_count == 1 def test_shutdown_timeout_not_reached( self, aiohttp_unused_port: Callable[[], int] @@ -1027,10 +1024,11 @@ async def task(): finished = True start_time = time.time() - t = self.run_app(port, 15, task) + t, connection_count = self.run_app(port, 15, task) assert finished is True assert t.done() + assert connection_count == 0 # Verify run_app has not waited for timeout. assert time.time() - start_time < 10 @@ -1055,10 +1053,11 @@ async def test(sess: ClientSession) -> None: pass assert finished is False - t = self.run_app(port, 10, task, test) + t, connection_count = self.run_app(port, 10, task, test) assert finished is True assert t.done() + assert connection_count == 0 def test_shutdown_pending_handler_responds( self, aiohttp_unused_port: Callable[[], int] @@ -1191,3 +1190,54 @@ async def run_test(app: web.Application) -> None: assert time.time() - start < 5 assert client_finished assert server_finished + + def test_shutdown_handler_cancellation_suppressed( + self, aiohttp_unused_port: Callable[[], int] + ) -> None: + port = aiohttp_unused_port() + actions = [] + + async def test() -> None: + async def test_resp(sess): + t = ClientTimeout(total=0.4) + with pytest.raises(asyncio.TimeoutError): + async with sess.get(f"http://localhost:{port}/", timeout=t) as resp: + assert await resp.text() == "FOO" + actions.append("CANCELLED") + + async with ClientSession() as sess: + t = asyncio.create_task(test_resp(sess)) + await asyncio.sleep(0.5) + # Handler is in-progress while we trigger server shutdown. + actions.append("PRESTOP") + async with sess.get(f"http://localhost:{port}/stop"): + pass + + actions.append("STOPPING") + # Handler should still complete and produce a response. + await t + + async def run_test(app: web.Application) -> None: + nonlocal t + t = asyncio.create_task(test()) + yield + await t + + async def handler(request: web.Request) -> web.Response: + try: + await asyncio.sleep(5) + except asyncio.CancelledError: + actions.append("SUPPRESSED") + await asyncio.sleep(2) + actions.append("DONE") + return web.Response(text="FOO") + + t = None + app = web.Application() + app.cleanup_ctx.append(run_test) + app.router.add_get("/", handler) + app.router.add_get("/stop", self.stop) + + web.run_app(app, port=port, shutdown_timeout=2, handler_cancellation=True) + assert t.exception() is None + assert actions == ["CANCELLED", "SUPPRESSED", "PRESTOP", "STOPPING", "DONE"] diff --git a/tests/test_test_utils.py b/tests/test_test_utils.py index 1ac742f78b1..328f83c3fd4 100644 --- a/tests/test_test_utils.py +++ b/tests/test_test_utils.py @@ -259,7 +259,7 @@ async def test_test_client_props(loop) -> None: async def test_test_client_raw_server_props(loop) -> None: async def hello(request): - return web.Response(body=_hello_world_bytes) + return web.Response() # pragma: no cover client = _TestClient(_RawTestServer(hello, host="127.0.0.1", loop=loop), loop=loop) assert client.host == "127.0.0.1" diff --git a/tests/test_urldispatch.py b/tests/test_urldispatch.py index 4f3abb8bcd7..2453ab5a235 100644 --- a/tests/test_urldispatch.py +++ b/tests/test_urldispatch.py @@ -339,6 +339,21 @@ def test_route_dynamic(router) -> None: assert route is route2 +def test_add_static_path_checks(router: any, tmp_path: pathlib.Path) -> None: + """Test that static paths must exist and be directories.""" + with pytest.raises(ValueError, match="does not exist"): + router.add_static("/", tmp_path / "does-not-exist") + with pytest.raises(ValueError, match="is not a directory"): + router.add_static("/", __file__) + + +def test_add_static_path_resolution(router: any) -> None: + """Test that static paths are expanded and absolute.""" + res = router.add_static("/", "~/..") + directory = str(res.get_info()["directory"]) + assert directory == str(pathlib.Path.home().parent) + + def test_add_static(router) -> None: resource = router.add_static( "/st", pathlib.Path(aiohttp.__file__).parent, name="static" @@ -1258,10 +1273,17 @@ async def test_prefixed_subapp_overlap(app) -> None: subapp2.router.add_get("/b", handler2) app.add_subapp("/ss", subapp2) + subapp3 = web.Application() + handler3 = make_handler() + subapp3.router.add_get("/c", handler3) + app.add_subapp("/s/s", subapp3) + match_info = await app.router.resolve(make_mocked_request("GET", "/s/a")) assert match_info.route.handler is handler1 match_info = await app.router.resolve(make_mocked_request("GET", "/ss/b")) assert match_info.route.handler is handler2 + match_info = await app.router.resolve(make_mocked_request("GET", "/s/s/c")) + assert match_info.route.handler is handler3 async def test_prefixed_subapp_empty_route(app) -> None: diff --git a/tests/test_web_app.py b/tests/test_web_app.py index 3688cf2b492..3d3aa2479f6 100644 --- a/tests/test_web_app.py +++ b/tests/test_web_app.py @@ -331,7 +331,7 @@ def test_app_run_middlewares() -> None: @web.middleware async def middleware(request: web.Request, handler: Handler) -> web.StreamResponse: - return await handler(request) + return await handler(request) # pragma: no cover root = web.Application(middlewares=[middleware]) sub = web.Application() diff --git a/tests/test_web_request_handler.py b/tests/test_web_request_handler.py index 06f99be76c0..4837cab030e 100644 --- a/tests/test_web_request_handler.py +++ b/tests/test_web_request_handler.py @@ -22,19 +22,21 @@ async def test_connections() -> None: manager = web.Server(serve) assert manager.connections == [] - handler = object() + handler = mock.Mock(spec_set=web.RequestHandler) + handler._task_handler = None transport = object() manager.connection_made(handler, transport) # type: ignore[arg-type] assert manager.connections == [handler] - manager.connection_lost(handler, None) # type: ignore[arg-type] + manager.connection_lost(handler, None) assert manager.connections == [] async def test_shutdown_no_timeout() -> None: manager = web.Server(serve) - handler = mock.Mock() + handler = mock.Mock(spec_set=web.RequestHandler) + handler._task_handler = None handler.shutdown = make_mocked_coro(mock.Mock()) transport = mock.Mock() manager.connection_made(handler, transport) diff --git a/tests/test_web_runner.py b/tests/test_web_runner.py index c4843d298ab..c7c94263234 100644 --- a/tests/test_web_runner.py +++ b/tests/test_web_runner.py @@ -16,7 +16,7 @@ def app(): @pytest.fixture -def make_runner(loop, app): +def make_runner(loop: Any, app: Any): asyncio.set_event_loop(loop) runners = [] @@ -30,7 +30,7 @@ def go(**kwargs): loop.run_until_complete(runner.cleanup()) -async def test_site_for_nonfrozen_app(make_runner) -> None: +async def test_site_for_nonfrozen_app(make_runner: Any) -> None: runner = make_runner() with pytest.raises(RuntimeError): web.TCPSite(runner) @@ -40,7 +40,7 @@ async def test_site_for_nonfrozen_app(make_runner) -> None: @pytest.mark.skipif( platform.system() == "Windows", reason="the test is not valid for Windows" ) -async def test_runner_setup_handle_signals(make_runner) -> None: +async def test_runner_setup_handle_signals(make_runner: Any) -> None: runner = make_runner(handle_signals=True) await runner.setup() assert signal.getsignal(signal.SIGTERM) is not signal.SIG_DFL @@ -51,7 +51,7 @@ async def test_runner_setup_handle_signals(make_runner) -> None: @pytest.mark.skipif( platform.system() == "Windows", reason="the test is not valid for Windows" ) -async def test_runner_setup_without_signal_handling(make_runner) -> None: +async def test_runner_setup_without_signal_handling(make_runner: Any) -> None: runner = make_runner(handle_signals=False) await runner.setup() assert signal.getsignal(signal.SIGTERM) is signal.SIG_DFL @@ -59,7 +59,7 @@ async def test_runner_setup_without_signal_handling(make_runner) -> None: assert signal.getsignal(signal.SIGTERM) is signal.SIG_DFL -async def test_site_double_added(make_runner) -> None: +async def test_site_double_added(make_runner: Any) -> None: _sock = get_unused_port_socket("127.0.0.1") runner = make_runner() await runner.setup() @@ -71,7 +71,7 @@ async def test_site_double_added(make_runner) -> None: assert len(runner.sites) == 1 -async def test_site_stop_not_started(make_runner) -> None: +async def test_site_stop_not_started(make_runner: Any) -> None: runner = make_runner() await runner.setup() site = web.TCPSite(runner) @@ -81,13 +81,13 @@ async def test_site_stop_not_started(make_runner) -> None: assert len(runner.sites) == 0 -async def test_custom_log_format(make_runner) -> None: +async def test_custom_log_format(make_runner: Any) -> None: runner = make_runner(access_log_format="abc") await runner.setup() assert runner.server._kwargs["access_log_format"] == "abc" -async def test_unreg_site(make_runner) -> None: +async def test_unreg_site(make_runner: Any) -> None: runner = make_runner() await runner.setup() site = web.TCPSite(runner) @@ -95,7 +95,7 @@ async def test_unreg_site(make_runner) -> None: runner._unreg_site(site) -async def test_app_property(make_runner, app) -> None: +async def test_app_property(make_runner: Any, app: Any) -> None: runner = make_runner() assert runner.app is app @@ -121,7 +121,9 @@ async def test_addresses(make_runner, unix_sockname) -> None: @pytest.mark.skipif( platform.system() != "Windows", reason="Proactor Event loop present only in Windows" ) -async def test_named_pipe_runner_wrong_loop(app, selector_loop, pipe_name) -> None: +async def test_named_pipe_runner_wrong_loop( + app: Any, selector_loop: Any, pipe_name: Any +) -> None: runner = web.AppRunner(app) await runner.setup() with pytest.raises(RuntimeError): @@ -131,7 +133,9 @@ async def test_named_pipe_runner_wrong_loop(app, selector_loop, pipe_name) -> No @pytest.mark.skipif( platform.system() != "Windows", reason="Proactor Event loop present only in Windows" ) -async def test_named_pipe_runner_proactor_loop(proactor_loop, app, pipe_name) -> None: +async def test_named_pipe_runner_proactor_loop( + proactor_loop: Any, app: Any, pipe_name: Any +) -> None: runner = web.AppRunner(app) await runner.setup() pipe = web.NamedPipeSite(runner, pipe_name) @@ -139,7 +143,7 @@ async def test_named_pipe_runner_proactor_loop(proactor_loop, app, pipe_name) -> await runner.cleanup() -async def test_tcpsite_default_host(make_runner): +async def test_tcpsite_default_host(make_runner: Any) -> None: runner = make_runner() await runner.setup() site = web.TCPSite(runner) diff --git a/tests/test_web_sendfile.py b/tests/test_web_sendfile.py index d472c407b7a..58a46ec602c 100644 --- a/tests/test_web_sendfile.py +++ b/tests/test_web_sendfile.py @@ -1,10 +1,13 @@ from pathlib import Path +from stat import S_IFREG, S_IRUSR, S_IWUSR from unittest import mock from aiohttp import hdrs from aiohttp.test_utils import make_mocked_coro, make_mocked_request from aiohttp.web_fileresponse import FileResponse +MOCK_MODE = S_IFREG | S_IRUSR | S_IWUSR + def test_using_gzip_if_header_present_and_file_available(loop) -> None: request = make_mocked_request( @@ -15,12 +18,13 @@ def test_using_gzip_if_header_present_and_file_available(loop) -> None: ) gz_filepath = mock.create_autospec(Path, spec_set=True) - gz_filepath.stat.return_value.st_size = 1024 - gz_filepath.stat.return_value.st_mtime_ns = 1603733507222449291 + gz_filepath.lstat.return_value.st_size = 1024 + gz_filepath.lstat.return_value.st_mtime_ns = 1603733507222449291 + gz_filepath.lstat.return_value.st_mode = MOCK_MODE filepath = mock.create_autospec(Path, spec_set=True) filepath.name = "logo.png" - filepath.with_name.return_value = gz_filepath + filepath.with_suffix.return_value = gz_filepath file_sender = FileResponse(filepath) file_sender._path = filepath @@ -36,14 +40,16 @@ def test_gzip_if_header_not_present_and_file_available(loop) -> None: request = make_mocked_request("GET", "http://python.org/logo.png", headers={}) gz_filepath = mock.create_autospec(Path, spec_set=True) - gz_filepath.stat.return_value.st_size = 1024 - gz_filepath.stat.return_value.st_mtime_ns = 1603733507222449291 + gz_filepath.lstat.return_value.st_size = 1024 + gz_filepath.lstat.return_value.st_mtime_ns = 1603733507222449291 + gz_filepath.lstat.return_value.st_mode = MOCK_MODE filepath = mock.create_autospec(Path, spec_set=True) filepath.name = "logo.png" - filepath.with_name.return_value = gz_filepath + filepath.with_suffix.return_value = gz_filepath filepath.stat.return_value.st_size = 1024 filepath.stat.return_value.st_mtime_ns = 1603733507222449291 + filepath.stat.return_value.st_mode = MOCK_MODE file_sender = FileResponse(filepath) file_sender._path = filepath @@ -63,9 +69,10 @@ def test_gzip_if_header_not_present_and_file_not_available(loop) -> None: filepath = mock.create_autospec(Path, spec_set=True) filepath.name = "logo.png" - filepath.with_name.return_value = gz_filepath + filepath.with_suffix.return_value = gz_filepath filepath.stat.return_value.st_size = 1024 filepath.stat.return_value.st_mtime_ns = 1603733507222449291 + filepath.stat.return_value.st_mode = MOCK_MODE file_sender = FileResponse(filepath) file_sender._path = filepath @@ -83,13 +90,14 @@ def test_gzip_if_header_present_and_file_not_available(loop) -> None: ) gz_filepath = mock.create_autospec(Path, spec_set=True) - gz_filepath.stat.side_effect = OSError(2, "No such file or directory") + gz_filepath.lstat.side_effect = OSError(2, "No such file or directory") filepath = mock.create_autospec(Path, spec_set=True) filepath.name = "logo.png" - filepath.with_name.return_value = gz_filepath + filepath.with_suffix.return_value = gz_filepath filepath.stat.return_value.st_size = 1024 filepath.stat.return_value.st_mtime_ns = 1603733507222449291 + filepath.stat.return_value.st_mode = MOCK_MODE file_sender = FileResponse(filepath) file_sender._path = filepath @@ -108,6 +116,7 @@ def test_status_controlled_by_user(loop) -> None: filepath.name = "logo.png" filepath.stat.return_value.st_size = 1024 filepath.stat.return_value.st_mtime_ns = 1603733507222449291 + filepath.stat.return_value.st_mode = MOCK_MODE file_sender = FileResponse(filepath, status=203) file_sender._path = filepath diff --git a/tests/test_web_sendfile_functional.py b/tests/test_web_sendfile_functional.py index 57ac0849efa..e2cfb7a1f0e 100644 --- a/tests/test_web_sendfile_functional.py +++ b/tests/test_web_sendfile_functional.py @@ -1,4 +1,5 @@ import asyncio +import bz2 import gzip import pathlib import socket @@ -10,6 +11,11 @@ import aiohttp from aiohttp import web +try: + import brotlicffi as brotli +except ImportError: + import brotli + try: import ssl except ImportError: @@ -27,9 +33,16 @@ def hello_txt(request, tmp_path_factory) -> pathlib.Path: indirect parameter can be passed with an encoding to get a compressed path. """ txt = tmp_path_factory.mktemp("hello-") / "hello.txt" - hello = {None: txt, "gzip": txt.with_suffix(f"{txt.suffix}.gz")} - hello[None].write_bytes(HELLO_AIOHTTP) + hello = { + None: txt, + "gzip": txt.with_suffix(f"{txt.suffix}.gz"), + "br": txt.with_suffix(f"{txt.suffix}.br"), + "bzip2": txt.with_suffix(f"{txt.suffix}.bz2"), + } + # Uncompressed file is not actually written to test it is not required. hello["gzip"].write_bytes(gzip.compress(HELLO_AIOHTTP)) + hello["br"].write_bytes(brotli.compress(HELLO_AIOHTTP)) + hello["bzip2"].write_bytes(bz2.compress(HELLO_AIOHTTP)) encoding = getattr(request, "param", None) return hello[encoding] @@ -220,7 +233,7 @@ async def handler(request): await client.close() -@pytest.mark.parametrize("hello_txt", ["gzip"], indirect=True) +@pytest.mark.parametrize("hello_txt", ["gzip", "br"], indirect=True) async def test_static_file_custom_content_type( hello_txt: pathlib.Path, aiohttp_client: Any, sender: Any ) -> None: @@ -245,8 +258,16 @@ async def handler(request): await client.close() +@pytest.mark.parametrize( + ("accept_encoding", "expect_encoding"), + [("gzip, deflate", "gzip"), ("gzip, deflate, br", "br")], +) async def test_static_file_custom_content_type_compress( - hello_txt: pathlib.Path, aiohttp_client: Any, sender: Any + hello_txt: pathlib.Path, + aiohttp_client: Any, + sender: Any, + accept_encoding: str, + expect_encoding: str, ): """Test that custom type with encoding is returned for unencoded requests.""" @@ -259,9 +280,9 @@ async def handler(request): app.router.add_get("/", handler) client = await aiohttp_client(app) - resp = await client.get("/") + resp = await client.get("/", headers={"Accept-Encoding": accept_encoding}) assert resp.status == 200 - assert resp.headers.get("Content-Encoding") == "gzip" + assert resp.headers.get("Content-Encoding") == expect_encoding assert resp.headers["Content-Type"] == "application/pdf" assert await resp.read() == HELLO_AIOHTTP resp.close() @@ -269,11 +290,17 @@ async def handler(request): await client.close() +@pytest.mark.parametrize( + ("accept_encoding", "expect_encoding"), + [("gzip, deflate", "gzip"), ("gzip, deflate, br", "br")], +) @pytest.mark.parametrize("forced_compression", [None, web.ContentCoding.gzip]) async def test_static_file_with_encoding_and_enable_compression( hello_txt: pathlib.Path, aiohttp_client: Any, sender: Any, + accept_encoding: str, + expect_encoding: str, forced_compression: Optional[web.ContentCoding], ): """Test that enable_compression does not double compress when an encoded file is also present.""" @@ -287,9 +314,9 @@ async def handler(request): app.router.add_get("/", handler) client = await aiohttp_client(app) - resp = await client.get("/") + resp = await client.get("/", headers={"Accept-Encoding": accept_encoding}) assert resp.status == 200 - assert resp.headers.get("Content-Encoding") == "gzip" + assert resp.headers.get("Content-Encoding") == expect_encoding assert resp.headers["Content-Type"] == "text/plain" assert await resp.read() == HELLO_AIOHTTP resp.close() @@ -298,10 +325,16 @@ async def handler(request): @pytest.mark.parametrize( - ("hello_txt", "expect_encoding"), [["gzip"] * 2], indirect=["hello_txt"] + ("hello_txt", "expect_type"), + [ + ("gzip", "application/gzip"), + ("br", "application/x-brotli"), + ("bzip2", "application/x-bzip2"), + ], + indirect=["hello_txt"], ) async def test_static_file_with_content_encoding( - hello_txt: pathlib.Path, aiohttp_client: Any, sender: Any, expect_encoding: str + hello_txt: pathlib.Path, aiohttp_client: Any, sender: Any, expect_type: str ) -> None: """Test requesting static compressed files returns the correct content type and encoding.""" @@ -314,9 +347,9 @@ async def handler(request): resp = await client.get("/") assert resp.status == 200 - assert resp.headers.get("Content-Encoding") == expect_encoding - assert resp.headers["Content-Type"] == "text/plain" - assert await resp.read() == HELLO_AIOHTTP + assert resp.headers.get("Content-Encoding") is None + assert resp.headers["Content-Type"] == expect_type + assert await resp.read() == hello_txt.read_bytes() resp.close() await resp.release() @@ -571,15 +604,6 @@ async def test_static_file_directory_traversal_attack(aiohttp_client) -> None: await client.close() -def test_static_route_path_existence_check() -> None: - directory = pathlib.Path(__file__).parent - web.StaticResource("/", directory) - - nodirectory = directory / "nonexistent-uPNiOEAg5d" - with pytest.raises(ValueError): - web.StaticResource("/", nodirectory) - - async def test_static_file_huge(aiohttp_client, tmp_path) -> None: file_path = tmp_path / "huge_data.unknown_mime_type" diff --git a/tests/test_web_server.py b/tests/test_web_server.py index d0fd95acdb4..14d78e23a85 100644 --- a/tests/test_web_server.py +++ b/tests/test_web_server.py @@ -4,7 +4,7 @@ import pytest -from aiohttp import client, helpers, web +from aiohttp import client, web async def test_simple_server(aiohttp_raw_server, aiohttp_client) -> None: @@ -19,12 +19,6 @@ async def handler(request): assert txt == "/path/to" -@pytest.mark.xfail( - not helpers.NO_EXTENSIONS, - raises=client.ServerDisconnectedError, - reason="The behavior of C-extensions differs from pure-Python: " - "https://github.com/aio-libs/aiohttp/issues/6446", -) async def test_unsupported_upgrade(aiohttp_raw_server, aiohttp_client) -> None: # don't fail if a client probes for an unsupported protocol upgrade # https://github.com/aio-libs/aiohttp/issues/6446#issuecomment-999032039 diff --git a/tests/test_web_urldispatcher.py b/tests/test_web_urldispatcher.py index 0441890c10b..3a45b9355f5 100644 --- a/tests/test_web_urldispatcher.py +++ b/tests/test_web_urldispatcher.py @@ -1,17 +1,18 @@ import asyncio import functools +import os import pathlib +import socket import sys -from typing import Optional -from unittest import mock -from unittest.mock import MagicMock +from stat import S_IFIFO, S_IMODE +from typing import Any, Generator, Optional import pytest import yarl from aiohttp import abc, web from aiohttp.pytest_plugin import AiohttpClient -from aiohttp.web_urldispatcher import SystemRoute +from aiohttp.web_urldispatcher import Resource, SystemRoute @pytest.mark.parametrize( @@ -330,7 +331,6 @@ async def test_access_to_the_file_with_spaces( r = await client.get(url) assert r.status == 200 assert (await r.text()) == data - await r.release() async def test_access_non_existing_resource( @@ -380,7 +380,7 @@ async def test_handler_metadata_persistence() -> None: async def async_handler(request: web.Request) -> web.Response: """Doc""" - return web.Response() + return web.Response() # pragma: no cover def sync_handler(request): """Doc""" @@ -395,31 +395,111 @@ def sync_handler(request): assert route.handler.__doc__ == "Doc" -async def test_unauthorized_folder_access( - tmp_path: pathlib.Path, aiohttp_client: AiohttpClient +@pytest.mark.skipif( + sys.platform.startswith("win32"), reason="Cannot remove read access on Windows" +) +@pytest.mark.parametrize("file_request", ["", "my_file.txt"]) +async def test_static_directory_without_read_permission( + tmp_path: pathlib.Path, aiohttp_client: AiohttpClient, file_request: str +) -> None: + """Test static directory without read permission receives forbidden response.""" + my_dir = tmp_path / "my_dir" + my_dir.mkdir() + my_dir.chmod(0o000) + + app = web.Application() + app.router.add_static("/", str(tmp_path), show_index=True) + client = await aiohttp_client(app) + + r = await client.get(f"/{my_dir.name}/{file_request}") + assert r.status == 403 + + +@pytest.mark.parametrize("file_request", ["", "my_file.txt"]) +async def test_static_directory_with_mock_permission_error( + monkeypatch: pytest.MonkeyPatch, + tmp_path: pathlib.Path, + aiohttp_client: AiohttpClient, + file_request: str, ) -> None: - # Tests the unauthorized access to a folder of static file server. - # Try to list a folder content of static file server when server does not - # have permissions to do so for the folder. + """Test static directory with mock permission errors receives forbidden response.""" my_dir = tmp_path / "my_dir" my_dir.mkdir() + real_iterdir = pathlib.Path.iterdir + real_is_dir = pathlib.Path.is_dir + + def mock_iterdir(self: pathlib.Path) -> Generator[pathlib.Path, None, None]: + if my_dir.samefile(self): + raise PermissionError() + return real_iterdir(self) + + def mock_is_dir(self: pathlib.Path, **kwargs: Any) -> bool: + if my_dir.samefile(self.parent): + raise PermissionError() + return real_is_dir(self, **kwargs) + + monkeypatch.setattr("pathlib.Path.iterdir", mock_iterdir) + monkeypatch.setattr("pathlib.Path.is_dir", mock_is_dir) + + app = web.Application() + app.router.add_static("/", str(tmp_path), show_index=True) + client = await aiohttp_client(app) + + r = await client.get("/") + assert r.status == 200 + r = await client.get(f"/{my_dir.name}/{file_request}") + assert r.status == 403 + + +@pytest.mark.skipif( + sys.platform.startswith("win32"), reason="Cannot remove read access on Windows" +) +async def test_static_file_without_read_permission( + tmp_path: pathlib.Path, aiohttp_client: AiohttpClient +) -> None: + """Test static file without read permission receives forbidden response.""" + my_file = tmp_path / "my_file.txt" + my_file.write_text("secret") + my_file.chmod(0o000) + app = web.Application() + app.router.add_static("/", str(tmp_path)) + client = await aiohttp_client(app) - with mock.patch("pathlib.Path.__new__") as path_constructor: - path = MagicMock() - path.joinpath.return_value = path - path.resolve.return_value = path - path.iterdir.return_value.__iter__.side_effect = PermissionError() - path_constructor.return_value = path + r = await client.get(f"/{my_file.name}") + assert r.status == 403 - # Register global static route: - app.router.add_static("/", str(tmp_path), show_index=True) - client = await aiohttp_client(app) - # Request the root of the static directory. - r = await client.get("/" + my_dir.name) - assert r.status == 403 +async def test_static_file_with_mock_permission_error( + monkeypatch: pytest.MonkeyPatch, + tmp_path: pathlib.Path, + aiohttp_client: AiohttpClient, +) -> None: + """Test static file with mock permission errors receives forbidden response.""" + my_file = tmp_path / "my_file.txt" + my_file.write_text("secret") + my_readable = tmp_path / "my_readable.txt" + my_readable.write_text("info") + + real_open = pathlib.Path.open + + def mock_open(self: pathlib.Path, *args: Any, **kwargs: Any) -> Any: + if my_file.samefile(self): + raise PermissionError() + return real_open(self, *args, **kwargs) + + monkeypatch.setattr("pathlib.Path.open", mock_open) + + app = web.Application() + app.router.add_static("/", str(tmp_path)) + client = await aiohttp_client(app) + + # Test the mock only applies to my_file, then test the permission error. + r = await client.get(f"/{my_readable.name}") + assert r.status == 200 + r = await client.get(f"/{my_file.name}") + assert r.status == 403 async def test_access_symlink_loop( @@ -440,33 +520,87 @@ async def test_access_symlink_loop( assert r.status == 404 -async def test_access_special_resource( +async def test_access_compressed_file_as_symlink( tmp_path: pathlib.Path, aiohttp_client: AiohttpClient ) -> None: - # Tests the access to a resource that is neither a file nor a directory. - # Checks that if a special resource is accessed (f.e. named pipe or UNIX - # domain socket) then 404 HTTP status returned. + """Test that compressed file variants as symlinks are ignored.""" + private_file = tmp_path / "private.txt" + private_file.write_text("private info") + www_dir = tmp_path / "www" + www_dir.mkdir() + gz_link = www_dir / "file.txt.gz" + gz_link.symlink_to(f"../{private_file.name}") + app = web.Application() + app.router.add_static("/", www_dir) + client = await aiohttp_client(app) + + # Symlink should be ignored; response reflects missing uncompressed file. + resp = await client.get(f"/{gz_link.stem}", auto_decompress=False) + assert resp.status == 404 + resp.release() + + # Again symlin is ignored, and then uncompressed is served. + txt_file = gz_link.with_suffix("") + txt_file.write_text("public data") + resp = await client.get(f"/{txt_file.name}") + assert resp.status == 200 + assert resp.headers.get("Content-Encoding") is None + assert resp.content_type == "text/plain" + assert await resp.text() == "public data" + resp.release() + await client.close() + + +async def test_access_special_resource( + tmp_path_factory: pytest.TempPathFactory, aiohttp_client: AiohttpClient +) -> None: + """Test access to non-regular files is forbidden using a UNIX domain socket.""" + if not getattr(socket, "AF_UNIX", None): + pytest.skip("UNIX domain sockets not supported") + + tmp_path = tmp_path_factory.mktemp("special") + my_special = tmp_path / "sock" + my_socket = socket.socket(socket.AF_UNIX) + my_socket.bind(str(my_special)) + assert my_special.is_socket() - with mock.patch("pathlib.Path.__new__") as path_constructor: - special = MagicMock() - special.is_dir.return_value = False - special.is_file.return_value = False + app = web.Application() + app.router.add_static("/", str(tmp_path)) + + client = await aiohttp_client(app) + r = await client.get(f"/{my_special.name}") + assert r.status == 403 + my_socket.close() - path = MagicMock() - path.joinpath.side_effect = lambda p: (special if p == "special" else path) - path.resolve.return_value = path - special.resolve.return_value = special - path_constructor.return_value = path +async def test_access_mock_special_resource( + monkeypatch: pytest.MonkeyPatch, + tmp_path: pathlib.Path, + aiohttp_client: AiohttpClient, +) -> None: + """Test access to non-regular files is forbidden using a mock FIFO.""" + my_special = tmp_path / "my_special" + my_special.touch() - # Register global static route: - app.router.add_static("/", str(tmp_path), show_index=True) - client = await aiohttp_client(app) + real_result = my_special.stat() + real_stat = pathlib.Path.stat - # Request the root of the static directory. - r = await client.get("/special") - assert r.status == 403 + def mock_stat(self: pathlib.Path, **kwargs: Any) -> os.stat_result: + s = real_stat(self, **kwargs) + if os.path.samestat(s, real_result): + mock_mode = S_IFIFO | S_IMODE(s.st_mode) + s = os.stat_result([mock_mode] + list(s)[1:]) + return s + + monkeypatch.setattr("pathlib.Path.stat", mock_stat) + + app = web.Application() + app.router.add_static("/", str(tmp_path)) + client = await aiohttp_client(app) + + r = await client.get(f"/{my_special.name}") + assert r.status == 403 async def test_partially_applied_handler(aiohttp_client: AiohttpClient) -> None: @@ -580,7 +714,7 @@ def test_reuse_last_added_resource(path: str) -> None: app = web.Application() async def handler(request: web.Request) -> web.Response: - return web.Response() + return web.Response() # pragma: no cover app.router.add_get(path, handler, name="a") app.router.add_post(path, handler, name="a") @@ -592,7 +726,7 @@ def test_resource_raw_match() -> None: app = web.Application() async def handler(request: web.Request) -> web.Response: - return web.Response() + return web.Response() # pragma: no cover route = app.router.add_get("/a", handler, name="a") assert route.resource is not None @@ -752,3 +886,110 @@ async def handler(request: web.Request) -> web.Response: r = await client.get(yarl.URL(urlencoded_path, encoded=True)) assert r.status == expected_http_resp_status await r.release() + + +async def test_order_is_preserved(aiohttp_client: AiohttpClient) -> None: + """Test route order is preserved. + + Note that fixed/static paths are always preferred over a regex path. + """ + app = web.Application() + + async def handler(request: web.Request) -> web.Response: + assert isinstance(request.match_info._route.resource, Resource) + return web.Response(text=request.match_info._route.resource.canonical) + + app.router.add_get("/first/x/{b}/", handler) + app.router.add_get(r"/first/{x:.*/b}", handler) + + app.router.add_get(r"/second/{user}/info", handler) + app.router.add_get("/second/bob/info", handler) + + app.router.add_get("/third/bob/info", handler) + app.router.add_get(r"/third/{user}/info", handler) + + app.router.add_get(r"/forth/{name:\d+}", handler) + app.router.add_get("/forth/42", handler) + + app.router.add_get("/fifth/42", handler) + app.router.add_get(r"/fifth/{name:\d+}", handler) + + client = await aiohttp_client(app) + + r = await client.get("/first/x/b/") + assert r.status == 200 + assert await r.text() == "/first/x/{b}/" + + r = await client.get("/second/frank/info") + assert r.status == 200 + assert await r.text() == "/second/{user}/info" + + # Fixed/static paths are always preferred over regex paths + r = await client.get("/second/bob/info") + assert r.status == 200 + assert await r.text() == "/second/bob/info" + + r = await client.get("/third/bob/info") + assert r.status == 200 + assert await r.text() == "/third/bob/info" + + r = await client.get("/third/frank/info") + assert r.status == 200 + assert await r.text() == "/third/{user}/info" + + r = await client.get("/forth/21") + assert r.status == 200 + assert await r.text() == "/forth/{name}" + + # Fixed/static paths are always preferred over regex paths + r = await client.get("/forth/42") + assert r.status == 200 + assert await r.text() == "/forth/42" + + r = await client.get("/fifth/21") + assert r.status == 200 + assert await r.text() == "/fifth/{name}" + + r = await client.get("/fifth/42") + assert r.status == 200 + assert await r.text() == "/fifth/42" + + +async def test_url_with_many_slashes(aiohttp_client: AiohttpClient) -> None: + app = web.Application() + + class MyView(web.View): + async def get(self) -> web.Response: + return web.Response() + + app.router.add_routes([web.view("/a", MyView)]) + + client = await aiohttp_client(app) + + r = await client.get("///a") + assert r.status == 200 + await r.release() + + +async def test_route_with_regex(aiohttp_client: AiohttpClient) -> None: + """Test a route with a regex preceded by a fixed string.""" + app = web.Application() + + async def handler(request: web.Request) -> web.Response: + assert isinstance(request.match_info._route.resource, Resource) + return web.Response(text=request.match_info._route.resource.canonical) + + app.router.add_get("/core/locations{tail:.*}", handler) + client = await aiohttp_client(app) + + r = await client.get("/core/locations/tail/here") + assert r.status == 200 + assert await r.text() == "/core/locations{tail}" + + r = await client.get("/core/locations_tail_here") + assert r.status == 200 + assert await r.text() == "/core/locations{tail}" + + r = await client.get("/core/locations_tail;id=abcdef") + assert r.status == 200 + assert await r.text() == "/core/locations{tail}" diff --git a/tests/test_web_websocket_functional.py b/tests/test_web_websocket_functional.py index b471b131c1e..15ef33e3648 100644 --- a/tests/test_web_websocket_functional.py +++ b/tests/test_web_websocket_functional.py @@ -313,6 +313,47 @@ async def handler(request): assert msg.type == WSMsgType.CLOSED +async def test_concurrent_close_multiple_tasks(loop: Any, aiohttp_client: Any) -> None: + srv_ws = None + + async def handler(request): + nonlocal srv_ws + ws = srv_ws = web.WebSocketResponse(autoclose=False, protocols=("foo", "bar")) + await ws.prepare(request) + + msg = await ws.receive() + assert msg.type == WSMsgType.CLOSING + + msg = await ws.receive() + assert msg.type == WSMsgType.CLOSING + + await asyncio.sleep(0) + + msg = await ws.receive() + assert msg.type == WSMsgType.CLOSED + + return ws + + app = web.Application() + app.router.add_get("/", handler) + client = await aiohttp_client(app) + + ws = await client.ws_connect("/", autoclose=False, protocols=("eggs", "bar")) + + task1 = asyncio.create_task(srv_ws.close(code=WSCloseCode.INVALID_TEXT)) + task2 = asyncio.create_task(srv_ws.close(code=WSCloseCode.INVALID_TEXT)) + + msg = await ws.receive() + assert msg.type == WSMsgType.CLOSE + + await task1 + await task2 + + await asyncio.sleep(0) + msg = await ws.receive() + assert msg.type == WSMsgType.CLOSED + + async def test_close_op_code_from_client(loop: Any, aiohttp_client: Any) -> None: srv_ws: Optional[web.WebSocketResponse] = None @@ -681,7 +722,64 @@ async def handler(request): await ws.close() -async def test_server_ws_async_for(loop, aiohttp_server) -> None: +async def test_heartbeat_no_pong_send_many_messages( + loop: Any, aiohttp_client: Any +) -> None: + """Test no pong after sending many messages.""" + + async def handler(request): + ws = web.WebSocketResponse(heartbeat=0.05) + await ws.prepare(request) + for _ in range(10): + await ws.send_str("test") + + await ws.receive() + return ws + + app = web.Application() + app.router.add_get("/", handler) + + client = await aiohttp_client(app) + ws = await client.ws_connect("/", autoping=False) + for _ in range(10): + msg = await ws.receive() + assert msg.type is aiohttp.WSMsgType.TEXT + assert msg.data == "test" + + msg = await ws.receive() + assert msg.type is aiohttp.WSMsgType.PING + await ws.close() + + +async def test_heartbeat_no_pong_receive_many_messages( + loop: Any, aiohttp_client: Any +) -> None: + """Test no pong after receiving many messages.""" + + async def handler(request): + ws = web.WebSocketResponse(heartbeat=0.05) + await ws.prepare(request) + for _ in range(10): + server_msg = await ws.receive() + assert server_msg.type is aiohttp.WSMsgType.TEXT + + await ws.receive() + return ws + + app = web.Application() + app.router.add_get("/", handler) + + client = await aiohttp_client(app) + ws = await client.ws_connect("/", autoping=False) + for _ in range(10): + await ws.send_str("test") + + msg = await ws.receive() + assert msg.type is aiohttp.WSMsgType.PING + await ws.close() + + +async def test_server_ws_async_for(loop: Any, aiohttp_server: Any) -> None: closed = loop.create_future() async def handler(request):