diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..3eb2612 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,109 @@ +name: CI +on: + push: + branches: [ main ] + pull_request: + branches: [ main ] + workflow_dispatch: + +jobs: + tox: + name: ${{ matrix.name }} + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + include: + - {name: '3.12', python: '3.12', tox: py312} + - {name: '3.11', python: '3.11', tox: py311} + - {name: '3.10', python: '3.10', tox: py310} + - {name: '3.9', python: '3.9', tox: py39} + - {name: '3.8', python: '3.8', tox: py38} + - {name: 'format', python: '3.12', tox: format} + - {name: 'mypy', python: '3.12', tox: mypy} + - {name: 'pep8', python: '3.12', tox: pep8} + - {name: 'package', python: '3.12', tox: package} + + steps: + - uses: actions/checkout@v4 + + - uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python }} + + - name: update pip + run: | + pip install -U wheel + pip install -U setuptools + python -m pip install -U pip + - run: pip install tox + + - run: tox -e ${{ matrix.tox }} + + + h2spec: + name: ${{ matrix.name }} + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + include: + - {name: 'asyncio', worker: 'asyncio'} + - {name: 'trio', worker: 'trio'} + + steps: + - uses: actions/checkout@v4 + + - uses: actions/setup-python@v5 + with: + python-version: "3.12" + + - name: update pip + run: | + pip install -U wheel + pip install -U setuptools + python -m pip install -U pip + - run: pip install trio . + + - name: Run server + working-directory: compliance/h2spec + run: nohup hypercorn --keyfile key.pem --certfile cert.pem -k ${{ matrix.worker }} server:app & + + - name: Download h2spec + run: | + wget https://github.com/summerwind/h2spec/releases/download/v2.6.0/h2spec_linux_amd64.tar.gz + tar -xvf h2spec_linux_amd64.tar.gz + + - name: Run h2spec + run: ./h2spec -tk -h 127.0.0.1 -p 8000 -o 10 + + autobahn: + name: ${{ matrix.name }} + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + include: + - {name: 'asyncio', worker: 'asyncio'} + - {name: 'trio', worker: 'trio'} + + steps: + - uses: actions/checkout@v4 + + - uses: actions/setup-python@v5 + with: + python-version: "3.10" + + - name: update pip + run: | + pip install -U wheel + pip install -U setuptools + python -m pip install -U pip + - run: python3 -m pip install trio . + - name: Run server + working-directory: compliance/autobahn + run: nohup hypercorn -k ${{ matrix.worker }} server:app & + + - name: Run Unit Tests + working-directory: compliance/autobahn + run: docker run --rm --network=host -v "${PWD}/:/config" -v "${PWD}/reports:/reports" --name fuzzingclient crossbario/autobahn-testsuite wstest -m fuzzingclient -s /config/fuzzingclient.json && python3 summarise.py diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml new file mode 100644 index 0000000..5e011a7 --- /dev/null +++ b/.github/workflows/publish.yml @@ -0,0 +1,38 @@ +name: Publish +on: + push: + tags: + - '*' +jobs: + build: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + + - uses: actions/setup-python@v3 + with: + python-version: 3.12 + + - run: | + pip install poetry + poetry build + - uses: actions/upload-artifact@v3 + with: + path: ./dist + + pypi-publish: + needs: ['build'] + environment: 'publish' + + name: upload release to PyPI + runs-on: ubuntu-latest + permissions: + # IMPORTANT: this permission is mandatory for trusted publishing + id-token: write + steps: + - uses: actions/download-artifact@v3 + + - name: Publish package distributions to PyPI + uses: pypa/gh-action-pypi-publish@release/v1 + with: + packages_dir: artifact/ diff --git a/.gitignore b/.gitignore index 8ef153f..c436bdd 100644 --- a/.gitignore +++ b/.gitignore @@ -13,3 +13,5 @@ docs/reference/source/ dist/ .coverage poetry.lock +.idea/ +.DS_Store diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml deleted file mode 100644 index 47f8680..0000000 --- a/.gitlab-ci.yml +++ /dev/null @@ -1,80 +0,0 @@ -py37: - image: python:3.7 - script: - - pip install tox - - tox -e py37 - -py38: - image: python:3.8 - script: - - pip install tox - - tox -e py38 - -py39: - image: python:3.9 - script: - - pip install tox - - tox -e py39 - -py310: - image: python:3.10 - script: - - pip install tox - - tox -e docs,format,mypy,py310,package,pep8 - -pages: - image: python:3.10 - script: - - pip install sphinx pydata-sphinx-theme . - - rm -rf docs/source - - sphinx-apidoc -e -f -o docs/reference/source/ src/hypercorn - - sphinx-build -b html docs/ docs/_build/html/ - - mv docs/_build/html/ public/ - artifacts: - paths: - - public - rules: - - if: $CI_COMMIT_BRANCH == $CI_DEFAULT_BRANCH - -.h2spec-script: &h2spec-script - image: python:3.10 - script: - - python3 -m pip install trio . - - cd compliance/h2spec && nohup hypercorn --keyfile key.pem --certfile cert.pem -k $WORKER_CLASS server:App & - - wget https://github.com/summerwind/h2spec/releases/download/v2.2.0/h2spec_linux_amd64.tar.gz - - tar -xvf h2spec_linux_amd64.tar.gz - - sleep 10 - - ./h2spec -tk -h 127.0.0.1 -p 8000 -o 10 - -h2spec: - <<: *h2spec-script - variables: - WORKER_CLASS: "asyncio" - -h2spec-trio: - <<: *h2spec-script - variables: - WORKER_CLASS: "trio" - -.autobahn-script: &autobahn-script - image: python:2.7.16-alpine3.10 - script: - - apk --update add build-base libressl libressl-dev ca-certificates libffi-dev python3 python3-dev - - pip install pyopenssl==19.1.0 cryptography==2.3.1 autobahntestsuite - - python3 -m pip install trio . - - cd compliance/autobahn && nohup hypercorn -k $WORKER_CLASS server:App & - - while ! netstat -l -t | grep -q 8000; do sleep 1; done - - cd compliance/autobahn && wstest -m fuzzingclient && python summarise.py - artifacts: - paths: - - compliance/autobahn/reports/servers/ - -autobahn: - <<: *autobahn-script - variables: - WORKER_CLASS: "asyncio" - -autobahn-trio: - <<: *autobahn-script - variables: - WORKER_CLASS: "trio" diff --git a/.readthedocs.yaml b/.readthedocs.yaml new file mode 100644 index 0000000..7ae0e68 --- /dev/null +++ b/.readthedocs.yaml @@ -0,0 +1,16 @@ +version: 2 + +build: + os: ubuntu-22.04 + tools: + python: "3.11" + +python: + install: + - method: pip + path: . + extra_requirements: + - docs + +sphinx: + configuration: docs/conf.py diff --git a/CHANGELOG.rst b/CHANGELOG.rst index aff3f5b..bf32d1b 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -1,3 +1,136 @@ +0.17.3 2024-05-28 +----------------- + +* Restore set TCP_NODELAY on TCP sockets +* Support uvloop >= 0.18 and the loop_factory argument +* Bugfix ensure ExceptionGroup lifespan failures crash the server. + +0.17.2 2024-05-27 +----------------- + +* Bugfix pass the correct quic connection to the H3 Protocol. + +0.17.1 2024-05-27 +----------------- + +* Bugfix revert set TCP_NODELAY on sockets. + +0.17.0 2024-05-27 +----------------- + +* Set TCP_NODELAY on sockets. +* Support sending trailing headers on h2/h3. +* Add support for lifespan state. +* Allow sending of the response before body data arrives. +* Bugfix properly set host header to ascii string in + ProxyFixMiddleware. +* Bugfix encode headers using latin-1. +* Bugfix don't double-access log if the response was sent. +* Bugfix a statsd logging bug. +* Bugfix handle already-closed on StreamEnded. +* Bugfix send a 400 response if data is received before the websocket + is accepted. +* Bugfix ensure only a single QUIC timer task per connection. +* Bugfix ensure responses are sent with empty bodies for WSGI. + +0.16.0 2024-01-01 +----------------- + +* Add a max keep alive requests configuration option, this mitigates + the HTTP/2 rapid reset attack. +* Return subprocess exit code if non-zero. +* Add ProxyFix middleware to make it easier to run Hypercorn behind a + proxy. +* Support restarting workers after max requests to make it easier to + manage memory leaks in apps. +* Bugfix ensure the idle task is stopped on error. +* Bugfix revert autoreload error because reausing old sockets. +* Bugfix send the hinted error from h11 on RemoteProtocolErrors. +* Bugfix handle asyncio.CancelledError when socket is closed without + flushing. +* Bugfix improve WSGI compliance by closing iterators, only sending + headers on first response byte, erroring if ``start_response`` is + not called, and switching wsgi.errors to stdout. +* Don't error on LocalProtoclErrors for ws streams to better cope with + race conditions. + +0.15.0 2023-10-29 +----------------- + +* Improve the NoAppError to help diagnose why the app has not been + found. +* Log cancelled requests as well as successful to aid diagnositics of + failures. +* Use more modern asyncio apis. This will hopefully fix reported + memory leak issues. +* Bugfix only load the application in the main process if the reloader + is being used. +* Bugfix Autoreload error because reausing old sockets. +* Bugfix scope client usage for sock binding. +* Bugfix disable multiprocessing if number of workers is 0 to support + systems that don't support multiprocessing. + +0.14.4 2023-07-08 +----------------- + +* Bugfix Use tomllib/tomli for .toml support replacing the + unmaintained toml library. +* Bugfix server hanging on startup failure. +* Bugfix close websocket with 1011 on internal error (1006 is a + client-only code). +* Bugfix support trio > 0.22 utilising exception groups (note trio <= + 0.22 is not supported). +* Bugfix except ConnectionAbortedError which can be raised on Windows + machines. +* Bugfix ensure that closed is sent on reading end. +* Bugfix handle read_timeout exception on trio. +* Support and test against Python 3.11. +* Add explanation of PicklingErrors. +* Add config option to pass raw h11 headers. + +0.14.3 2022-09-04 +----------------- + +* Revert Preserve response headers casing for HTTP/1 as this breaks + ASGI frameworks. +* Bugfix stream WSGI responses + +0.14.2 2022-09-03 +----------------- + +* Bugfix add missing ASGI version to lifespan scope. +* Bugfix preserve the HTTP/1 request header casing through to the ASGI + app. +* Bugifx ensure the config loglevel is respected. +* Bugfix ensure new processes are spawned not forked. +* Bugfix ignore dunder vars in config objects. +* Bugfix clarify the subprotocol exception. + +0.14.1 2022-08-29 +----------------- + +* Fix Python3.7 compatibility. + +0.14.0 2022-08-29 +----------------- + +* Bugfix only recycle a HTTP/1.1 connection if client is DONE. +* Bugfix uvloop may raise a RuntimeError. +* Bugfix ensure 100ms sleep between Windows workers starting. +* Bugfix ensure lifespan shutdowns occur. +* Bugfix close idle Keep-Alive connections on graceful exit. +* Bugfix don't suppress 412 bodies. +* Bugfix don't idle close upgrade requests. +* Allow control over date header addition. +* Allow for logging configuration to be loaded from JSON or TOML + files. +* Preserve response headers casing for HTTP/1. +* Support the early hint ASGI-extension. +* Alter the process and reloading system such that it should work + correctly in all configurations. +* Directly support serving WSGI applications (and drop support for + ASGI-2, now ASGI-3 only). + 0.13.2 2021-12-23 ----------------- diff --git a/README.rst b/README.rst index ac4cb1d..3c676b9 100644 --- a/README.rst +++ b/README.rst @@ -1,19 +1,19 @@ Hypercorn ========= -.. image:: https://assets.gitlab-static.net/pgjones/hypercorn/raw/main/artwork/logo.png +.. image:: https://github.com/pgjones/hypercorn/raw/main/artwork/logo.png :alt: Hypercorn logo |Build Status| |docs| |pypi| |http| |python| |license| Hypercorn is an `ASGI -`_ web -server based on the sans-io hyper, `h11 +`_ and +WSGI web server based on the sans-io hyper, `h11 `_, `h2 `_, and `wsproto `_ libraries and inspired by Gunicorn. Hypercorn supports HTTP/1, HTTP/2, WebSockets (over HTTP/1 -and HTTP/2), ASGI/2, and ASGI/3 specifications. Hypercorn can utilise +and HTTP/2), ASGI, and WSGI specifications. Hypercorn can utilise asyncio, uvloop, or trio worker types. Hypercorn can optionally serve the current draft of the HTTP/3 @@ -24,8 +24,8 @@ choose a quic binding e.g. ``hypercorn --quic-bind localhost:4433 ...``. Hypercorn was initially part of `Quart -`_ before being separated out into a -standalone ASGI server. Hypercorn forked from version 0.5.0 of Quart. +`_ before being separated out into a +standalone server. Hypercorn forked from version 0.5.0 of Quart. Quickstart ---------- @@ -37,7 +37,7 @@ Hypercorn can be installed via `pip $ pip install hypercorn -and requires Python 3.7.0 or higher. +and requires Python 3.8 or higher. With hypercorn installed ASGI frameworks (or apps) can be served via Hypercorn via the command line, @@ -59,19 +59,19 @@ Alternatively Hypercorn can be used programatically, asyncio.run(serve(app, Config())) learn more (including a Trio example of the above) in the `API usage -`_ +`_ docs. Contributing ------------ -Hypercorn is developed on `GitLab -`_. If you come across an issue, +Hypercorn is developed on `Github +`_. If you come across an issue, or have a feature request please open an `issue -`_. If you want to +`_. If you want to contribute a fix or the feature-implementation please do (typo fixes -welcome), by proposing a `merge request -`_. +welcome), by proposing a `pull request +`_. Testing ~~~~~~~ @@ -89,17 +89,17 @@ this will check the code style and run the tests. Help ---- -The Hypercorn `documentation `_ -is the best place to start, after that try searching stack overflow, -if you still can't find an answer please `open an issue -`_. +The Hypercorn `documentation `_ is +the best place to start, after that try searching stack overflow, if +you still can't find an answer please `open an issue +`_. -.. |Build Status| image:: https://gitlab.com/pgjones/hypercorn/badges/main/pipeline.svg - :target: https://gitlab.com/pgjones/hypercorn/commits/main +.. |Build Status| image:: https://github.com/pgjones/hypercorn/actions/workflows/ci.yml/badge.svg + :target: https://github.com/pgjones/hypercorn/commits/main .. |docs| image:: https://img.shields.io/badge/docs-passing-brightgreen.svg - :target: https://pgjones.gitlab.io/hypercorn/ + :target: https://hypercorn.readthedocs.io .. |pypi| image:: https://img.shields.io/pypi/v/hypercorn.svg :target: https://pypi.python.org/pypi/Hypercorn/ @@ -111,4 +111,4 @@ if you still can't find an answer please `open an issue :target: https://pypi.python.org/pypi/Hypercorn/ .. |license| image:: https://img.shields.io/badge/license-MIT-blue.svg - :target: https://gitlab.com/pgjones/hypercorn/blob/main/LICENSE + :target: https://github.com/pgjones/hypercorn/blob/main/LICENSE diff --git a/compliance/autobahn/server.py b/compliance/autobahn/server.py index 5bd5d0d..1d22ea2 100644 --- a/compliance/autobahn/server.py +++ b/compliance/autobahn/server.py @@ -1,23 +1,18 @@ -class App: - - def __init__(self, scope): - pass - - async def __call__(self, receive, send): - while True: - event = await receive() - if event['type'] == 'websocket.disconnect': - break - elif event['type'] == 'websocket.connect': - await send({'type': 'websocket.accept'}) - elif event['type'] == 'websocket.receive': - await send({ - 'type': 'websocket.send', - 'bytes': event['bytes'], - 'text': event['text'], - }) - elif event['type'] == 'lifespan.startup': - await send({'type': 'lifespan.startup.complete'}) - elif event['type'] == 'lifespan.shutdown': - await send({'type': 'lifespan.shutdown.complete'}) - break +async def app(scope, receive, send): + while True: + event = await receive() + if event['type'] == 'websocket.disconnect': + break + elif event['type'] == 'websocket.connect': + await send({'type': 'websocket.accept'}) + elif event['type'] == 'websocket.receive': + await send({ + 'type': 'websocket.send', + 'bytes': event['bytes'], + 'text': event['text'], + }) + elif event['type'] == 'lifespan.startup': + await send({'type': 'lifespan.startup.complete'}) + elif event['type'] == 'lifespan.shutdown': + await send({'type': 'lifespan.shutdown.complete'}) + break diff --git a/compliance/h2spec/server.py b/compliance/h2spec/server.py index 2f87ce3..3c3a433 100644 --- a/compliance/h2spec/server.py +++ b/compliance/h2spec/server.py @@ -1,30 +1,25 @@ -class App: +async def app(scope, receive, send): + while True: + event = await receive() + if event['type'] == 'http.disconnect': + break + elif event['type'] == 'http.request' and not event.get('more_body', False): + await send_data(send) + break + elif event['type'] == 'lifespan.startup': + await send({'type': 'lifespan.startup.complete'}) + elif event['type'] == 'lifespan.shutdown': + await send({'type': 'lifespan.shutdown.complete'}) + break - def __init__(self, scope): - pass - - async def __call__(self, receive, send): - while True: - event = await receive() - if event['type'] == 'http.disconnect': - break - elif event['type'] == 'http.request' and not event.get('more_body', False): - await self.send_data(send) - break - elif event['type'] == 'lifespan.startup': - await send({'type': 'lifespan.startup.complete'}) - elif event['type'] == 'lifespan.shutdown': - await send({'type': 'lifespan.shutdown.complete'}) - break - - async def send_data(self, send): - await send({ - 'type': 'http.response.start', - 'status': 200, - 'headers': [(b'content-length', b'5')], - }) - await send({ - 'type': 'http.response.body', - 'body': b'Hello', - 'more_body': False, - }) +async def send_data(send): + await send({ + 'type': 'http.response.start', + 'status': 200, + 'headers': [(b'content-length', b'5')], + }) + await send({ + 'type': 'http.response.body', + 'body': b'Hello', + 'more_body': False, + }) diff --git a/docs/conf.py b/docs/conf.py index e7c54b1..45b063e 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -32,7 +32,7 @@ # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. -extensions = ['sphinx.ext.autodoc', 'sphinx.ext.napoleon'] +extensions = ['sphinx.ext.autodoc', 'sphinx.ext.napoleon', 'sphinxcontrib.mermaid'] # Add any paths that contain templates here, relative to this directory. # templates_path = ['_templates'] @@ -65,7 +65,7 @@ # # This is also used if you do content translation via gettext catalogs. # Usually you set "language" from the command line for these cases. -language = None +language = "en" # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. @@ -93,14 +93,14 @@ # html_theme_options = { "external_links": [ - {"name": "Source code", "url": "https://gitlab.com/pgjones/hypercorn"}, - {"name": "Issues", "url": "https://gitlab.com/pgjones/hypercorn/issues"}, + {"name": "Source code", "url": "https://github.com/pgjones/hypercorn"}, + {"name": "Issues", "url": "https://github.com/pgjones/hypercorn/issues"}, ], "icon_links": [ { - "name": "GitLab", - "url": "https://gitlab.com/pgjones/hypercorn", - "icon": "fab fa-gitlab", + "name": "Github", + "url": "https://github.com/pgjones/hypercorn", + "icon": "fab fa-github", }, ], } diff --git a/docs/discussion/dos_mitigations.rst b/docs/discussion/dos_mitigations.rst index 358ba98..88cc48b 100644 --- a/docs/discussion/dos_mitigations.rst +++ b/docs/discussion/dos_mitigations.rst @@ -169,3 +169,14 @@ data that it cannot send to the client. To mitigate this Hypercorn responds to the backpressure and pauses (blocks) the coroutine writing the response. + +Rapid reset +^^^^^^^^^^^ + +This attack works by opening and closing streams in quick succession +in the expectation that this is more costly for the server than the +client. + +To mitigate Hypercorn will only allow a maximum number of requests per +kept-alive connection before closing it. This ensures that cost of the +attack is equally born by the client. diff --git a/docs/discussion/flow.rst b/docs/discussion/flow.rst new file mode 100644 index 0000000..4f3b416 --- /dev/null +++ b/docs/discussion/flow.rst @@ -0,0 +1,49 @@ +Flow +==== + +These are the expected event flows/sequences. + +H11/H2 +------ + +A typical HTTP/1 or HTTP/2 request with response with the connection +specified to close on response. + +.. mermaid:: + + sequenceDiagram + TCPServer->>H11/H2: RawData + H11/H2->>HTTPStream: Request + H11/H2->>HTTPStream: Body + HTTPStream->>App: http.request[more_body=True] + H11/H2->>HTTPStream: EndBody + HTTPStream->>App: http.request[more_body=False] + App->>HTTPStream: http.response.start + App->>HTTPStream: http.response.body + HTTPStream->>H11/H2: Response + H11/H2->>TCPServer: RawData + HTTPStream->>H11/H2: Body + H11/H2->>TCPServer: RawData + HTTPStream->>H11/H2: EndBody + H11/H2->>TCPServer: RawData + H11/H2->>HTTPStream: StreamClosed + HTTPStream->>App: http.disconnect + H11/H2->>TCPServer: Closed + + +H11 early client cancel +----------------------- + +The flow as expected if the connection is closed before the server has +the opportunity to respond. + +.. mermaid:: + + sequenceDiagram + TCPServer->>H11/H2: RawData + H11/H2->>HTTPStream: Request + H11/H2->>HTTPStream: Body + HTTPStream->>App: http.request[more_body=True] + TCPServer->>H11/H2: Closed + H11/H2->>HTTPStream: StreamClosed + HTTPStream->>App: http.disconnect diff --git a/docs/discussion/index.rst b/docs/discussion/index.rst index aa1bd0e..509e1c2 100644 --- a/docs/discussion/index.rst +++ b/docs/discussion/index.rst @@ -9,5 +9,6 @@ Discussions closing.rst design_choices.rst dos_mitigations.rst + flow.rst http2.rst workers.rst diff --git a/docs/how_to_guides/api_usage.rst b/docs/how_to_guides/api_usage.rst index abdb8d5..6947dcd 100644 --- a/docs/how_to_guides/api_usage.rst +++ b/docs/how_to_guides/api_usage.rst @@ -7,9 +7,8 @@ Most usage of Hypercorn is expected to be via the command line, as explained in the :ref:`usage` documentation. Alternatively it is possible to use Hypercorn programmatically via the ``serve`` function available for either the asyncio or trio :ref:`workers` (note the -asyncio ``serve`` can be used with uvloop). In Python 3.7, or better, -this can be done as follows, first you need to create a Hypercorn -Config instance, +asyncio ``serve`` can be used with uvloop). This can be done as +follows, first you need to create a Hypercorn Config instance, .. code-block:: python @@ -18,8 +17,8 @@ Config instance, config = Config() config.bind = ["localhost:8080"] # As an example configuration setting -Then assuming you have an ASGI framework instance called ``app``, -using asyncio, +Then assuming you have an ASGI or WSGI framework instance called +``app``, using asyncio, .. code-block:: python @@ -115,3 +114,10 @@ exception handler, loop.default_exception_handler(context) loop.set_exception_handler(_exception_handler) + +Forcing ASGI or WSGI mode +------------------------- + +The ``serve`` function takes a ``mode`` argument that can be +``"asgi"`` or ``"wsgi"`` to force the app to be considered ASGI or +WSGI as required. diff --git a/docs/how_to_guides/configuring.rst b/docs/how_to_guides/configuring.rst index c47308a..ab3a07c 100644 --- a/docs/how_to_guides/configuring.rst +++ b/docs/how_to_guides/configuring.rst @@ -71,20 +71,20 @@ can be used, Configuration options ===================== -========================== ============================= ========================================== -Attribute Command line Purpose --------------------------- ----------------------------- ------------------------------------------ +========================== ============================= =============================================== ======================== +Attribute Command line Purpose Default +-------------------------- ----------------------------- ----------------------------------------------- ------------------------ access_log_format ``--access-logformat`` The log format for the access log, see :ref:`how_to_log`. accesslog ``--access-logfile`` The target logger for access logs, use ``-`` for stdout. -alpn_protocols N/A The HTTP protocols to advertise over +alpn_protocols N/A The HTTP protocols to advertise over ``h2`` and ``http/1.1`` ALPN. alt_svc_headers N/A List of header values to return as Alt-Svc headers. -application_path N/A The path location of the ASGI - application, defaults to cwd. -backlog ``--backlog`` The maximum number of pending +application_path N/A The path location of the ASGI cwd + application. +backlog ``--backlog`` The maximum number of pending 100 connections. bind ``-b``, ``--bind`` The TCP host/address to bind to. Should be either host:port, host, @@ -94,8 +94,8 @@ bind ``-b``, ``--bind`` The TCP host/address to respectively. ca_certs ``--ca-certs`` Path to the SSL CA certificate file. certfile ``--certfile`` Path to the SSL certificate file. -ciphers ``--ciphers`` Ciphers to use for the SSL setup. -debug ``--debug`` Enable debug mode, i.e. extra logging +ciphers ``--ciphers`` Ciphers to use for the SSL setup. ``ECDHE+AESGCM`` +debug ``--debug`` Enable debug mode, i.e. extra logging ``False`` and checks. dogstatsd_tags N/A DogStatsd format tag, see :ref:`using_statsd`. @@ -103,30 +103,48 @@ errorlog ``--error-logfile`` The target location for ``--log-file`` use `-` for stderr. graceful_timeout ``--graceful-timeout`` Time to wait after SIGTERM or Ctrl-C for any remaining requests (tasks) to - complete. +read_timeout ``--read-timeout`` Seconds to wait before timing out reads No timeout. + on TCP sockets. group ``-g``, ``--group`` Group to own any unix sockets. -h11_max_incomplete_size N/A The max HTTP/2 request line + headers size - in bytes. -h2_max_concurrent_streams N/A Maximum number of HTTP/2 concurrent +h11_max_incomplete_size N/A The max HTTP/1.1 request line + headers 16KiB + size in bytes. +h11_pass_raw_headers N/A Pass the raw headers from h11 to the ``False`` + Request object, which preserves header + casing. +h2_max_concurrent_streams N/A Maximum number of HTTP/2 concurrent 100 streams. -h2_max_header_list_size N/A Maximum number of HTTP/2 headers. -h2_max_inbound_frame_size N/A Maximum size of a HTTP/2 frame. -include_server_header N/A Include the ``Server: Hypercorn`` header, - default True. +h2_max_header_list_size N/A Maximum number of HTTP/2 headers. 65536 +h2_max_inbound_frame_size N/A Maximum size of a HTTP/2 frame. 16KiB +include_date_header N/A Include the ``True`` + ``Date: Tue, 15 Nov 1994 08:12:31 GMT`` + header. +include_server_header N/A Include the ``Server: Hypercorn`` header. ``True`` insecure_bind ``--insecure-bind`` The TCP host/address to bind to. SSL options will not apply to these binds. See *bind* for formatting options. Care must be taken! See HTTP -> HTTPS redirection docs. -keep_alive_timeout ``--keep-alive`` Seconds to keep inactive connections alive +keep_alive_max_requests N/A Maximum number of requests before connection 1000 + is closed. HTTP/1 & HTTP/2 only. +keep_alive_timeout ``--keep-alive`` Seconds to keep inactive connections alive 5s before closing. keyfile ``--keyfile`` Path to the SSL key file. -logconfig ``--log-config`` A Python logging configuration file. +keyfile_password ``--keyfile-password`` Password for the keyfile if the keyfile is + password-protected. +logconfig ``--log-config`` A Python logging configuration file. This The logging ini format. + can be prefixed with 'json:' or 'toml:' to + load the configuration from a file in that + format. logconfig_dict N/A A Python logging configuration dictionary. logger_class N/A Type of class to use for logging. -loglevel ``--log-level`` The (error) log level. -max_app_queue_size N/A The maximum number of events to queue up +loglevel ``--log-level`` The (error) log level. ``INFO`` +max_app_queue_size N/A The maximum number of events to queue up 10 sending to the ASGI application. +max_requests ``--max-requests`` Maximum number of requests a worker will + process before restarting. +max_requests_jitter ``--max-requests-jitter`` This jitter causes the max-requests per worker 0 + to be randomized by + ``randint(0, max_requests_jitter)`` pid_path ``-p``, ``--pid`` Location to write the PID (Program ID) to. quic_bind ``--quic-bind`` The UDP/QUIC host/address to bind to. See *bind* for formatting options. @@ -135,11 +153,11 @@ root_path ``--root-path`` The setting for the ASG server_names ``--server-name`` The hostnames that can be served, requests to different hosts will be responded to with 404s. -shutdown_timeout N/A Timeout when waiting for Lifespan +shutdown_timeout N/A Timeout when waiting for Lifespan 60s shutdowns to complete. -ssl_handshake_timeout N/A Timeout when waiting for SSL handshakes to +ssl_handshake_timeout N/A Timeout when waiting for SSL handshakes to 60s complete. -startup_timeout N/A Timeout when waiting for Lifespan +startup_timeout N/A Timeout when waiting for Lifespan 60s startups to complete. statsd_host ``--statsd-host`` The host:port of the statsd server. statsd_prefix ``--statsd-prefix`` Prefix for all statsd messages. @@ -151,7 +169,7 @@ verify_flags N/A SSL context verify flag verify_mode ``--verify-mode`` SSL verify mode for peer's certificate, see ssl.VerifyMode enum for possible values. -websocket_max_message_size N/A Maximum size of a WebSocket frame. +websocket_max_message_size N/A Maximum size of a WebSocket frame. 16MiB websocket_ping_interval ``--websocket-ping-interval`` If set this is the time in seconds between pings sent to the client. This can be used to keep the websocket connection alive. @@ -159,5 +177,7 @@ worker_class ``-k``, ``--worker-class`` The type of worker to u asyncio, uvloop (pip install hypercorn[uvloop]), and trio (pip install hypercorn[trio]). -workers ``-w``, ``--workers`` The number of workers to spawn and use. -========================== ============================= ========================================== +workers ``-w``, ``--workers`` The number of workers to spawn and use. 1 +wsgi_max_body_size N/A The maximum size of a body that will be 16MiB + accepted in WSGI mode. +========================== ============================= =============================================== ======================== diff --git a/docs/how_to_guides/index.rst b/docs/how_to_guides/index.rst index bccdd54..9f4bf2d 100644 --- a/docs/how_to_guides/index.rst +++ b/docs/how_to_guides/index.rst @@ -11,6 +11,7 @@ How to guides dispatch_apps.rst http_https_redirect.rst logging.rst + proxy_fix.rst server_names.rst statsd.rst wsgi_apps.rst diff --git a/docs/how_to_guides/logging.rst b/docs/how_to_guides/logging.rst index f4070ab..9e16e07 100644 --- a/docs/how_to_guides/logging.rst +++ b/docs/how_to_guides/logging.rst @@ -8,6 +8,19 @@ default neither will actively log. The special value of ``-`` can be used as the logging target in order to log to stdout and stderr respectively. Any other value is considered a filepath to target. +Configuring the Python logger +----------------------------- + +The Python logger can be configured using the ``logconfig`` or +``logconfig_dict`` configuration attributes. The latter, +``logconfig_dict`` will be passed to ``dictConfig`` after the loggers +have been created. + +The ``logconfig`` variable should point at a file to be used by the +``fileConfig`` function. Alternatively it can point to a JSON or TOML +formatted file which will be loaded and passed to the ``dictConfig`` +function. To use a JSON formatted file prefix the filepath with +``json:`` and for TOML use ``toml:``. Configuring access logs ----------------------- @@ -55,10 +68,10 @@ p process ID {Variable}e environment variable =========== =========== -Customising the access logger ------------------------------ +Customising the logger +---------------------- -The acces logger class can be customised by changing the -``access_logger_class`` attribute of the ``Config`` class. This is -only possible when using the python based configuration file. The -``hypercorn.logging.AccessLogger`` class is used by default. +The logger class can be customised by changing the ``logger_class`` +attribute of the ``Config`` class. This is only possible when using +the python based configuration file. The +``hypercorn.logging.Logger`` class is used by default. diff --git a/docs/how_to_guides/proxy_fix.rst b/docs/how_to_guides/proxy_fix.rst new file mode 100644 index 0000000..ca7e6f7 --- /dev/null +++ b/docs/how_to_guides/proxy_fix.rst @@ -0,0 +1,38 @@ +Fixing proxy headers +==================== + +If you are serving Hypercorn behind a proxy e.g. a load balancer the +client-address, scheme, and host-header will match that of the +connection between the proxy and Hypercorn rather than the user-agent +(client). However, most proxies provide headers with the original +user-agent (client) values which can be used to "fix" the headers to +these values. + +Modern proxies should provide this information via a ``Forwarded`` +header from `RFC 7239 +`_. However, this is +rare in practice with legacy proxies using a combination of +``X-Forwarded-For``, ``X-Forwarded-Proto`` and +``X-Forwarded-Host``. It is important that you chose the correct mode +(legacy, or modern) based on the proxy you use. + +To use the proxy fix middleware behind a single legacy proxy simply +wrap your app and serve the wrapped app, + +.. code-block:: python + + from hypercorn.middleware import ProxyFixMiddleware + + fixed_app = ProxyFixMiddleware(app, mode="legacy", trusted_hops=1) + +.. warning:: + + The mode and number of trusted hops must match your setup or the + user-agent (client) may be trusted and hence able to set + alternative for, proto, and host values. This can, depending on + your usage in the app, lead to security vulnerabilities. + +The ``trusted_hops`` argument should be set to the number of proxies +that are chained in front of Hypercorn. You should set this to how +many proxies are setting the headers so the middleware knows what to +trust. diff --git a/docs/how_to_guides/wsgi_apps.rst b/docs/how_to_guides/wsgi_apps.rst index 65d2167..df6b72b 100644 --- a/docs/how_to_guides/wsgi_apps.rst +++ b/docs/how_to_guides/wsgi_apps.rst @@ -3,38 +3,38 @@ Serve WSGI applications ======================= -Hypercorn directly serves ASGI applications, but it can be used to -serve WSGI applications by using ``AsyncioWSGIMiddleware`` or -``TrioWSGIMiddleware`` middleware. To do so simply wrap the WSGI -app with the appropriate middleware for the hypercorn worker, +Hypercorn directly serves WSGI applications: -.. code-block:: python +.. code-block:: shell - from hypercorn.middleware import AsyncioWSGIMiddleware, TrioWSGIMiddleware + $ hypercorn module:wsgi_app - asyncio_app = AsyncioWSGIMiddleware(wsgi_app) - trio_app = TrioWSGIMiddleware(wsgi_app) +WSGI Middleware +--------------- -which can then be served by hypercorn, +If a WSGI application is being combined with ASGI middleware it is +best to use either ``AsyncioWSGIMiddleware`` or ``TrioWSGIMiddleware`` +middleware. To do so simply wrap the WSGI app with the appropriate +middleware for the hypercorn worker, -.. code-block:: shell +.. code-block:: python - $ hypercorn module:asyncio_app - $ hypercorn --worker-class trio module:trio_app + from hypercorn.middleware import AsyncioWSGIMiddleware, TrioWSGIMiddleware -.. warning:: + asyncio_app = AsyncioWSGIMiddleware(wsgi_app) + trio_app = TrioWSGIMiddleware(wsgi_app) - The full response from the WSGI app will be stored in memory - before being sent. This prevents the WSGI app from streaming a - response. +which can then be passed to other middleware served by hypercorn, Limiting the request body size ------------------------------ As the request body is stored in memory before being processed it is -important to limit the max size. Both the ``AsyncioWSGIMiddleware`` -and ``TrioWSGIMiddleware`` have a default max size that can be -configured, +important to limit the max size. This is configured by the +``wsgi_max_body_size`` configuration attribute. + +When using middleware the ``AsyncioWSGIMiddleware`` and +``TrioWSGIMiddleware`` have a default max size that can be configured, .. code-block:: python diff --git a/docs/index.rst b/docs/index.rst index 1973c4c..36cfdeb 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -17,14 +17,14 @@ and HTTP/2), ASGI/2, and ASGI/3 specifications. Hypercorn can utilise asyncio, uvloop, or trio worker types. Hypercorn was initially part of `Quart -`_ before being separated out into a +`_ before being separated out into a standalone ASGI server. Hypercorn forked from version 0.5.0 of Quart. -Hypercorn is developed on `GitLab -`_. You are very welcome to -open `issues `_ or -propose `merge requests -`_. +Hypercorn is developed on `Github +`_. You are very welcome to +open `issues `_ or +propose `pull requests +`_. Contents -------- diff --git a/docs/tutorials/installation.rst b/docs/tutorials/installation.rst index d052ca7..6a7b8db 100644 --- a/docs/tutorials/installation.rst +++ b/docs/tutorials/installation.rst @@ -3,24 +3,9 @@ Installation ============ -Hypercorn is only compatible with Python 3.7 or higher and can be -installed using pipenv or your favorite python package manager. +Hypercorn is only compatible with Python 3.8 or higher and can be +installed using pip or your favorite python package manager. .. code-block:: sh - pipenv install hypercorn - -It is sufficient to run this single command in your working directory. Besides -installing dependency, it will also create a Pipfile if one doesn't exist yet -along with a linked virtualenv. Now you'll be able to activate your virtualenv -using: - -.. code-block:: sh - - pipenv shell - -To learn more about it visit `pipenv docs -`_ - -If you do not have Python 3.7 or better an error message ``Python 3.7 -is the minimum required version`` will be displayed. + pip install hypercorn diff --git a/docs/tutorials/usage.rst b/docs/tutorials/usage.rst index 7e835ae..c6c0fc4 100644 --- a/docs/tutorials/usage.rst +++ b/docs/tutorials/usage.rst @@ -9,10 +9,13 @@ Hypercorn is invoked via the command line script ``hypercorn`` $ hypercorn [OPTIONS] MODULE_APP -with ``MODULE_APP`` has the pattern +where ``MODULE_APP`` has the pattern ``$(MODULE_NAME):$(VARIABLE_NAME)`` with the module name as a full (dotted) path to a python module containing a named variable that -conforms to the ASGI framework specification. +conforms to the ASGI or WSGI framework specifications. -See :ref:`how_to_configure` for the fill list of command line +The ``MODULE_APP`` can be prefixed with ``asgi:`` or ``wsgi:`` to +ensure that the loaded app is treated as either an asgi or wsgi app. + +See :ref:`how_to_configure` for the full list of command line arguments. diff --git a/pyproject.toml b/pyproject.toml index 7e353ac..7a6b6a8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "Hypercorn" -version = "0.13.2+dev" +version = "0.17.3" description = "A ASGI Server based on Hyper libraries and inspired by Gunicorn" authors = ["pgjones "] classifiers = [ @@ -11,31 +11,38 @@ classifiers = [ "Operating System :: OS Independent", "Programming Language :: Python", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", "Topic :: Internet :: WWW/HTTP :: Dynamic Content", "Topic :: Software Development :: Libraries :: Python Modules", ] include = ["src/hypercorn/py.typed"] license = "MIT" readme = "README.rst" -repository = "https://gitlab.com/pgjones/hypercorn/" -documentation = "https://pgjones.gitlab.io/hypercorn/" +repository = "https://github.com/pgjones/hypercorn/" +documentation = "https://hypercorn.readthedocs.io" [tool.poetry.dependencies] -python = ">=3.7" +python = ">=3.8" aioquic = { version = ">= 0.9.0, < 1.0", optional = true } +exceptiongroup = { version = ">= 1.1.0", python = "<3.11" } h11 = "*" h2 = ">=3.1.0" priority = "*" -toml = "*" -trio = { version = ">=0.11.0", optional = true } -typing_extensions = { version = ">=3.7.4", python = "<3.8" } -uvloop = { version = "*", markers = "platform_system != 'Windows'", optional = true } +pydata_sphinx_theme = { version = "*", optional = true } +sphinxcontrib_mermaid = { version = "*", optional = true } +taskgroup = { version = "*", python = "<3.11", allow-prereleases = true } +tomli = { version = "*", python = "<3.11" } +trio = { version = ">=0.22.0", optional = true } +typing_extensions = { version = "*", python = "<3.11" } +uvloop = { version = ">=0.18", markers = "platform_system != 'Windows'", optional = true } wsproto = ">=0.14.0" [tool.poetry.dev-dependencies] +httpx = "*" hypothesis = "*" mock = "*" pytest = "*" @@ -47,13 +54,14 @@ trio = "*" hypercorn = "hypercorn.__main__:main" [tool.poetry.extras] +docs = ["pydata_sphinx_theme", "sphinxcontrib_mermaid"] h3 = ["aioquic"] trio = ["trio"] uvloop = ["uvloop"] [tool.black] line-length = 100 -target-version = ["py37"] +target-version = ["py38"] [tool.isort] combine_as_imports = true @@ -84,11 +92,19 @@ warn_unused_configs = true warn_unused_ignores = true [[tool.mypy.overrides]] -module =["aioquic.*", "cryptography.*", "h11.*", "h2.*", "priority.*", "trio.*", "uvloop.*"] +module =["aioquic.*", "cryptography.*", "h11.*", "h2.*", "priority.*", "pytest_asyncio.*", "uvloop.*"] ignore_missing_imports = true +[[tool.mypy.overrides]] +module = ["trio.*", "tests.trio.*"] +disallow_any_generics = true +disallow_untyped_calls = true +strict_optional = true +warn_return_any = true + [tool.pytest.ini_options] addopts = "--no-cov-on-fail --showlocals --strict-markers" +asyncio_mode = "strict" testpaths = ["tests"] [build-system] diff --git a/setup.cfg b/setup.cfg index 0f81513..3423d8f 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,5 +1,5 @@ [flake8] ignore = E203, E252, FI58, W503, W504 max_line_length = 100 -min_version = 3.7 +min_version = 3.8 require_code = True diff --git a/src/hypercorn/__main__.py b/src/hypercorn/__main__.py index 6e19ee9..bcc59c8 100644 --- a/src/hypercorn/__main__.py +++ b/src/hypercorn/__main__.py @@ -23,11 +23,16 @@ def _load_config(config_path: Optional[str]) -> Config: return Config.from_toml(config_path) -def main(sys_args: Optional[List[str]] = None) -> None: +def main(sys_args: Optional[List[str]] = None) -> int: parser = argparse.ArgumentParser() parser.add_argument( "application", help="The application to dispatch to as path.to.module:instance.path" ) + parser.add_argument( + "--worker-type", + help="The worker type to use, process or thread, useful for free-threading python build", + default=sentinel, + ) parser.add_argument("--access-log", help="Deprecated, see access-logfile", default=sentinel) parser.add_argument( "--access-logfile", @@ -89,6 +94,19 @@ def main(sys_args: Optional[List[str]] = None) -> None: default=sentinel, type=int, ) + parser.add_argument( + "--max-requests", + help="""Maximum number of requests a worker will process before restarting""", + default=sentinel, + type=int, + ) + parser.add_argument( + "--max-requests-jitter", + help="This jitter causes the max-requests per worker to be " + "randomized by randint(0, max_requests_jitter)", + default=sentinel, + type=int, + ) parser.add_argument( "-g", "--group", help="Group to own any unix sockets.", default=sentinel, type=int ) @@ -121,10 +139,14 @@ def main(sys_args: Optional[List[str]] = None) -> None: action="append", ) parser.add_argument( - "--log-config", help="A Python logging configuration file.", default=sentinel + "--log-config", + help=""""A Python logging configuration file. This can be prefixed with + 'json:' or 'toml:' to load the configuration from a file in + that format. Default is the logging ini format.""", + default=sentinel, ) parser.add_argument( - "--log-level", help="The (error) log level, defaults to info", default="INFO" + "--log-level", help="The (error) log level, defaults to info", default=sentinel ) parser.add_argument( "-p", "--pid", help="Location to write the PID (Program ID) to.", default=sentinel @@ -201,8 +223,11 @@ def _convert_verify_mode(value: str) -> ssl.VerifyMode: args = parser.parse_args(sys_args or sys.argv[1:]) config = _load_config(args.config) config.application_path = args.application - config.loglevel = args.log_level - + + if args.worker_type is not sentinel: + config.worker_type = args.worker_type + if args.log_level is not sentinel: + config.loglevel = args.log_level if args.access_logformat is not sentinel: config.access_log_format = args.access_logformat if args.access_log is not sentinel: @@ -247,6 +272,10 @@ def _convert_verify_mode(value: str) -> ssl.VerifyMode: config.keyfile_password = args.keyfile_password if args.log_config is not sentinel: config.logconfig = args.log_config + if args.max_requests is not sentinel: + config.max_requests = args.max_requests + if args.max_requests_jitter is not sentinel: + config.max_requests_jitter = args.max_requests if args.pid is not sentinel: config.pid_path = args.pid if args.root_path is not sentinel: @@ -279,8 +308,8 @@ def _convert_verify_mode(value: str) -> ssl.VerifyMode: if len(args.server_names) > 0: config.server_names = args.server_names - run(config) + return run(config) if __name__ == "__main__": - main() + sys.exit(main()) diff --git a/src/hypercorn/app_wrappers.py b/src/hypercorn/app_wrappers.py new file mode 100644 index 0000000..56c1bfa --- /dev/null +++ b/src/hypercorn/app_wrappers.py @@ -0,0 +1,163 @@ +from __future__ import annotations + +import sys +from functools import partial +from io import BytesIO +from typing import Callable, List, Optional, Tuple + +from .typing import ( + ASGIFramework, + ASGIReceiveCallable, + ASGISendCallable, + HTTPScope, + Scope, + WSGIFramework, +) + + +class InvalidPathError(Exception): + pass + + +class ASGIWrapper: + def __init__(self, app: ASGIFramework) -> None: + self.app = app + + async def __call__( + self, + scope: Scope, + receive: ASGIReceiveCallable, + send: ASGISendCallable, + sync_spawn: Callable, + call_soon: Callable, + ) -> None: + await self.app(scope, receive, send) + + +class WSGIWrapper: + def __init__(self, app: WSGIFramework, max_body_size: int) -> None: + self.app = app + self.max_body_size = max_body_size + + async def __call__( + self, + scope: Scope, + receive: ASGIReceiveCallable, + send: ASGISendCallable, + sync_spawn: Callable, + call_soon: Callable, + ) -> None: + if scope["type"] == "http": + await self.handle_http(scope, receive, send, sync_spawn, call_soon) + elif scope["type"] == "websocket": + await send({"type": "websocket.close"}) # type: ignore + elif scope["type"] == "lifespan": + return + else: + raise Exception(f"Unknown scope type, {scope['type']}") + + async def handle_http( + self, + scope: HTTPScope, + receive: ASGIReceiveCallable, + send: ASGISendCallable, + sync_spawn: Callable, + call_soon: Callable, + ) -> None: + body = bytearray() + while True: + message = await receive() + body.extend(message.get("body", b"")) # type: ignore + if len(body) > self.max_body_size: + await send({"type": "http.response.start", "status": 400, "headers": []}) + await send({"type": "http.response.body", "body": b"", "more_body": False}) + return + if not message.get("more_body"): + break + + try: + environ = _build_environ(scope, body) + except InvalidPathError: + await send({"type": "http.response.start", "status": 404, "headers": []}) + else: + await sync_spawn(self.run_app, environ, partial(call_soon, send)) + await send({"type": "http.response.body", "body": b"", "more_body": False}) + + def run_app(self, environ: dict, send: Callable) -> None: + headers: List[Tuple[bytes, bytes]] + response_started = False + status_code: Optional[int] = None + + def start_response( + status: str, + response_headers: List[Tuple[str, str]], + exc_info: Optional[Exception] = None, + ) -> None: + nonlocal headers, response_started, status_code + + raw, _ = status.split(" ", 1) + status_code = int(raw) + headers = [ + (name.lower().encode("latin-1"), value.encode("latin-1")) + for name, value in response_headers + ] + response_started = True + + response_body = self.app(environ, start_response) + + if not response_started: + raise RuntimeError("WSGI app did not call start_response") + + send({"type": "http.response.start", "status": status_code, "headers": headers}) + try: + for output in response_body: + send({"type": "http.response.body", "body": output, "more_body": True}) + finally: + if hasattr(response_body, "close"): + response_body.close() + + +def _build_environ(scope: HTTPScope, body: bytes) -> dict: + server = scope.get("server") or ("localhost", 80) + path = scope["path"] + script_name = scope.get("root_path", "") + if path.startswith(script_name): + path = path[len(script_name) :] + path = path if path != "" else "/" + else: + raise InvalidPathError() + + environ = { + "REQUEST_METHOD": scope["method"], + "SCRIPT_NAME": script_name.encode("utf8").decode("latin1"), + "PATH_INFO": path.encode("utf8").decode("latin1"), + "QUERY_STRING": scope["query_string"].decode("ascii"), + "SERVER_NAME": server[0], + "SERVER_PORT": server[1], + "SERVER_PROTOCOL": "HTTP/%s" % scope["http_version"], + "wsgi.version": (1, 0), + "wsgi.url_scheme": scope.get("scheme", "http"), + "wsgi.input": BytesIO(body), + "wsgi.errors": sys.stdout, + "wsgi.multithread": True, + "wsgi.multiprocess": True, + "wsgi.run_once": False, + } + + if scope.get("client") is not None: + environ["REMOTE_ADDR"] = scope["client"][0] + + for raw_name, raw_value in scope.get("headers", []): + name = raw_name.decode("latin1") + if name == "content-length": + corrected_name = "CONTENT_LENGTH" + elif name == "content-type": + corrected_name = "CONTENT_TYPE" + else: + corrected_name = "HTTP_%s" % name.upper().replace("-", "_") + # HTTPbis say only ASCII chars are allowed in headers, but we latin1 just in case + value = raw_value.decode("latin1") + if corrected_name in environ: + value = environ[corrected_name] + "," + value # type: ignore + environ[corrected_name] = value + return environ diff --git a/src/hypercorn/asyncio/__init__.py b/src/hypercorn/asyncio/__init__.py index 91035c7..3755da0 100644 --- a/src/hypercorn/asyncio/__init__.py +++ b/src/hypercorn/asyncio/__init__.py @@ -1,23 +1,25 @@ from __future__ import annotations import warnings -from typing import Awaitable, Callable, Optional +from typing import Awaitable, Callable, Literal, Optional from .run import worker_serve from ..config import Config -from ..typing import ASGIFramework +from ..typing import Framework +from ..utils import wrap_app async def serve( - app: ASGIFramework, + app: Framework, config: Config, *, - shutdown_trigger: Optional[Callable[..., Awaitable[None]]] = None, + shutdown_trigger: Optional[Callable[..., Awaitable]] = None, + mode: Optional[Literal["asgi", "wsgi"]] = None, ) -> None: - """Serve an ASGI framework app given the config. + """Serve an ASGI or WSGI framework app given the config. - This allows for a programmatic way to serve an ASGI framework, it - can be used via, + This allows for a programmatic way to serve an ASGI or WSGI + framework, it can be used via, .. code-block:: python @@ -28,14 +30,17 @@ async def serve( setup or process setup are ignored. Arguments: - app: The ASGI application to serve. + app: The ASGI or WSGI application to serve. config: A Hypercorn configuration object. shutdown_trigger: This should return to trigger a graceful shutdown. + mode: Specify if the app is WSGI or ASGI. """ if config.debug: warnings.warn("The config `debug` has no affect when using serve", Warning) if config.workers != 1: warnings.warn("The config `workers` has no affect when using serve", Warning) - await worker_serve(app, config, shutdown_trigger=shutdown_trigger) + await worker_serve( + wrap_app(app, config.wsgi_max_body_size, mode), config, shutdown_trigger=shutdown_trigger + ) diff --git a/src/hypercorn/asyncio/lifespan.py b/src/hypercorn/asyncio/lifespan.py index f21b762..3980345 100644 --- a/src/hypercorn/asyncio/lifespan.py +++ b/src/hypercorn/asyncio/lifespan.py @@ -1,10 +1,16 @@ from __future__ import annotations import asyncio +import sys +from functools import partial +from typing import Any, Callable from ..config import Config -from ..typing import ASGIFramework, ASGIReceiveEvent, ASGISendEvent, LifespanScope -from ..utils import invoke_asgi, LifespanFailureError, LifespanTimeoutError +from ..typing import AppWrapper, ASGIReceiveEvent, ASGISendEvent, LifespanScope, LifespanState +from ..utils import LifespanFailureError, LifespanTimeoutError + +if sys.version_info < (3, 11): + from exceptiongroup import BaseExceptionGroup class UnexpectedMessageError(Exception): @@ -12,13 +18,21 @@ class UnexpectedMessageError(Exception): class Lifespan: - def __init__(self, app: ASGIFramework, config: Config) -> None: + def __init__( + self, + app: AppWrapper, + config: Config, + loop: asyncio.AbstractEventLoop, + lifespan_state: LifespanState, + ) -> None: self.app = app self.config = config self.startup = asyncio.Event() self.shutdown = asyncio.Event() self.app_queue: asyncio.Queue = asyncio.Queue(config.max_app_queue_size) self.supported = True + self.loop = loop + self.state = lifespan_state # This mimics the Trio nursery.start task_status and is # required to ensure the support has been checked before @@ -27,13 +41,32 @@ def __init__(self, app: ASGIFramework, config: Config) -> None: async def handle_lifespan(self) -> None: self._started.set() - scope: LifespanScope = {"type": "lifespan", "asgi": {"spec_version": "2.0"}} + scope: LifespanScope = { + "type": "lifespan", + "asgi": {"spec_version": "2.0", "version": "3.0"}, + "state": self.state, + } + + def _call_soon(func: Callable, *args: Any) -> Any: + future = asyncio.run_coroutine_threadsafe(func(*args), self.loop) + return future.result() + try: - await invoke_asgi(self.app, scope, self.asgi_receive, self.asgi_send) - except LifespanFailureError: - # Lifespan failures should crash the server + await self.app( + scope, + self.asgi_receive, + self.asgi_send, + partial(self.loop.run_in_executor, None), + _call_soon, + ) + except (LifespanFailureError, asyncio.CancelledError): raise - except Exception: + except (BaseExceptionGroup, Exception) as error: + if isinstance(error, BaseExceptionGroup): + reraise_error = error.subgroup((LifespanFailureError, asyncio.CancelledError)) + if reraise_error is not None: + raise reraise_error + self.supported = False if not self.startup.is_set(): await self.config.log.warning( @@ -81,9 +114,9 @@ async def asgi_send(self, message: ASGISendEvent) -> None: self.shutdown.set() elif message["type"] == "lifespan.startup.failed": self.startup.set() - raise LifespanFailureError("startup", message["message"]) + raise LifespanFailureError("startup", message.get("message", "")) elif message["type"] == "lifespan.shutdown.failed": self.shutdown.set() - raise LifespanFailureError("shutdown", message["message"]) + raise LifespanFailureError("shutdown", message.get("message", "")) else: raise UnexpectedMessageError(message["type"]) diff --git a/src/hypercorn/asyncio/run.py b/src/hypercorn/asyncio/run.py index cf5004e..be4a22a 100644 --- a/src/hypercorn/asyncio/run.py +++ b/src/hypercorn/asyncio/run.py @@ -4,12 +4,13 @@ import platform import signal import ssl +import sys from functools import partial from multiprocessing.synchronize import Event as EventType from os import getpid +from random import randint from socket import socket -from typing import Any, Awaitable, Callable, Optional -from weakref import WeakSet +from typing import Any, Awaitable, Callable, Optional, Set from .lifespan import Lifespan from .statsd import StatsdLogger @@ -17,25 +18,27 @@ from .udp_server import UDPServer from .worker_context import WorkerContext from ..config import Config, Sockets -from ..typing import ASGIFramework +from ..typing import AppWrapper, LifespanState from ..utils import ( check_multiprocess_shutdown_event, load_application, - MustReloadError, - observe_changes, raise_shutdown, repr_socket_addr, - restart, ShutdownError, ) +try: + from asyncio import Runner +except ImportError: + from taskgroup import Runner # type: ignore -async def _windows_signal_support() -> None: - # See https://bugs.python.org/issue23057, to catch signals on - # Windows it is necessary for an IO event to happen periodically. - # Fixed by Python 3.8 - while True: - await asyncio.sleep(1) +try: + from asyncio import TaskGroup +except ImportError: + from taskgroup import TaskGroup # type: ignore + +if sys.version_info < (3, 11): + from exceptiongroup import BaseExceptionGroup def _share_socket(sock: socket) -> socket: @@ -48,11 +51,11 @@ def _share_socket(sock: socket) -> socket: async def worker_serve( - app: ASGIFramework, + app: AppWrapper, config: Config, *, sockets: Optional[Sockets] = None, - shutdown_trigger: Optional[Callable[..., Awaitable[None]]] = None, + shutdown_trigger: Optional[Callable[..., Awaitable]] = None, ) -> None: config.set_statsd_logger_class(StatsdLogger) @@ -72,10 +75,10 @@ def _signal_handler(*_: Any) -> None: # noqa: N803 # Add signal handler may not be implemented on Windows signal.signal(getattr(signal, signal_name), _signal_handler) - shutdown_trigger = signal_event.wait # type: ignore + shutdown_trigger = signal_event.wait - lifespan = Lifespan(app, config) - reload_ = False + lifespan_state: LifespanState = {} + lifespan = Lifespan(app, config, loop, lifespan_state) lifespan_task = loop.create_task(lifespan.handle_lifespan()) await lifespan.wait_for_startup() @@ -92,16 +95,23 @@ def _signal_handler(*_: Any) -> None: # noqa: N803 ssl_context = config.create_ssl_context() ssl_handshake_timeout = config.ssl_handshake_timeout - context = WorkerContext() - server_tasks: WeakSet = WeakSet() + max_requests = None + if config.max_requests is not None: + max_requests = config.max_requests + randint(0, config.max_requests_jitter) + context = WorkerContext(max_requests) + server_tasks: Set[asyncio.Task] = set() async def _server_callback(reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None: - server_tasks.add(asyncio.current_task(loop)) - await TCPServer(app, loop, config, context, reader, writer) + nonlocal server_tasks + + task = asyncio.current_task(loop) + server_tasks.add(task) + task.add_done_callback(server_tasks.discard) + await TCPServer(app, loop, config, context, lifespan_state, reader, writer) servers = [] for sock in sockets.secure_sockets: - if config.workers > 1 and platform.system() == "Windows": + if config.workers > 1 and platform.system() == "Windows" and config.worker_class == "process": sock = _share_socket(sock) servers.append( @@ -117,7 +127,7 @@ async def _server_callback(reader: asyncio.StreamReader, writer: asyncio.StreamW await config.log.info(f"Running on https://{bind} (CTRL + C to quit)") for sock in sockets.insecure_sockets: - if config.workers > 1 and platform.system() == "Windows": + if config.workers > 1 and platform.system() == "Windows" and config.worker_class == "process": sock = _share_socket(sock) servers.append( @@ -127,68 +137,54 @@ async def _server_callback(reader: asyncio.StreamReader, writer: asyncio.StreamW await config.log.info(f"Running on http://{bind} (CTRL + C to quit)") for sock in sockets.quic_sockets: - if config.workers > 1 and platform.system() == "Windows": + if config.workers > 1 and platform.system() == "Windows" and config.worker_class == "process": sock = _share_socket(sock) _, protocol = await loop.create_datagram_endpoint( - lambda: UDPServer(app, loop, config, context), sock=sock + lambda: UDPServer(app, loop, config, context, lifespan_state), sock=sock ) - server_tasks.add(loop.create_task(protocol.run())) # type: ignore + task = loop.create_task(protocol.run()) + server_tasks.add(task) + task.add_done_callback(server_tasks.discard) bind = repr_socket_addr(sock.family, sock.getsockname()) await config.log.info(f"Running on https://{bind} (QUIC) (CTRL + C to quit)") - tasks = [] - if platform.system() == "Windows": - tasks.append(loop.create_task(_windows_signal_support())) - - tasks.append(loop.create_task(raise_shutdown(shutdown_trigger))) - - if config.use_reloader: - tasks.append(loop.create_task(observe_changes(asyncio.sleep))) - try: - if len(tasks): - gathered_tasks = asyncio.gather(*tasks) - await gathered_tasks - else: - loop.run_forever() - except MustReloadError: - reload_ = True + async with TaskGroup() as task_group: + task_group.create_task(raise_shutdown(shutdown_trigger)) + task_group.create_task(raise_shutdown(context.terminate.wait)) + except BaseExceptionGroup as error: + _, other_errors = error.split((ShutdownError, KeyboardInterrupt)) + if other_errors is not None: + raise other_errors except (ShutdownError, KeyboardInterrupt): pass finally: - context.terminated = True + await context.terminated.set() for server in servers: server.close() await server.wait_closed() - # Retrieve the Gathered Tasks Cancelled Exception, to - # prevent a warning that this hasn't been done. - gathered_tasks.exception() - try: gathered_server_tasks = asyncio.gather(*server_tasks) await asyncio.wait_for(gathered_server_tasks, config.graceful_timeout) except asyncio.TimeoutError: pass + finally: + # Retrieve the Gathered Tasks Cancelled Exception, to + # prevent a warning that this hasn't been done. + gathered_server_tasks.exception() - # Retrieve the Gathered Tasks Cancelled Exception, to - # prevent a warning that this hasn't been done. - gathered_server_tasks.exception() - - await lifespan.wait_for_shutdown() - lifespan_task.cancel() - await lifespan_task - - if reload_: - restart() + await lifespan.wait_for_shutdown() + lifespan_task.cancel() + await lifespan_task def asyncio_worker( config: Config, sockets: Optional[Sockets] = None, shutdown_event: Optional[EventType] = None ) -> None: - app = load_application(config.application_path) + app = load_application(config.application_path, config.wsgi_max_body_size) shutdown_trigger = None if shutdown_event is not None: @@ -211,10 +207,8 @@ def uvloop_worker( import uvloop except ImportError as error: raise Exception("uvloop is not installed") from error - else: - asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) - app = load_application(config.application_path) + app = load_application(config.application_path, config.wsgi_max_body_size) shutdown_trigger = None if shutdown_event is not None: @@ -224,6 +218,7 @@ def uvloop_worker( partial(worker_serve, app, config, sockets=sockets), debug=config.debug, shutdown_trigger=shutdown_trigger, + loop_factory=uvloop.new_event_loop, ) @@ -232,49 +227,11 @@ def _run( *, debug: bool = False, shutdown_trigger: Optional[Callable[..., Awaitable[None]]] = None, + loop_factory: Callable[[], asyncio.AbstractEventLoop] | None = None, ) -> None: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - loop.set_debug(debug) - loop.set_exception_handler(_exception_handler) - - try: - loop.run_until_complete(main(shutdown_trigger=shutdown_trigger)) - except KeyboardInterrupt: - pass - finally: - try: - _cancel_all_tasks(loop) - loop.run_until_complete(loop.shutdown_asyncgens()) - - try: - loop.run_until_complete(loop.shutdown_default_executor()) - except AttributeError: - pass # shutdown_default_executor is new to Python 3.9 - - finally: - asyncio.set_event_loop(None) - loop.close() - - -def _cancel_all_tasks(loop: asyncio.AbstractEventLoop) -> None: - tasks = [task for task in asyncio.all_tasks(loop) if not task.done()] - if not tasks: - return - - for task in tasks: - task.cancel() - loop.run_until_complete(asyncio.gather(*tasks, return_exceptions=True)) - - for task in tasks: - if not task.cancelled() and task.exception() is not None: - loop.call_exception_handler( - { - "message": "unhandled exception during shutdown", - "exception": task.exception(), - "task": task, - } - ) + with Runner(debug=debug, loop_factory=loop_factory) as runner: + runner.get_loop().set_exception_handler(_exception_handler) + runner.run(main(shutdown_trigger=shutdown_trigger)) def _exception_handler(loop: asyncio.AbstractEventLoop, context: dict) -> None: diff --git a/src/hypercorn/asyncio/task_group.py b/src/hypercorn/asyncio/task_group.py index 42867fe..2e58903 100644 --- a/src/hypercorn/asyncio/task_group.py +++ b/src/hypercorn/asyncio/task_group.py @@ -1,24 +1,30 @@ from __future__ import annotations import asyncio -import weakref +from functools import partial from types import TracebackType from typing import Any, Awaitable, Callable, Optional from ..config import Config -from ..typing import ASGIFramework, ASGIReceiveCallable, ASGIReceiveEvent, ASGISendEvent, Scope -from ..utils import invoke_asgi +from ..typing import AppWrapper, ASGIReceiveCallable, ASGIReceiveEvent, ASGISendEvent, Scope + +try: + from asyncio import TaskGroup as AsyncioTaskGroup +except ImportError: + from taskgroup import TaskGroup as AsyncioTaskGroup # type: ignore async def _handle( - app: ASGIFramework, + app: AppWrapper, config: Config, scope: Scope, receive: ASGIReceiveCallable, send: Callable[[Optional[ASGISendEvent]], Awaitable[None]], + sync_spawn: Callable, + call_soon: Callable, ) -> None: try: - await invoke_asgi(app, scope, receive, send) + await app(scope, receive, send, sync_spawn, call_soon) except asyncio.CancelledError: raise except Exception: @@ -30,43 +36,39 @@ async def _handle( class TaskGroup: def __init__(self, loop: asyncio.AbstractEventLoop) -> None: self._loop = loop - self._tasks: weakref.WeakSet = weakref.WeakSet() - self._exiting = False + self._task_group = AsyncioTaskGroup() async def spawn_app( self, - app: ASGIFramework, + app: AppWrapper, config: Config, scope: Scope, send: Callable[[Optional[ASGISendEvent]], Awaitable[None]], ) -> Callable[[ASGIReceiveEvent], Awaitable[None]]: app_queue: asyncio.Queue[ASGIReceiveEvent] = asyncio.Queue(config.max_app_queue_size) - self.spawn(_handle, app, config, scope, app_queue.get, send) + + def _call_soon(func: Callable, *args: Any) -> Any: + future = asyncio.run_coroutine_threadsafe(func(*args), self._loop) + return future.result() + + self.spawn( + _handle, + app, + config, + scope, + app_queue.get, + send, + partial(self._loop.run_in_executor, None), + _call_soon, + ) return app_queue.put def spawn(self, func: Callable, *args: Any) -> None: - if self._exiting: - raise RuntimeError("Spawning whilst exiting") - self._tasks.add(self._loop.create_task(func(*args))) + self._task_group.create_task(func(*args)) async def __aenter__(self) -> "TaskGroup": + await self._task_group.__aenter__() return self async def __aexit__(self, exc_type: type, exc_value: BaseException, tb: TracebackType) -> None: - self._exiting = True - if exc_type is not None: - self._cancel_tasks() - - try: - task = asyncio.gather(*self._tasks) - await task - finally: - task.cancel() - try: - await task - except asyncio.CancelledError: - pass - - def _cancel_tasks(self) -> None: - for task in self._tasks: - task.cancel() + await self._task_group.__aexit__(exc_type, exc_value, tb) diff --git a/src/hypercorn/asyncio/tcp_server.py b/src/hypercorn/asyncio/tcp_server.py index b143a29..bf9d9fe 100644 --- a/src/hypercorn/asyncio/tcp_server.py +++ b/src/hypercorn/asyncio/tcp_server.py @@ -2,40 +2,27 @@ import asyncio from ssl import SSLError -from typing import Any, Callable, Generator, Optional +from typing import Any, Generator from .task_group import TaskGroup -from .worker_context import WorkerContext +from .worker_context import AsyncioSingleTask, WorkerContext from ..config import Config from ..events import Closed, Event, RawData, Updated from ..protocol import ProtocolWrapper -from ..typing import ASGIFramework +from ..typing import AppWrapper, ConnectionState, LifespanState from ..utils import parse_socket_addr -MAX_RECV = 2 ** 16 - - -class EventWrapper: - def __init__(self) -> None: - self._event = asyncio.Event() - - async def clear(self) -> None: - self._event.clear() - - async def wait(self) -> None: - await self._event.wait() - - async def set(self) -> None: - self._event.set() +MAX_RECV = 2**16 class TCPServer: def __init__( self, - app: ASGIFramework, + app: AppWrapper, loop: asyncio.AbstractEventLoop, config: Config, context: WorkerContext, + state: LifespanState, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, ) -> None: @@ -47,9 +34,8 @@ def __init__( self.reader = reader self.writer = writer self.send_lock = asyncio.Lock() - self.timeout_lock = asyncio.Lock() - - self._keep_alive_timeout_handle: Optional[asyncio.Task] = None + self.state = state + self.idle_task = AsyncioSingleTask() def __await__(self) -> Generator[Any, None, None]: return self.run().__await__() @@ -68,11 +54,13 @@ async def run(self) -> None: alpn_protocol = "http/1.1" async with TaskGroup(self.loop) as task_group: + self._task_group = task_group self.protocol = ProtocolWrapper( self.app, self.config, self.context, task_group, + ConnectionState(self.state.copy()), ssl, client, server, @@ -80,7 +68,7 @@ async def run(self) -> None: alpn_protocol, ) await self.protocol.initiate() - await self._start_keep_alive_timeout() + await self.idle_task.restart(task_group, self._idle_timeout) await self._read_data() except OSError: pass @@ -97,12 +85,11 @@ async def protocol_send(self, event: Event) -> None: await self.protocol.handle(Closed()) elif isinstance(event, Closed): await self._close() - await self.protocol.handle(Closed()) elif isinstance(event, Updated): if event.idle: - await self._start_keep_alive_timeout() + await self.idle_task.restart(self._task_group, self._idle_timeout) else: - await self._stop_keep_alive_timeout() + await self.idle_task.stop() async def _read_data(self) -> None: while not self.reader.at_eof(): @@ -115,11 +102,12 @@ async def _read_data(self) -> None: TimeoutError, SSLError, ): - await self.protocol.handle(Closed()) break else: await self.protocol.handle(RawData(data)) + await self.protocol.handle(Closed()) + async def _close(self) -> None: try: self.writer.write_eof() @@ -129,32 +117,24 @@ async def _close(self) -> None: try: self.writer.close() await self.writer.wait_closed() - except (BrokenPipeError, ConnectionResetError, RuntimeError): + except ( + BrokenPipeError, + ConnectionAbortedError, + ConnectionResetError, + RuntimeError, + asyncio.CancelledError, + ): pass # Already closed + finally: + await self.idle_task.stop() - await self._stop_keep_alive_timeout() - - async def _start_keep_alive_timeout(self) -> None: - async with self.timeout_lock: - if self._keep_alive_timeout_handle is None: - self._keep_alive_timeout_handle = self.loop.create_task( - _call_later(self.config.keep_alive_timeout, self._timeout) - ) - - async def _timeout(self) -> None: + async def _initiate_server_close(self) -> None: await self.protocol.handle(Closed()) self.writer.close() - async def _stop_keep_alive_timeout(self) -> None: - async with self.timeout_lock: - if self._keep_alive_timeout_handle is not None: - self._keep_alive_timeout_handle.cancel() - try: - await self._keep_alive_timeout_handle - except asyncio.CancelledError: - pass - - -async def _call_later(timeout: float, callback: Callable) -> None: - await asyncio.sleep(timeout) - await asyncio.shield(callback()) + async def _idle_timeout(self) -> None: + try: + await asyncio.wait_for(self.context.terminated.wait(), self.config.keep_alive_timeout) + except asyncio.TimeoutError: + pass + await asyncio.shield(self._initiate_server_close()) diff --git a/src/hypercorn/asyncio/udp_server.py b/src/hypercorn/asyncio/udp_server.py index 02329ec..32857cc 100644 --- a/src/hypercorn/asyncio/udp_server.py +++ b/src/hypercorn/asyncio/udp_server.py @@ -7,7 +7,7 @@ from .worker_context import WorkerContext from ..config import Config from ..events import Event, RawData -from ..typing import ASGIFramework +from ..typing import AppWrapper, ConnectionState, LifespanState from ..utils import parse_socket_addr if TYPE_CHECKING: @@ -18,10 +18,11 @@ class UDPServer(asyncio.DatagramProtocol): def __init__( self, - app: ASGIFramework, + app: AppWrapper, loop: asyncio.AbstractEventLoop, config: Config, context: WorkerContext, + state: LifespanState, ) -> None: self.app = app self.config = config @@ -30,6 +31,7 @@ def __init__( self.protocol: "QuicProtocol" self.protocol_queue: asyncio.Queue = asyncio.Queue(10) self.transport: Optional[asyncio.DatagramTransport] = None + self.state = state def connection_made(self, transport: asyncio.DatagramTransport) -> None: # type: ignore self.transport = transport @@ -48,10 +50,16 @@ async def run(self) -> None: server = parse_socket_addr(socket.family, socket.getsockname()) async with TaskGroup(self.loop) as task_group: self.protocol = QuicProtocol( - self.app, self.config, self.context, task_group, server, self.protocol_send + self.app, + self.config, + self.context, + task_group, + ConnectionState(self.state.copy()), + server, + self.protocol_send, ) - while not self.context.terminated or not self.protocol.idle: + while not self.context.terminated.is_set() or not self.protocol.idle: event = await self.protocol_queue.get() await self.protocol.handle(event) diff --git a/src/hypercorn/asyncio/worker_context.py b/src/hypercorn/asyncio/worker_context.py index fe3fd1b..31e9877 100644 --- a/src/hypercorn/asyncio/worker_context.py +++ b/src/hypercorn/asyncio/worker_context.py @@ -1,9 +1,37 @@ from __future__ import annotations import asyncio -from typing import Type, Union +from typing import Callable, Optional, Type, Union -from ..typing import Event +from ..typing import Event, SingleTask, TaskGroup + + +class AsyncioSingleTask: + def __init__(self) -> None: + self._handle: Optional[asyncio.Task] = None + self._lock = asyncio.Lock() + + async def restart(self, task_group: TaskGroup, action: Callable) -> None: + async with self._lock: + if self._handle is not None: + self._handle.cancel() + try: + await self._handle + except asyncio.CancelledError: + pass + + self._handle = task_group._task_group.create_task(action()) # type: ignore + + async def stop(self) -> None: + async with self._lock: + if self._handle is not None: + self._handle.cancel() + try: + await self._handle + except asyncio.CancelledError: + pass + + self._handle = None class EventWrapper: @@ -19,12 +47,27 @@ async def wait(self) -> None: async def set(self) -> None: self._event.set() + def is_set(self) -> bool: + return self._event.is_set() + class WorkerContext: event_class: Type[Event] = EventWrapper + single_task_class: Type[SingleTask] = AsyncioSingleTask - def __init__(self) -> None: - self.terminated = False + def __init__(self, max_requests: Optional[int]) -> None: + self.max_requests = max_requests + self.requests = 0 + self.terminate = self.event_class() + self.terminated = self.event_class() + + async def mark_request(self) -> None: + if self.max_requests is None: + return + + self.requests += 1 + if self.requests > self.max_requests: + await self.terminate.set() @staticmethod async def sleep(wait: Union[float, int]) -> None: diff --git a/src/hypercorn/config.py b/src/hypercorn/config.py index dd2224b..79d3a5f 100644 --- a/src/hypercorn/config.py +++ b/src/hypercorn/config.py @@ -6,6 +6,7 @@ import os import socket import stat +import sys import types import warnings from dataclasses import dataclass @@ -19,10 +20,13 @@ VerifyMode, ) from time import time -from typing import Any, AnyStr, Dict, List, Mapping, Optional, Tuple, Type, Union +from typing import Any, AnyStr, Dict, List, Mapping, Optional, Tuple, Type, Union, Literal from wsgiref.handlers import format_date_time -import toml +if sys.version_info >= (3, 11): + import tomllib +else: + import tomli as tomllib from .logging import Logger @@ -56,7 +60,8 @@ class Config: _quic_addresses: List[Tuple] = [] _log: Optional[Logger] = None _root_path: str = "" - + + worker_type: Literal["thread", "process"] = "process" access_log_format = '%(h)s %(l)s %(l)s %(t)s "%(r)s" %(s)s %(b)s "%(f)s" "%(a)s"' accesslog: Union[logging.Logger, str, None] = None alpn_protocols = ["h2", "http/1.1"] @@ -73,11 +78,14 @@ class Config: read_timeout: Optional[int] = None group: Optional[int] = None h11_max_incomplete_size = 16 * 1024 * BYTES + h11_pass_raw_headers = False h2_max_concurrent_streams = 100 - h2_max_header_list_size = 2 ** 16 - h2_max_inbound_frame_size = 2 ** 14 * OCTETS + h2_max_header_list_size = 2**16 + h2_max_inbound_frame_size = 2**14 * OCTETS + include_date_header = True include_server_header = True keep_alive_timeout = 5 * SECONDS + keep_alive_max_requests = 1000 keyfile: Optional[str] = None keyfile_password: Optional[str] = None logconfig: Optional[str] = None @@ -85,6 +93,8 @@ class Config: logger_class = Logger loglevel: str = "INFO" max_app_queue_size: int = 10 + max_requests: Optional[int] = None + max_requests_jitter: int = 0 pid_path: Optional[str] = None server_names: List[str] = [] shutdown_timeout = 60 * SECONDS @@ -101,6 +111,7 @@ class Config: websocket_ping_interval: Optional[float] = None worker_class = "asyncio" workers = 1 + wsgi_max_body_size = 16 * 1024 * 1024 * BYTES def set_cert_reqs(self, value: int) -> None: warnings.warn("Please use verify_mode instead", Warning) @@ -236,6 +247,10 @@ def _create_sockets( except (ValueError, IndexError): host, port = bind, 8000 sock = socket.socket(socket.AF_INET6 if ":" in host else socket.AF_INET, type_) + + if type_ == socket.SOCK_STREAM: + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + if self.workers > 1: try: sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) @@ -267,7 +282,9 @@ def _create_sockets( return sockets def response_headers(self, protocol: str) -> List[Tuple[bytes, bytes]]: - headers = [(b"date", format_date_time(time()).encode("ascii"))] + headers = [] + if self.include_date_header: + headers.append((b"date", format_date_time(time()).encode("ascii"))) if self.include_server_header: headers.append((b"server", f"hypercorn-{protocol}".encode("ascii"))) @@ -351,8 +368,8 @@ def from_toml(cls: Type["Config"], filename: FilePath) -> "Config": filename: The filename which gives the path to the file. """ file_path = os.fspath(filename) - with open(file_path) as file_: - data = toml.load(file_) + with open(file_path, "rb") as file_: + data = tomllib.load(file_) return cls.from_mapping(data) @classmethod @@ -387,6 +404,6 @@ def from_object(cls: Type["Config"], instance: Union[object, str]) -> "Config": mapping = { key: getattr(instance, key) for key in dir(instance) - if not isinstance(getattr(instance, key), types.ModuleType) + if not isinstance(getattr(instance, key), types.ModuleType) and not key.startswith("__") } return cls.from_mapping(mapping) diff --git a/src/hypercorn/logging.py b/src/hypercorn/logging.py index e583f96..d9b8901 100644 --- a/src/hypercorn/logging.py +++ b/src/hypercorn/logging.py @@ -1,5 +1,6 @@ from __future__ import annotations +import json import logging import os import sys @@ -8,6 +9,12 @@ from logging.config import dictConfig, fileConfig from typing import Any, IO, Mapping, Optional, TYPE_CHECKING, Union +if sys.version_info >= (3, 11): + import tomllib +else: + import tomli as tomllib + + if TYPE_CHECKING: from .config import Config from .typing import ResponseSummary, WWWScope @@ -58,11 +65,18 @@ def __init__(self, config: "Config") -> None: ) if config.logconfig is not None: - log_config = { - "__file__": config.logconfig, - "here": os.path.dirname(config.logconfig), - } - fileConfig(config.logconfig, defaults=log_config, disable_existing_loggers=False) + if config.logconfig.startswith("json:"): + with open(config.logconfig[5:]) as file_: + dictConfig(json.load(file_)) + elif config.logconfig.startswith("toml:"): + with open(config.logconfig[5:], "rb") as file_: + dictConfig(tomllib.load(file_)) + else: + log_config = { + "__file__": config.logconfig, + "here": os.path.dirname(config.logconfig), + } + fileConfig(config.logconfig, defaults=log_config, disable_existing_loggers=False) else: if config.logconfig_dict is not None: dictConfig(config.logconfig_dict) @@ -104,7 +118,7 @@ async def log(self, level: int, message: str, *args: Any, **kwargs: Any) -> None self.error_logger.log(level, message, *args, **kwargs) def atoms( - self, request: "WWWScope", response: "ResponseSummary", request_time: float + self, request: "WWWScope", response: Optional["ResponseSummary"], request_time: float ) -> Mapping[str, str]: """Create and return an access log atoms dictionary. @@ -119,12 +133,10 @@ def __getattr__(self, name: str) -> Any: class AccessLogAtoms(dict): def __init__( - self, request: "WWWScope", response: "ResponseSummary", request_time: float + self, request: "WWWScope", response: Optional["ResponseSummary"], request_time: float ) -> None: for name, value in request["headers"]: self[f"{{{name.decode('latin1').lower()}}}i"] = value.decode("latin1") - for name, value in response.get("headers", []): - self[f"{{{name.decode('latin1').lower()}}}o"] = value.decode("latin1") for name, value in os.environ.items(): self[f"{{{name.lower()}}}e"] = value protocol = request.get("http_version", "ws") @@ -143,11 +155,17 @@ def __init__( method = "GET" query_string = request["query_string"].decode() path_with_qs = request["path"] + ("?" + query_string if query_string else "") - status_code = response["status"] - try: - status_phrase = HTTPStatus(status_code).phrase - except ValueError: - status_phrase = f"" + + status_code = "-" + status_phrase = "-" + if response is not None: + for name, value in response.get("headers", []): # type: ignore + self[f"{{{name.decode('latin1').lower()}}}o"] = value.decode("latin1") # type: ignore # noqa: E501 + status_code = str(response["status"]) + try: + status_phrase = HTTPStatus(response["status"]).phrase + except ValueError: + status_phrase = f"" self.update( { "h": remote_addr, @@ -155,7 +173,7 @@ def __init__( "t": time.strftime("[%d/%b/%Y:%H:%M:%S %z]"), "r": f"{method} {request['path']} {protocol}", "R": f"{method} {path_with_qs} {protocol}", - "s": response["status"], + "s": status_code, "st": status_phrase, "S": request["scheme"], "m": method, diff --git a/src/hypercorn/middleware/__init__.py b/src/hypercorn/middleware/__init__.py index 83ea29c..e7f017c 100644 --- a/src/hypercorn/middleware/__init__.py +++ b/src/hypercorn/middleware/__init__.py @@ -2,11 +2,13 @@ from .dispatcher import DispatcherMiddleware from .http_to_https import HTTPToHTTPSRedirectMiddleware +from .proxy_fix import ProxyFixMiddleware from .wsgi import AsyncioWSGIMiddleware, TrioWSGIMiddleware __all__ = ( "AsyncioWSGIMiddleware", "DispatcherMiddleware", "HTTPToHTTPSRedirectMiddleware", + "ProxyFixMiddleware", "TrioWSGIMiddleware", ) diff --git a/src/hypercorn/middleware/dispatcher.py b/src/hypercorn/middleware/dispatcher.py index 40832b3..abe0e7e 100644 --- a/src/hypercorn/middleware/dispatcher.py +++ b/src/hypercorn/middleware/dispatcher.py @@ -5,8 +5,7 @@ from typing import Callable, Dict from ..asyncio.task_group import TaskGroup -from ..typing import ASGIFramework, Scope -from ..utils import invoke_asgi +from ..typing import ASGIFramework, ASGIReceiveEvent, Scope MAX_QUEUE_SIZE = 10 @@ -21,8 +20,9 @@ async def __call__(self, scope: Scope, receive: Callable, send: Callable) -> Non else: for path, app in self.mounts.items(): if scope["path"].startswith(path): - scope["path"] = scope["path"][len(path) :] or "/" - return await invoke_asgi(app, scope, receive, send) + local_scope = scope.copy() + local_scope["root_path"] += path + return await app(local_scope, receive, send) await send( { "type": "http.response.start", @@ -47,7 +47,6 @@ async def _handle_lifespan(self, scope: Scope, receive: Callable, send: Callable async with TaskGroup(asyncio.get_event_loop()) as task_group: for path, app in self.mounts.items(): task_group.spawn( - invoke_asgi, app, scope, self.app_queues[path].get, @@ -76,14 +75,15 @@ class TrioDispatcherMiddleware(_DispatcherMiddleware): async def _handle_lifespan(self, scope: Scope, receive: Callable, send: Callable) -> None: import trio - self.app_queues = {path: trio.open_memory_channel(MAX_QUEUE_SIZE) for path in self.mounts} + self.app_queues = { + path: trio.open_memory_channel[ASGIReceiveEvent](MAX_QUEUE_SIZE) for path in self.mounts + } self.startup_complete = {path: False for path in self.mounts} self.shutdown_complete = {path: False for path in self.mounts} async with trio.open_nursery() as nursery: for path, app in self.mounts.items(): nursery.start_soon( - invoke_asgi, app, scope, self.app_queues[path][1].receive, diff --git a/src/hypercorn/middleware/http_to_https.py b/src/hypercorn/middleware/http_to_https.py index 200b84d..542b28f 100644 --- a/src/hypercorn/middleware/http_to_https.py +++ b/src/hypercorn/middleware/http_to_https.py @@ -4,7 +4,6 @@ from urllib.parse import urlunsplit from ..typing import ASGIFramework, HTTPScope, Scope, WebsocketScope, WWWScope -from ..utils import invoke_asgi class HTTPToHTTPSRedirectMiddleware: @@ -24,7 +23,7 @@ async def __call__(self, scope: Scope, receive: Callable, send: Callable) -> Non else: await send({"type": "websocket.close"}) else: - return await invoke_asgi(self.app, scope, receive, send) + return await self.app(scope, receive, send) async def _send_http_redirect(self, scope: HTTPScope, send: Callable) -> None: new_url = self._new_url("https", scope) diff --git a/src/hypercorn/middleware/proxy_fix.py b/src/hypercorn/middleware/proxy_fix.py new file mode 100644 index 0000000..bd3dc4c --- /dev/null +++ b/src/hypercorn/middleware/proxy_fix.py @@ -0,0 +1,79 @@ +from __future__ import annotations + +from copy import deepcopy +from typing import Callable, Iterable, Literal, Optional, Tuple + +from ..typing import ASGIFramework, Scope + + +class ProxyFixMiddleware: + def __init__( + self, + app: ASGIFramework, + mode: Literal["legacy", "modern"] = "legacy", + trusted_hops: int = 1, + ) -> None: + self.app = app + self.mode = mode + self.trusted_hops = trusted_hops + + async def __call__(self, scope: Scope, receive: Callable, send: Callable) -> None: + # Keep the `or` instead of `in {'http' …}` to allow type narrowing + if scope["type"] == "http" or scope["type"] == "websocket": + scope = deepcopy(scope) + headers = scope["headers"] + client: Optional[str] = None + scheme: Optional[str] = None + host: Optional[str] = None + + if ( + self.mode == "modern" + and (value := _get_trusted_value(b"forwarded", headers, self.trusted_hops)) + is not None + ): + for part in value.split(";"): + if part.startswith("for="): + client = part[4:].strip() + elif part.startswith("host="): + host = part[5:].strip() + elif part.startswith("proto="): + scheme = part[6:].strip() + + else: + client = _get_trusted_value(b"x-forwarded-for", headers, self.trusted_hops) + scheme = _get_trusted_value(b"x-forwarded-proto", headers, self.trusted_hops) + host = _get_trusted_value(b"x-forwarded-host", headers, self.trusted_hops) + + if client is not None: + scope["client"] = (client, 0) + + if scheme is not None: + scope["scheme"] = scheme + + if host is not None: + headers = [ + (name, header_value) + for name, header_value in headers + if name.lower() != b"host" + ] + headers.append((b"host", host.encode())) + scope["headers"] = headers + + await self.app(scope, receive, send) + + +def _get_trusted_value( + name: bytes, headers: Iterable[Tuple[bytes, bytes]], trusted_hops: int +) -> Optional[str]: + if trusted_hops == 0: + return None + + values = [] + for header_name, header_value in headers: + if header_name.lower() == name: + values.extend([value.decode("latin1").strip() for value in header_value.split(b",")]) + + if len(values) >= trusted_hops: + return values[-trusted_hops] + + return None diff --git a/src/hypercorn/middleware/wsgi.py b/src/hypercorn/middleware/wsgi.py index 9ed74c9..8e4f61b 100644 --- a/src/hypercorn/middleware/wsgi.py +++ b/src/hypercorn/middleware/wsgi.py @@ -2,12 +2,12 @@ import asyncio from functools import partial -from io import BytesIO -from typing import Callable, Iterable, List, Optional, Tuple +from typing import Any, Callable, Iterable -from ..typing import HTTPScope, Scope +from ..app_wrappers import WSGIWrapper +from ..typing import ASGIReceiveCallable, ASGISendCallable, Scope, WSGIFramework -MAX_BODY_SIZE = 2 ** 16 +MAX_BODY_SIZE = 2**16 WSGICallable = Callable[[dict, Callable], Iterable[bytes]] @@ -17,134 +17,33 @@ class InvalidPathError(Exception): class _WSGIMiddleware: - def __init__(self, wsgi_app: WSGICallable, max_body_size: int = MAX_BODY_SIZE) -> None: - self.wsgi_app = wsgi_app + def __init__(self, wsgi_app: WSGIFramework, max_body_size: int = MAX_BODY_SIZE) -> None: + self.wsgi_app = WSGIWrapper(wsgi_app, max_body_size) self.max_body_size = max_body_size - async def __call__(self, scope: Scope, receive: Callable, send: Callable) -> None: - if scope["type"] == "http": - status_code, headers, body = await self._handle_http(scope, receive, send) - await send({"type": "http.response.start", "status": status_code, "headers": headers}) - await send({"type": "http.response.body", "body": body}) - elif scope["type"] == "websocket": - await send({"type": "websocket.close"}) - elif scope["type"] == "lifespan": - return - else: - raise Exception(f"Unknown scope type, {scope['type']}") - - async def _handle_http( - self, scope: HTTPScope, receive: Callable, send: Callable - ) -> Tuple[int, list, bytes]: + async def __call__( + self, scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable + ) -> None: pass class AsyncioWSGIMiddleware(_WSGIMiddleware): - async def _handle_http( - self, scope: HTTPScope, receive: Callable, send: Callable - ) -> Tuple[int, list, bytes]: + async def __call__( + self, scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable + ) -> None: loop = asyncio.get_event_loop() - instance = _WSGIInstance(self.wsgi_app, self.max_body_size) - return await instance.handle_http(scope, receive, partial(loop.run_in_executor, None)) - -class TrioWSGIMiddleware(_WSGIMiddleware): - async def _handle_http( - self, scope: HTTPScope, receive: Callable, send: Callable - ) -> Tuple[int, list, bytes]: - import trio + def _call_soon(func: Callable, *args: Any) -> Any: + future = asyncio.run_coroutine_threadsafe(func(*args), loop) + return future.result() - instance = _WSGIInstance(self.wsgi_app, self.max_body_size) - return await instance.handle_http(scope, receive, trio.to_thread.run_sync) + await self.wsgi_app(scope, receive, send, partial(loop.run_in_executor, None), _call_soon) -class _WSGIInstance: - def __init__(self, wsgi_app: WSGICallable, max_body_size: int = MAX_BODY_SIZE) -> None: - self.wsgi_app = wsgi_app - self.max_body_size = max_body_size - self.status_code = 500 - self.headers: list = [] - - async def handle_http( - self, scope: HTTPScope, receive: Callable, spawn: Callable - ) -> Tuple[int, list, bytes]: - self.scope = scope - body = bytearray() - while True: - message = await receive() - body.extend(message.get("body", b"")) - if len(body) > self.max_body_size: - return 400, [], b"" - if not message.get("more_body"): - break - return await spawn(self.run_wsgi_app, body) - - def _start_response( - self, - status: str, - response_headers: List[Tuple[str, str]], - exc_info: Optional[Exception] = None, +class TrioWSGIMiddleware(_WSGIMiddleware): + async def __call__( + self, scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable ) -> None: - raw, _ = status.split(" ", 1) - self.status_code = int(raw) - self.headers = [ - (name.lower().encode("ascii"), value.encode("ascii")) - for name, value in response_headers - ] - - def run_wsgi_app(self, body: bytes) -> Tuple[int, list, bytes]: - try: - environ = _build_environ(self.scope, body) - except InvalidPathError: - return 404, self.headers, b"" - else: - body = bytearray() - for output in self.wsgi_app(environ, self._start_response): - body.extend(output) - return self.status_code, self.headers, body - - -def _build_environ(scope: HTTPScope, body: bytes) -> dict: - server = scope.get("server") or ("localhost", 80) - path = scope["path"] - script_name = scope.get("root_path", "") - if path.startswith(script_name): - path = path[len(script_name) :] - path = path if path != "" else "/" - else: - raise InvalidPathError() - - environ = { - "REQUEST_METHOD": scope["method"], - "SCRIPT_NAME": script_name.encode("utf8").decode("latin1"), - "PATH_INFO": path.encode("utf8").decode("latin1"), - "QUERY_STRING": scope["query_string"].decode("ascii"), - "SERVER_NAME": server[0], - "SERVER_PORT": server[1], - "SERVER_PROTOCOL": "HTTP/%s" % scope["http_version"], - "wsgi.version": (1, 0), - "wsgi.url_scheme": scope.get("scheme", "http"), - "wsgi.input": BytesIO(body), - "wsgi.errors": BytesIO(), - "wsgi.multithread": True, - "wsgi.multiprocess": True, - "wsgi.run_once": False, - } - - if "client" in scope: - environ["REMOTE_ADDR"] = scope["client"][0] - - for raw_name, raw_value in scope.get("headers", []): - name = raw_name.decode("latin1") - if name == "content-length": - corrected_name = "CONTENT_LENGTH" - elif name == "content-type": - corrected_name = "CONTENT_TYPE" - else: - corrected_name = "HTTP_%s" % name.upper().replace("-", "_") - # HTTPbis say only ASCII chars are allowed in headers, but we latin1 just in case - value = raw_value.decode("latin1") - if corrected_name in environ: - value = environ[corrected_name] + "," + value # type: ignore - environ[corrected_name] = value - return environ + import trio + + await self.wsgi_app(scope, receive, send, trio.to_thread.run_sync, trio.from_thread.run) diff --git a/src/hypercorn/protocol/__init__.py b/src/hypercorn/protocol/__init__.py old mode 100755 new mode 100644 index 794ad7e..4e8feae --- a/src/hypercorn/protocol/__init__.py +++ b/src/hypercorn/protocol/__init__.py @@ -6,16 +6,17 @@ from .h11 import H2CProtocolRequiredError, H2ProtocolAssumedError, H11Protocol from ..config import Config from ..events import Event, RawData -from ..typing import ASGIFramework, TaskGroup, WorkerContext +from ..typing import AppWrapper, ConnectionState, TaskGroup, WorkerContext class ProtocolWrapper: def __init__( self, - app: ASGIFramework, + app: AppWrapper, config: Config, context: WorkerContext, task_group: TaskGroup, + state: ConnectionState, ssl: bool, client: Optional[Tuple[str, int]], server: Optional[Tuple[str, int]], @@ -30,6 +31,7 @@ def __init__( self.client = client self.server = server self.send = send + self.state = state self.protocol: Union[H11Protocol, H2Protocol] if alpn_protocol == "h2": self.protocol = H2Protocol( @@ -37,6 +39,7 @@ def __init__( self.config, self.context, self.task_group, + self.state, self.ssl, self.client, self.server, @@ -48,6 +51,7 @@ def __init__( self.config, self.context, self.task_group, + self.state, self.ssl, self.client, self.server, @@ -66,6 +70,7 @@ async def handle(self, event: Event) -> None: self.config, self.context, self.task_group, + self.state, self.ssl, self.client, self.server, @@ -80,6 +85,7 @@ async def handle(self, event: Event) -> None: self.config, self.context, self.task_group, + self.state, self.ssl, self.client, self.server, diff --git a/src/hypercorn/protocol/events.py b/src/hypercorn/protocol/events.py index 7b39e9a..0ded34a 100644 --- a/src/hypercorn/protocol/events.py +++ b/src/hypercorn/protocol/events.py @@ -3,6 +3,8 @@ from dataclasses import dataclass from typing import List, Tuple +from hypercorn.typing import ConnectionState + @dataclass(frozen=True) class Event: @@ -15,6 +17,7 @@ class Request(Event): http_version: str method: str raw_path: bytes + state: ConnectionState @dataclass(frozen=True) @@ -27,6 +30,11 @@ class EndBody(Event): pass +@dataclass(frozen=True) +class Trailers(Event): + headers: List[Tuple[bytes, bytes]] + + @dataclass(frozen=True) class Data(Event): data: bytes @@ -43,6 +51,16 @@ class Response(Event): status_code: int +@dataclass(frozen=True) +class InformationalResponse(Event): + headers: List[Tuple[bytes, bytes]] + status_code: int + + def __post_init__(self) -> None: + if self.status_code >= 200 or self.status_code < 100: + raise ValueError(f"Status code must be 1XX not {self.status_code}") + + @dataclass(frozen=True) class StreamClosed(Event): pass diff --git a/src/hypercorn/protocol/h11.py b/src/hypercorn/protocol/h11.py old mode 100755 new mode 100644 index c8636bf..c3c6e0f --- a/src/hypercorn/protocol/h11.py +++ b/src/hypercorn/protocol/h11.py @@ -1,7 +1,7 @@ from __future__ import annotations from itertools import chain -from typing import Awaitable, Callable, Optional, Tuple, Union +from typing import Awaitable, Callable, cast, Optional, Tuple, Type, Union import h11 @@ -11,6 +11,7 @@ EndBody, EndData, Event as StreamEvent, + InformationalResponse, Request, Response, StreamClosed, @@ -19,7 +20,7 @@ from .ws_stream import WSStream from ..config import Config from ..events import Closed, Event, RawData, Updated -from ..typing import ASGIFramework, H11SendableEvent, TaskGroup, WorkerContext +from ..typing import AppWrapper, ConnectionState, H11SendableEvent, TaskGroup, WorkerContext STREAM_ID = 1 @@ -51,6 +52,8 @@ class H11WSConnection: # events (Response, Body, EndBody). our_state = None # Prevents recycling the connection they_are_waiting_for_100_continue = False + their_state = None + trailing_data = (b"", False) def __init__(self, h11_connection: h11.Connection) -> None: self.buffer = bytearray(h11_connection.trailing_data[0]) @@ -59,7 +62,7 @@ def __init__(self, h11_connection: h11.Connection) -> None: def receive_data(self, data: bytes) -> None: self.buffer.extend(data) - def next_event(self) -> Data: + def next_event(self) -> Union[Data, Type[h11.NEED_DATA]]: if self.buffer: event = Data(stream_id=STREAM_ID, data=bytes(self.buffer)) self.buffer = bytearray() @@ -70,14 +73,18 @@ def next_event(self) -> Data: def send(self, event: H11SendableEvent) -> bytes: return self.h11_connection.send(event) + def start_next_cycle(self) -> None: + pass + class H11Protocol: def __init__( self, - app: ASGIFramework, + app: AppWrapper, config: Config, context: WorkerContext, task_group: TaskGroup, + connection_state: ConnectionState, ssl: bool, client: Optional[Tuple[str, int]], server: Optional[Tuple[str, int]], @@ -87,15 +94,17 @@ def __init__( self.can_read = context.event_class() self.client = client self.config = config - self.connection = h11.Connection( + self.connection: Union[h11.Connection, H11WSConnection] = h11.Connection( h11.SERVER, max_incomplete_event_size=self.config.h11_max_incomplete_size ) self.context = context + self.keep_alive_requests = 0 self.send = send self.server = server self.ssl = ssl self.stream: Optional[Union[HTTPStream, WSStream]] = None self.task_group = task_group + self.connection_state = connection_state async def initiate(self) -> None: pass @@ -111,19 +120,24 @@ async def handle(self, event: Event) -> None: async def stream_send(self, event: StreamEvent) -> None: if isinstance(event, Response): if event.status_code >= 200: + headers = list(chain(event.headers, self.config.response_headers("h11"))) + if self.keep_alive_requests >= self.config.keep_alive_max_requests: + headers.append((b"connection", b"close")) await self._send_h11_event( h11.Response( - headers=chain(event.headers, self.config.response_headers("h11")), + headers=headers, status_code=event.status_code, ) ) else: await self._send_h11_event( h11.InformationalResponse( - headers=chain(event.headers, self.config.response_headers("h11")), + headers=list(chain(event.headers, self.config.response_headers("h11"))), status_code=event.status_code, ) ) + elif isinstance(event, InformationalResponse): + pass # Ignore for HTTP/1 elif isinstance(event, Body): await self._send_h11_event(h11.Data(data=event.data)) elif isinstance(event, EndBody): @@ -146,16 +160,16 @@ async def _handle_events(self) -> None: try: event = self.connection.next_event() - except h11.RemoteProtocolError: + except h11.RemoteProtocolError as error: if self.connection.our_state in {h11.IDLE, h11.SEND_RESPONSE}: - await self._send_error_response(400) + await self._send_error_response(error.error_status_hint) await self.send(Closed()) break else: if isinstance(event, h11.Request): + await self.send(Updated(idle=False)) await self._check_protocol(event) await self._create_stream(event) - await self.send(Updated(idle=False)) elif event is h11.PAUSED: await self.can_read.clear() await self.can_read.wait() @@ -198,7 +212,7 @@ async def _create_stream(self, request: h11.Request) -> None: self.stream_send, STREAM_ID, ) - self.connection = H11WSConnection(self.connection) + self.connection = H11WSConnection(cast(h11.Connection, self.connection)) else: self.stream = HTTPStream( self.app, @@ -211,15 +225,24 @@ async def _create_stream(self, request: h11.Request) -> None: self.stream_send, STREAM_ID, ) + + if self.config.h11_pass_raw_headers: + headers = request.headers.raw_items() + else: + headers = list(request.headers) + await self.stream.handle( Request( stream_id=STREAM_ID, - headers=request.headers, + headers=headers, http_version=request.http_version.decode(), method=request.method.decode("ascii").upper(), raw_path=request.target, + state=self.connection_state, ) ) + self.keep_alive_requests += 1 + await self.context.mark_request() async def _send_h11_event(self, event: H11SendableEvent) -> None: try: @@ -234,9 +257,11 @@ async def _send_error_response(self, status_code: int) -> None: await self._send_h11_event( h11.Response( status_code=status_code, - headers=chain( - [(b"content-length", b"0"), (b"connection", b"close")], - self.config.response_headers("h11"), + headers=list( + chain( + [(b"content-length", b"0"), (b"connection", b"close")], + self.config.response_headers("h11"), + ) ), ) ) @@ -245,7 +270,7 @@ async def _send_error_response(self, status_code: int) -> None: async def _maybe_recycle(self) -> None: await self._close_stream() if ( - not self.context.terminated + not self.context.terminated.is_set() and self.connection.our_state is h11.DONE and self.connection.their_state is h11.DONE ): diff --git a/src/hypercorn/protocol/h2.py b/src/hypercorn/protocol/h2.py old mode 100755 new mode 100644 index c743848..b19a2bc --- a/src/hypercorn/protocol/h2.py +++ b/src/hypercorn/protocol/h2.py @@ -14,18 +14,20 @@ EndBody, EndData, Event as StreamEvent, + InformationalResponse, Request, Response, StreamClosed, + Trailers, ) from .http_stream import HTTPStream from .ws_stream import WSStream from ..config import Config from ..events import Closed, Event, RawData, Updated -from ..typing import ASGIFramework, Event as IOEvent, TaskGroup, WorkerContext +from ..typing import AppWrapper, ConnectionState, Event as IOEvent, TaskGroup, WorkerContext from ..utils import filter_pseudo_headers -BUFFER_HIGH_WATER = 2 * 2 ** 14 # Twice the default max frame size (two frames worth) +BUFFER_HIGH_WATER = 2 * 2**14 # Twice the default max frame size (two frames worth) BUFFER_LOW_WATER = BUFFER_HIGH_WATER / 2 @@ -79,10 +81,11 @@ async def pop(self, max_length: int) -> bytes: class H2Protocol: def __init__( self, - app: ASGIFramework, + app: AppWrapper, config: Config, context: WorkerContext, task_group: TaskGroup, + connection_state: ConnectionState, ssl: bool, client: Optional[Tuple[str, int]], server: Optional[Tuple[str, int]], @@ -94,6 +97,7 @@ def __init__( self.config = config self.context = context self.task_group = task_group + self.connection_state = connection_state self.connection = h2.connection.H2Connection( config=h2.config.H2Configuration(client_side=False, header_encoding=None) @@ -108,6 +112,7 @@ def __init__( }, ) + self.keep_alive_requests = 0 self.send = send self.server = server self.ssl = ssl @@ -194,7 +199,7 @@ async def handle(self, event: Event) -> None: async def stream_send(self, event: StreamEvent) -> None: try: - if isinstance(event, Response): + if isinstance(event, (InformationalResponse, Response)): self.connection.send_headers( event.stream_id, [(b":status", b"%d" % event.status_code)] @@ -211,12 +216,15 @@ async def stream_send(self, event: StreamEvent) -> None: self.priority.unblock(event.stream_id) await self.has_data.set() await self.stream_buffers[event.stream_id].drain() + elif isinstance(event, Trailers): + self.connection.send_headers(event.stream_id, event.headers) + await self._flush() elif isinstance(event, StreamClosed): await self._close_stream(event.stream_id) idle = len(self.streams) == 0 or all( stream.idle for stream in self.streams.values() ) - if idle and self.context.terminated: + if idle and self.context.terminated.is_set(): self.connection.close_connection() await self._flush() await self.send(Updated(idle=idle)) @@ -235,7 +243,7 @@ async def stream_send(self, event: StreamEvent) -> None: async def _handle_events(self, events: List[h2.events.Event]) -> None: for event in events: if isinstance(event, h2.events.RequestReceived): - if self.context.terminated: + if self.context.terminated.is_set(): self.connection.reset_stream(event.stream_id) self.connection.update_settings( {h2.settings.SettingCodes.MAX_CONCURRENT_STREAMS: 0} @@ -243,6 +251,9 @@ async def _handle_events(self, events: List[h2.events.Event]) -> None: else: await self._create_stream(event) await self.send(Updated(idle=False)) + + if self.keep_alive_requests > self.config.keep_alive_max_requests: + self.connection.close_connection() elif isinstance(event, h2.events.DataReceived): await self.streams[event.stream_id].handle( Body(stream_id=event.stream_id, data=event.data) @@ -251,7 +262,12 @@ async def _handle_events(self, events: List[h2.events.Event]) -> None: event.flow_controlled_length, event.stream_id ) elif isinstance(event, h2.events.StreamEnded): - await self.streams[event.stream_id].handle(EndBody(stream_id=event.stream_id)) + try: + await self.streams[event.stream_id].handle(EndBody(stream_id=event.stream_id)) + except KeyError: + # Response sent before full request received, + # nothing to do already closed. + pass elif isinstance(event, h2.events.StreamReset): await self._close_stream(event.stream_id) await self._window_updated(event.stream_id) @@ -346,8 +362,11 @@ async def _create_stream(self, request: h2.events.RequestReceived) -> None: http_version="2", method=method, raw_path=raw_path, + state=self.connection_state, ) ) + self.keep_alive_requests += 1 + await self.context.mark_request() async def _create_server_push( self, stream_id: int, path: bytes, headers: List[Tuple[bytes, bytes]] @@ -373,6 +392,7 @@ async def _create_server_push( event.headers = request_headers await self._create_stream(event) await self.streams[event.stream_id].handle(EndBody(stream_id=event.stream_id)) + self.keep_alive_requests += 1 async def _close_stream(self, stream_id: int) -> None: if stream_id in self.streams: diff --git a/src/hypercorn/protocol/h3.py b/src/hypercorn/protocol/h3.py index c9e3157..ae2eb8f 100644 --- a/src/hypercorn/protocol/h3.py +++ b/src/hypercorn/protocol/h3.py @@ -14,24 +14,27 @@ EndBody, EndData, Event as StreamEvent, + InformationalResponse, Request, Response, StreamClosed, + Trailers, ) from .http_stream import HTTPStream from .ws_stream import WSStream from ..config import Config -from ..typing import ASGIFramework, TaskGroup, WorkerContext +from ..typing import AppWrapper, ConnectionState, TaskGroup, WorkerContext from ..utils import filter_pseudo_headers class H3Protocol: def __init__( self, - app: ASGIFramework, + app: AppWrapper, config: Config, context: WorkerContext, task_group: TaskGroup, + state: ConnectionState, client: Optional[Tuple[str, int]], server: Optional[Tuple[str, int]], quic: QuicConnection, @@ -46,11 +49,12 @@ def __init__( self.server = server self.streams: Dict[int, Union[HTTPStream, WSStream]] = {} self.task_group = task_group + self.state = state async def handle(self, quic_event: QuicEvent) -> None: for event in self.connection.handle_event(quic_event): if isinstance(event, HeadersReceived): - if not self.context.terminated: + if not self.context.terminated.is_set(): await self._create_stream(event) if event.stream_ended: await self.streams[event.stream_id].handle( @@ -64,7 +68,7 @@ async def handle(self, quic_event: QuicEvent) -> None: await self.streams[event.stream_id].handle(EndBody(stream_id=event.stream_id)) async def stream_send(self, event: StreamEvent) -> None: - if isinstance(event, Response): + if isinstance(event, (InformationalResponse, Response)): self.connection.send_headers( event.stream_id, [(b":status", b"%d" % event.status_code)] @@ -78,6 +82,9 @@ async def stream_send(self, event: StreamEvent) -> None: elif isinstance(event, (EndBody, EndData)): self.connection.send_data(event.stream_id, b"", True) await self.send() + elif isinstance(event, Trailers): + self.connection.send_headers(event.stream_id, event.headers) + await self.send() elif isinstance(event, StreamClosed): pass # ?? elif isinstance(event, Request): @@ -122,8 +129,10 @@ async def _create_stream(self, request: HeadersReceived) -> None: http_version="3", method=method, raw_path=raw_path, + state=self.state, ) ) + await self.context.mark_request() async def _create_server_push( self, stream_id: int, path: bytes, headers: List[Tuple[bytes, bytes]] diff --git a/src/hypercorn/protocol/http_stream.py b/src/hypercorn/protocol/http_stream.py index c6bdcb5..7ffac1d 100644 --- a/src/hypercorn/protocol/http_stream.py +++ b/src/hypercorn/protocol/http_stream.py @@ -5,10 +5,19 @@ from typing import Awaitable, Callable, Optional, Tuple from urllib.parse import unquote -from .events import Body, EndBody, Event, Request, Response, StreamClosed +from .events import ( + Body, + EndBody, + Event, + InformationalResponse, + Request, + Response, + StreamClosed, + Trailers, +) from ..config import Config from ..typing import ( - ASGIFramework, + AppWrapper, ASGISendEvent, HTTPResponseStartEvent, HTTPScope, @@ -22,7 +31,9 @@ valid_server_name, ) +TRAILERS_VERSIONS = {"2", "3"} PUSH_VERSIONS = {"2", "3"} +EARLY_HINTS_VERSIONS = {"2", "3"} class ASGIHTTPState(Enum): @@ -31,13 +42,14 @@ class ASGIHTTPState(Enum): # state tracking is required. REQUEST = auto() RESPONSE = auto() + TRAILERS = auto() CLOSED = auto() class HTTPStream: def __init__( self, - app: ASGIFramework, + app: AppWrapper, config: Config, context: WorkerContext, task_group: TaskGroup, @@ -75,7 +87,7 @@ async def handle(self, event: Event) -> None: self.scope = { "type": "http", "http_version": event.http_version, - "asgi": {"spec_version": "2.1"}, + "asgi": {"spec_version": "2.1", "version": "3.0"}, "method": event.method, "scheme": self.scheme, "path": unquote(path.decode("ascii")), @@ -85,11 +97,19 @@ async def handle(self, event: Event) -> None: "headers": event.headers, "client": self.client, "server": self.server, + "state": event.state, "extensions": {}, } + + if event.http_version in TRAILERS_VERSIONS: + self.scope["extensions"]["http.response.trailers"] = {} + if event.http_version in PUSH_VERSIONS: self.scope["extensions"]["http.response.push"] = {} + if event.http_version in EARLY_HINTS_VERSIONS: + self.scope["extensions"]["http.response.early_hint"] = {} + if valid_server_name(self.config, event): self.app_put = await self.task_group.spawn_app( self.app, self.config, self.scope, self.app_send @@ -106,22 +126,30 @@ async def handle(self, event: Event) -> None: await self.app_put({"type": "http.request", "body": b"", "more_body": False}) elif isinstance(event, StreamClosed): self.closed = True + if self.state != ASGIHTTPState.CLOSED: + await self.config.log.access(self.scope, None, time() - self.start_time) if self.app_put is not None: - await self.app_put({"type": "http.disconnect"}) # type: ignore + await self.app_put({"type": "http.disconnect"}) async def app_send(self, message: Optional[ASGISendEvent]) -> None: - if self.closed: - # Allow app to finish after close - return - if message is None: # ASGI App has finished sending messages - # Cleanup if required - if self.state == ASGIHTTPState.REQUEST: - await self._send_error_response(500) - await self.send(StreamClosed(stream_id=self.stream_id)) + if not self.closed: + # Cleanup if required + if self.state == ASGIHTTPState.REQUEST: + await self._send_error_response(500) + await self.send(StreamClosed(stream_id=self.stream_id)) else: if message["type"] == "http.response.start" and self.state == ASGIHTTPState.REQUEST: self.response = message + headers = build_and_validate_headers(self.response.get("headers", [])) + await self.send( + Response( + stream_id=self.stream_id, + headers=headers, + status_code=int(self.response["status"]), + ) + ) + self.state = ASGIHTTPState.RESPONSE elif ( message["type"] == "http.response.push" and self.scope["http_version"] in PUSH_VERSIONS @@ -140,23 +168,23 @@ async def app_send(self, message: Optional[ASGISendEvent]) -> None: http_version=self.scope["http_version"], method="GET", raw_path=message["path"].encode(), + state=self.scope["state"], ) ) - elif message["type"] == "http.response.body" and self.state in { - ASGIHTTPState.REQUEST, - ASGIHTTPState.RESPONSE, - }: - if self.state == ASGIHTTPState.REQUEST: - headers = build_and_validate_headers(self.response.get("headers", [])) - await self.send( - Response( - stream_id=self.stream_id, - headers=headers, - status_code=int(self.response["status"]), - ) + elif ( + message["type"] == "http.response.early_hint" + and self.scope["http_version"] in EARLY_HINTS_VERSIONS + and self.state == ASGIHTTPState.REQUEST + ): + headers = [(b"link", bytes(link).strip()) for link in message["links"]] + await self.send( + InformationalResponse( + stream_id=self.stream_id, + headers=headers, + status_code=103, ) - self.state = ASGIHTTPState.RESPONSE - + ) + elif message["type"] == "http.response.body" and self.state == ASGIHTTPState.RESPONSE: if ( not suppress_body(self.scope["method"], int(self.response["status"])) and message.get("body", b"") != b"" @@ -166,16 +194,58 @@ async def app_send(self, message: Optional[ASGISendEvent]) -> None: ) if not message.get("more_body", False): - if self.state != ASGIHTTPState.CLOSED: - self.state = ASGIHTTPState.CLOSED - await self.config.log.access( - self.scope, self.response, time() - self.start_time + if self.response.get("trailers", False): + self.state = ASGIHTTPState.TRAILERS + else: + await self._send_closed() + elif ( + message["type"] == "http.response.trailers" + and self.scope["http_version"] in TRAILERS_VERSIONS + and self.state == ASGIHTTPState.REQUEST + ): + for name, value in self.scope["headers"]: + if name == b"te" and value == b"trailers": + headers = build_and_validate_headers(message["headers"]) + self.response = { + "type": "http.response.start", + "status": 200, + "headers": headers, + } + await self.send( + Response( + stream_id=self.stream_id, + headers=headers, + status_code=200, + ) ) - await self.send(EndBody(stream_id=self.stream_id)) - await self.send(StreamClosed(stream_id=self.stream_id)) + self.state = ASGIHTTPState.TRAILERS + break + + if not message.get("more_trailers", False): + await self._send_closed() + + elif ( + message["type"] == "http.response.trailers" + and self.scope["http_version"] in TRAILERS_VERSIONS + and self.state == ASGIHTTPState.TRAILERS + ): + for name, value in self.scope["headers"]: + if name == b"te" and value == b"trailers": + headers = build_and_validate_headers(message["headers"]) + await self.send(Trailers(stream_id=self.stream_id, headers=headers)) + break + + if not message.get("more_trailers", False): + await self._send_closed() else: raise UnexpectedMessageError(self.state, message["type"]) + async def _send_closed(self) -> None: + await self.send(EndBody(stream_id=self.stream_id)) + self.state = ASGIHTTPState.CLOSED + await self.config.log.access(self.scope, self.response, time() - self.start_time) + await self.send(StreamClosed(stream_id=self.stream_id)) + async def _send_error_response(self, status_code: int) -> None: await self.send( Response( diff --git a/src/hypercorn/protocol/quic.py b/src/hypercorn/protocol/quic.py index b6676af..40625a6 100644 --- a/src/hypercorn/protocol/quic.py +++ b/src/hypercorn/protocol/quic.py @@ -1,7 +1,8 @@ from __future__ import annotations +from dataclasses import dataclass from functools import partial -from typing import Awaitable, Callable, Dict, Optional, Tuple +from typing import Awaitable, Callable, Dict, Optional, Set, Tuple from aioquic.buffer import Buffer from aioquic.h3.connection import H3_ALPN @@ -22,34 +23,43 @@ from .h3 import H3Protocol from ..config import Config from ..events import Closed, Event, RawData -from ..typing import ASGIFramework, TaskGroup, WorkerContext +from ..typing import AppWrapper, ConnectionState, SingleTask, TaskGroup, WorkerContext + + +@dataclass +class _Connection: + cids: Set[bytes] + quic: QuicConnection + task: SingleTask + h3: Optional[H3Protocol] = None class QuicProtocol: def __init__( self, - app: ASGIFramework, + app: AppWrapper, config: Config, context: WorkerContext, task_group: TaskGroup, + state: ConnectionState, server: Optional[Tuple[str, int]], send: Callable[[Event], Awaitable[None]], ) -> None: self.app = app self.config = config self.context = context - self.connections: Dict[bytes, QuicConnection] = {} - self.http_connections: Dict[QuicConnection, H3Protocol] = {} + self.connections: Dict[bytes, _Connection] = {} self.send = send self.server = server self.task_group = task_group + self.state = state self.quic_config = QuicConfiguration(alpn_protocols=H3_ALPN, is_client=False) self.quic_config.load_cert_chain(certfile=config.certfile, keyfile=config.keyfile) @property def idle(self) -> bool: - return len(self.connections) == 0 and len(self.http_connections) == 0 + return len(self.connections) == 0 async def handle(self, event: Event) -> None: if isinstance(event, RawData): @@ -74,61 +84,74 @@ async def handle(self, event: Event) -> None: connection is None and len(event.data) >= 1200 and header.packet_type == PACKET_TYPE_INITIAL + and not self.context.terminated.is_set() ): - connection = QuicConnection( + quic_connection = QuicConnection( configuration=self.quic_config, original_destination_connection_id=header.destination_cid, ) + connection = _Connection( + cids={header.destination_cid, quic_connection.host_cid}, + quic=quic_connection, + task=self.context.single_task_class(), + ) self.connections[header.destination_cid] = connection - self.connections[connection.host_cid] = connection + self.connections[quic_connection.host_cid] = connection if connection is not None: - connection.receive_datagram(event.data, event.address, now=self.context.time()) + connection.quic.receive_datagram(event.data, event.address, now=self.context.time()) await self._handle_events(connection, event.address) elif isinstance(event, Closed): pass - async def send_all(self, connection: QuicConnection) -> None: - for data, address in connection.datagrams_to_send(now=self.context.time()): + async def send_all(self, connection: _Connection) -> None: + for data, address in connection.quic.datagrams_to_send(now=self.context.time()): await self.send(RawData(data=data, address=address)) + timer = connection.quic.get_timer() + if timer is not None: + await connection.task.restart( + self.task_group, partial(self._handle_timer, timer, connection) + ) + async def _handle_events( - self, connection: QuicConnection, client: Optional[Tuple[str, int]] = None + self, connection: _Connection, client: Optional[Tuple[str, int]] = None ) -> None: - event = connection.next_event() + event = connection.quic.next_event() while event is not None: if isinstance(event, ConnectionTerminated): - pass + await connection.task.stop() + for cid in connection.cids: + del self.connections[cid] + connection.cids = set() elif isinstance(event, ProtocolNegotiated): - self.http_connections[connection] = H3Protocol( + connection.h3 = H3Protocol( self.app, self.config, self.context, self.task_group, + self.state, client, self.server, - connection, + connection.quic, partial(self.send_all, connection), ) elif isinstance(event, ConnectionIdIssued): + connection.cids.add(event.connection_id) self.connections[event.connection_id] = connection elif isinstance(event, ConnectionIdRetired): + connection.cids.remove(event.connection_id) del self.connections[event.connection_id] - if connection in self.http_connections: - await self.http_connections[connection].handle(event) + if connection.h3 is not None: + await connection.h3.handle(event) - event = connection.next_event() + event = connection.quic.next_event() await self.send_all(connection) - timer = connection.get_timer() - if timer is not None: - self.task_group.spawn(self._handle_timer, timer, connection) - - async def _handle_timer(self, timer: float, connection: QuicConnection) -> None: + async def _handle_timer(self, timer: float, connection: _Connection) -> None: wait = max(0, timer - self.context.time()) await self.context.sleep(wait) - if connection._close_at is not None: - connection.handle_timer(now=self.context.time()) - await self._handle_events(connection, None) + connection.quic.handle_timer(now=self.context.time()) + await self._handle_events(connection, None) diff --git a/src/hypercorn/protocol/ws_stream.py b/src/hypercorn/protocol/ws_stream.py index 5c670c4..7b39815 100644 --- a/src/hypercorn/protocol/ws_stream.py +++ b/src/hypercorn/protocol/ws_stream.py @@ -18,12 +18,12 @@ from wsproto.extensions import Extension, PerMessageDeflate from wsproto.frame_protocol import CloseReason from wsproto.handshake import server_extensions_handshake, WEBSOCKET_VERSION -from wsproto.utilities import generate_accept_token, split_comma_header +from wsproto.utilities import generate_accept_token, LocalProtocolError, split_comma_header from .events import Body, Data, EndBody, EndData, Event, Request, Response, StreamClosed from ..config import Config from ..typing import ( - ASGIFramework, + AppWrapper, ASGISendEvent, TaskGroup, WebsocketAcceptEvent, @@ -56,6 +56,7 @@ class FrameTooLargeError(Exception): class Handshake: def __init__(self, headers: List[Tuple[bytes, bytes]], http_version: str) -> None: + self.accepted = False self.http_version = http_version self.connection_tokens: Optional[List[str]] = None self.extensions: Optional[List[str]] = None @@ -102,7 +103,7 @@ def accept( ) -> Tuple[int, List[Tuple[bytes, bytes]], Connection]: headers = [] if subprotocol is not None: - if subprotocol not in self.subprotocols: + if self.subprotocols is None or subprotocol not in self.subprotocols: raise Exception("Invalid Subprotocol") else: headers.append((b"sec-websocket-protocol", subprotocol.encode())) @@ -129,6 +130,7 @@ def accept( headers.append((name, value)) + self.accepted = True return status_code, headers, Connection(ConnectionType.SERVER, extensions) @@ -163,7 +165,7 @@ def to_message(self) -> dict: class WSStream: def __init__( self, - app: ASGIFramework, + app: AppWrapper, config: Config, context: WorkerContext, task_group: TaskGroup, @@ -207,7 +209,7 @@ async def handle(self, event: Event) -> None: path, _, query_string = event.raw_path.partition(b"?") self.scope = { "type": "websocket", - "asgi": {"spec_version": "2.3"}, + "asgi": {"spec_version": "2.3", "version": "3.0"}, "scheme": self.scheme, "http_version": event.http_version, "path": unquote(path.decode("ascii")), @@ -217,6 +219,7 @@ async def handle(self, event: Event) -> None: "headers": event.headers, "client": self.client, "server": self.server, + "state": event.state, "subprotocols": self.handshake.subprotocols or [], "extensions": {"websocket.http.response": {}}, } @@ -231,7 +234,10 @@ async def handle(self, event: Event) -> None: self.app_put = await self.task_group.spawn_app( self.app, self.config, self.scope, self.app_send ) - await self.app_put({"type": "websocket.connect"}) # type: ignore + await self.app_put({"type": "websocket.connect"}) + elif isinstance(event, (Body, Data)) and not self.handshake.accepted: + await self._send_error_response(400) + self.closed = True elif isinstance(event, (Body, Data)): self.connection.receive_data(event.data) await self._handle_events() @@ -257,7 +263,7 @@ async def app_send(self, message: Optional[ASGISendEvent]) -> None: self.scope, {"status": 500, "headers": []}, time() - self.start_time ) elif self.state == ASGIWebsocketState.CONNECTED: - await self._send_wsproto_event(CloseConnection(code=CloseReason.ABNORMAL_CLOSURE)) + await self._send_wsproto_event(CloseConnection(code=CloseReason.INTERNAL_ERROR)) await self.send(StreamClosed(stream_id=self.stream_id)) else: if message["type"] == "websocket.accept" and self.state == ASGIWebsocketState.HANDSHAKE: @@ -333,8 +339,12 @@ async def _send_error_response(self, status_code: int) -> None: ) async def _send_wsproto_event(self, event: WSProtoEvent) -> None: - data = self.connection.send(event) - await self.send(Data(stream_id=self.stream_id, data=data)) + try: + data = self.connection.send(event) + except LocalProtocolError: + pass + else: + await self.send(Data(stream_id=self.stream_id, data=data)) async def _accept(self, message: WebsocketAcceptEvent) -> None: self.state = ASGIWebsocketState.CONNECTED diff --git a/src/hypercorn/run.py b/src/hypercorn/run.py index 4dd067e..2589c52 100644 --- a/src/hypercorn/run.py +++ b/src/hypercorn/run.py @@ -2,16 +2,22 @@ import platform import signal +import threading import time -from multiprocessing import Event, Process -from typing import Any - -from .config import Config +from multiprocessing import get_context +from multiprocessing.connection import wait +from multiprocessing.context import BaseContext +from multiprocessing.process import BaseProcess +from multiprocessing.synchronize import Event as EventType +from pickle import PicklingError +from typing import Any, List, Union + +from .config import Config, Sockets from .typing import WorkerFunc -from .utils import write_pid_file +from .utils import check_for_updates, files_to_watch, load_application, write_pid_file -def run(config: Config) -> None: +def run(config: Config) -> int: if config.pid_path is not None: write_pid_file(config.pid_path) @@ -31,51 +37,142 @@ def run(config: Config) -> None: else: raise ValueError(f"No worker of class {config.worker_class} exists") - if config.workers == 1: - worker_func(config) - else: - run_multiple(config, worker_func) - - -def run_multiple(config: Config, worker_func: WorkerFunc) -> None: - if config.use_reloader: - raise RuntimeError("Reloader can only be used with a single worker") - sockets = config.create_sockets() - processes = [] - - # Ignore SIGINT before creating the processes, so that they - # inherit the signal handling. This means that the shutdown - # function controls the shutdown. - signal.signal(signal.SIGINT, signal.SIG_IGN) - - shutdown_event = Event() - - for _ in range(config.workers): - process = Process( - target=worker_func, - kwargs={"config": config, "shutdown_event": shutdown_event, "sockets": sockets}, - ) - process.daemon = True - process.start() - processes.append(process) - if platform.system() == "Windows": - time.sleep(0.1) - - def shutdown(*args: Any) -> None: - shutdown_event.set() - - for signal_name in {"SIGINT", "SIGTERM", "SIGBREAK"}: - if hasattr(signal, signal_name): - signal.signal(getattr(signal, signal_name), shutdown) - - for process in processes: - process.join() - for process in processes: - process.terminate() - - for sock in sockets.secure_sockets: - sock.close() - for sock in sockets.insecure_sockets: - sock.close() + if config.use_reloader and config.workers == 0: + raise RuntimeError("Cannot reload without workers") + + exitcode = 0 + if config.workers == 0: + worker_func(config, sockets) + else: + if config.use_reloader: + # Load the application so that the correct paths are checked for + # changes, but only when the reloader is being used. + load_application(config.application_path, config.wsgi_max_body_size) + + active = True + if config.worker_type == "process": + ctx = get_context("spawn") + shutdown_event = ctx.Event() + def shutdown(*args: Any) -> None: + nonlocal active, shutdown_event + shutdown_event.set() + active = False + else: + ctx = None # multithreading mode does not need a context + shutdown_event = threading.Event() + def shutdown(*args: Any) -> None: + nonlocal active, shutdown_event + shutdown_event.set() + active = False + + processes: List[Union[BaseProcess, threading.Thread]] = [] + while active: + # Ignore SIGINT before creating the processes, so that they + # inherit the signal handling. This means that the shutdown + # function controls the shutdown. + signal.signal(signal.SIGINT, signal.SIG_IGN) + + _populate(processes, config, worker_func, sockets, shutdown_event, ctx) + + for signal_name in {"SIGINT", "SIGTERM", "SIGBREAK"}: + if hasattr(signal, signal_name): + signal.signal(getattr(signal, signal_name), shutdown) + + if config.use_reloader: + files = files_to_watch() + if config.worker_type == "process": + while True: + finished = wait((process.sentinel for process in processes), timeout=1) + updated = check_for_updates(files) + if updated: + shutdown_event.set() + for process in processes: + process.join() + shutdown_event.clear() + break + if len(finished) > 0: + break + else: + raise RuntimeError("Reloading not supported with threads") + else: + if config.worker_type == "process": + wait(process.sentinel for process in processes) + else: + while True: + time.sleep(0.1) + if any(not process.is_alive() for process in processes): + break + + exitcode = _join_exited(processes) + if exitcode != 0: + shutdown_event.set() + active = False + + for process in processes: + if isinstance(process, BaseProcess): + process.terminate() + + exitcode = _join_exited(processes) if exitcode != 0 else exitcode + + for sock in sockets.secure_sockets: + sock.close() + + for sock in sockets.insecure_sockets: + sock.close() + + return exitcode + + +def _populate( + processes: List[Union[BaseProcess, threading.Thread]], + config: Config, + worker_func: WorkerFunc, + sockets: Sockets, + shutdown_event: EventType, + ctx: BaseContext, +) -> None: + if config.worker_type == "process": + for _ in range(config.workers - len(processes)): + process = ctx.Process( # type: ignore + target=worker_func, + kwargs={"config": config, "shutdown_event": shutdown_event, "sockets": sockets}, + ) + process.daemon = True + try: + process.start() + except PicklingError as error: + raise RuntimeError( + "Cannot pickle the config, see https://docs.python.org/3/library/pickle.html#pickle-picklable" # noqa: E501 + ) from error + processes.append(process) + if platform.system() == "Windows": + time.sleep(0.1) + else: + for _ in range(config.workers - len(processes)): + thread = threading.Thread( + target=worker_func, + kwargs={"config": config, "shutdown_event": shutdown_event, "sockets": sockets}, + ) + thread.daemon = True + thread.start() + processes.append(thread) + if platform.system() == "Windows": + time.sleep(0.1) # let's simulate the same behavior as processes, in case something wrong happens + + +def _join_exited(processes: List[Union[BaseProcess, threading.Thread]]) -> int: + exitcode = 0 + for index in reversed(range(len(processes))): + worker = processes[index] + if isinstance(worker, BaseProcess): + if worker.exitcode is not None: + worker.join() + exitcode = worker.exitcode if exitcode == 0 else exitcode + del processes[index] + else: + if worker.is_alive(): + worker.join() + del processes[index] + return exitcode diff --git a/src/hypercorn/statsd.py b/src/hypercorn/statsd.py index 1418c1f..58e2cde 100644 --- a/src/hypercorn/statsd.py +++ b/src/hypercorn/statsd.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any, TYPE_CHECKING +from typing import Any, Optional, TYPE_CHECKING from .logging import Logger @@ -30,11 +30,11 @@ async def critical(self, message: str, *args: Any, **kwargs: Any) -> None: async def error(self, message: str, *args: Any, **kwargs: Any) -> None: await super().error(message, *args, **kwargs) - self.increment("hypercorn.log.error", 1) + await self.increment("hypercorn.log.error", 1) async def warning(self, message: str, *args: Any, **kwargs: Any) -> None: await super().warning(message, *args, **kwargs) - self.increment("hypercorn.log.warning", 1) + await self.increment("hypercorn.log.warning", 1) async def info(self, message: str, *args: Any, **kwargs: Any) -> None: await super().info(message, *args, **kwargs) @@ -67,12 +67,13 @@ async def log(self, level: int, message: str, *args: Any, **kwargs: Any) -> None await super().warning("Failed to log to statsd", exc_info=True) async def access( - self, request: "WWWScope", response: "ResponseSummary", request_time: float + self, request: "WWWScope", response: Optional["ResponseSummary"], request_time: float ) -> None: await super().access(request, response, request_time) await self.histogram("hypercorn.request.duration", request_time * 1_000) await self.increment("hypercorn.requests", 1) - await self.increment(f"hypercorn.request.status.{response['status']}", 1) + if response is not None: + await self.increment(f"hypercorn.request.status.{response['status']}", 1) async def gauge(self, name: str, value: int) -> None: await self._send(f"{self.prefix}{name}:{value}|g") diff --git a/src/hypercorn/trio/__init__.py b/src/hypercorn/trio/__init__.py index a0aa291..0795706 100644 --- a/src/hypercorn/trio/__init__.py +++ b/src/hypercorn/trio/__init__.py @@ -1,21 +1,23 @@ from __future__ import annotations import warnings -from typing import Awaitable, Callable, Optional +from typing import Awaitable, Callable, Literal, Optional import trio from .run import worker_serve from ..config import Config -from ..typing import ASGIFramework +from ..typing import Framework +from ..utils import wrap_app async def serve( - app: ASGIFramework, + app: Framework, config: Config, *, shutdown_trigger: Optional[Callable[..., Awaitable[None]]] = None, - task_status: trio._core._run._TaskStatus = trio.TASK_STATUS_IGNORED, + task_status: trio.TaskStatus = trio.TASK_STATUS_IGNORED, + mode: Optional[Literal["asgi", "wsgi"]] = None, ) -> None: """Serve an ASGI framework app given the config. @@ -35,10 +37,16 @@ async def serve( config: A Hypercorn configuration object. shutdown_trigger: This should return to trigger a graceful shutdown. + mode: Specify if the app is WSGI or ASGI. """ if config.debug: warnings.warn("The config `debug` has no affect when using serve", Warning) if config.workers != 1: warnings.warn("The config `workers` has no affect when using serve", Warning) - await worker_serve(app, config, shutdown_trigger=shutdown_trigger, task_status=task_status) + await worker_serve( + wrap_app(app, config.wsgi_max_body_size, mode), + config, + shutdown_trigger=shutdown_trigger, + task_status=task_status, + ) diff --git a/src/hypercorn/trio/lifespan.py b/src/hypercorn/trio/lifespan.py index 94d4780..087aa83 100644 --- a/src/hypercorn/trio/lifespan.py +++ b/src/hypercorn/trio/lifespan.py @@ -1,10 +1,15 @@ from __future__ import annotations +import sys + import trio from ..config import Config -from ..typing import ASGIFramework, ASGIReceiveEvent, ASGISendEvent, LifespanScope -from ..utils import invoke_asgi, LifespanFailureError, LifespanTimeoutError +from ..typing import AppWrapper, ASGIReceiveEvent, ASGISendEvent, LifespanScope, LifespanState +from ..utils import LifespanFailureError, LifespanTimeoutError + +if sys.version_info < (3, 11): + from exceptiongroup import BaseExceptionGroup class UnexpectedMessageError(Exception): @@ -12,27 +17,42 @@ class UnexpectedMessageError(Exception): class Lifespan: - def __init__(self, app: ASGIFramework, config: Config) -> None: + def __init__(self, app: AppWrapper, config: Config, state: LifespanState) -> None: self.app = app self.config = config self.startup = trio.Event() self.shutdown = trio.Event() - self.app_send_channel, self.app_receive_channel = trio.open_memory_channel( - config.max_app_queue_size - ) + self.app_send_channel, self.app_receive_channel = trio.open_memory_channel[ + ASGIReceiveEvent + ](config.max_app_queue_size) + self.state = state self.supported = True async def handle_lifespan( - self, *, task_status: trio._core._run._TaskStatus = trio.TASK_STATUS_IGNORED + self, *, task_status: trio.TaskStatus = trio.TASK_STATUS_IGNORED ) -> None: task_status.started() - scope: LifespanScope = {"type": "lifespan", "asgi": {"spec_version": "2.0"}} + scope: LifespanScope = { + "type": "lifespan", + "asgi": {"spec_version": "2.0", "version": "3.0"}, + "state": self.state, + } try: - await invoke_asgi(self.app, scope, self.asgi_receive, self.asgi_send) - except LifespanFailureError: - # Lifespan failures should crash the server + await self.app( + scope, + self.asgi_receive, + self.asgi_send, + trio.to_thread.run_sync, + trio.from_thread.run, + ) + except (LifespanFailureError, trio.Cancelled): raise - except Exception: + except (BaseExceptionGroup, Exception) as error: + if isinstance(error, BaseExceptionGroup): + reraise_error = error.subgroup((LifespanFailureError, trio.Cancelled)) + if reraise_error is not None: + raise reraise_error + self.supported = False if not self.startup.is_set(): await self.config.log.warning( @@ -81,8 +101,8 @@ async def asgi_send(self, message: ASGISendEvent) -> None: elif message["type"] == "lifespan.shutdown.complete": self.shutdown.set() elif message["type"] == "lifespan.startup.failed": - raise LifespanFailureError("startup", message["message"]) + raise LifespanFailureError("startup", message.get("message", "")) elif message["type"] == "lifespan.shutdown.failed": - raise LifespanFailureError("shutdown", message["message"]) + raise LifespanFailureError("shutdown", message.get("message", "")) else: raise UnexpectedMessageError(message["type"]) diff --git a/src/hypercorn/trio/run.py b/src/hypercorn/trio/run.py index 4f7a8c2..7c55df1 100644 --- a/src/hypercorn/trio/run.py +++ b/src/hypercorn/trio/run.py @@ -1,7 +1,9 @@ from __future__ import annotations +import sys from functools import partial from multiprocessing.synchronize import Event as EventType +from random import randint from typing import Awaitable, Callable, Optional import trio @@ -12,32 +14,35 @@ from .udp_server import UDPServer from .worker_context import WorkerContext from ..config import Config, Sockets -from ..typing import ASGIFramework +from ..typing import AppWrapper, ConnectionState, LifespanState from ..utils import ( check_multiprocess_shutdown_event, load_application, - MustReloadError, - observe_changes, raise_shutdown, repr_socket_addr, - restart, ShutdownError, ) +if sys.version_info < (3, 11): + from exceptiongroup import BaseExceptionGroup + async def worker_serve( - app: ASGIFramework, + app: AppWrapper, config: Config, *, sockets: Optional[Sockets] = None, shutdown_trigger: Optional[Callable[..., Awaitable[None]]] = None, - task_status: trio._core._run._TaskStatus = trio.TASK_STATUS_IGNORED, + task_status: trio.TaskStatus = trio.TASK_STATUS_IGNORED, ) -> None: config.set_statsd_logger_class(StatsdLogger) - lifespan = Lifespan(app, config) - reload_ = False - context = WorkerContext() + lifespan_state: LifespanState = {} + lifespan = Lifespan(app, config, lifespan_state) + max_requests = None + if config.max_requests is not None: + max_requests = config.max_requests + randint(0, config.max_requests_jitter) + context = WorkerContext(max_requests) async with trio.open_nursery() as lifespan_nursery: await lifespan_nursery.start(lifespan.handle_lifespan) @@ -52,7 +57,7 @@ async def worker_serve( sock.listen(config.backlog) ssl_context = config.create_ssl_context() - listeners = [] + listeners: list[trio.SSLListener[trio.SocketStream] | trio.SocketListener] = [] binds = [] for sock in sockets.secure_sockets: listeners.append( @@ -73,45 +78,48 @@ async def worker_serve( await config.log.info(f"Running on http://{bind} (CTRL + C to quit)") for sock in sockets.quic_sockets: - await server_nursery.start(UDPServer(app, config, context, sock).run) + await server_nursery.start( + UDPServer( + app, config, context, ConnectionState(lifespan_state.copy()), sock + ).run + ) bind = repr_socket_addr(sock.family, sock.getsockname()) await config.log.info(f"Running on https://{bind} (QUIC) (CTRL + C to quit)") task_status.started(binds) try: - async with trio.open_nursery() as nursery: - if config.use_reloader: - nursery.start_soon(observe_changes, trio.sleep) - + async with trio.open_nursery(strict_exception_groups=True) as nursery: if shutdown_trigger is not None: nursery.start_soon(raise_shutdown, shutdown_trigger) + nursery.start_soon(raise_shutdown, context.terminate.wait) nursery.start_soon( partial( trio.serve_listeners, - partial(TCPServer, app, config, context), + partial( + TCPServer, + app, + config, + context, + ConnectionState(lifespan_state.copy()), + ), listeners, handler_nursery=server_nursery, ), ) await trio.sleep_forever() - except trio.MultiError as error: - reload_ = any(isinstance(exc, MustReloadError) for exc in error.exceptions) - except MustReloadError: - reload_ = True - except (ShutdownError, KeyboardInterrupt): - pass + except BaseExceptionGroup as error: + _, other_errors = error.split((ShutdownError, KeyboardInterrupt)) + if other_errors is not None: + raise other_errors finally: - context.terminated = True + await context.terminated.set() server_nursery.cancel_scope.deadline = trio.current_time() + config.graceful_timeout await lifespan.wait_for_shutdown() lifespan_nursery.cancel_scope.cancel() - if reload_: - restart() - def trio_worker( config: Config, sockets: Optional[Sockets] = None, shutdown_event: Optional[EventType] = None @@ -121,7 +129,7 @@ def trio_worker( sock.listen(config.backlog) for sock in sockets.insecure_sockets: sock.listen(config.backlog) - app = load_application(config.application_path) + app = load_application(config.application_path, config.wsgi_max_body_size) shutdown_trigger = None if shutdown_event is not None: diff --git a/src/hypercorn/trio/task_group.py b/src/hypercorn/trio/task_group.py index 35e9932..7fad871 100644 --- a/src/hypercorn/trio/task_group.py +++ b/src/hypercorn/trio/task_group.py @@ -1,31 +1,35 @@ from __future__ import annotations +import sys +from contextlib import AbstractAsyncContextManager from types import TracebackType from typing import Any, Awaitable, Callable, Optional import trio from ..config import Config -from ..typing import ASGIFramework, ASGIReceiveCallable, ASGIReceiveEvent, ASGISendEvent, Scope -from ..utils import invoke_asgi +from ..typing import AppWrapper, ASGIReceiveCallable, ASGIReceiveEvent, ASGISendEvent, Scope + +if sys.version_info < (3, 11): + from exceptiongroup import BaseExceptionGroup async def _handle( - app: ASGIFramework, + app: AppWrapper, config: Config, scope: Scope, receive: ASGIReceiveCallable, send: Callable[[Optional[ASGISendEvent]], Awaitable[None]], + sync_spawn: Callable, + call_soon: Callable, ) -> None: try: - await invoke_asgi(app, scope, receive, send) + await app(scope, receive, send, sync_spawn, call_soon) except trio.Cancelled: raise - except trio.MultiError as error: - errors = trio.MultiError.filter( - lambda exc: None if isinstance(exc, trio.Cancelled) else exc, root_exc=error - ) - if errors is not None: + except BaseExceptionGroup as error: + _, other_errors = error.split(trio.Cancelled) + if other_errors is not None: await config.log.exception("Error in ASGI Framework") await send(None) else: @@ -38,18 +42,29 @@ async def _handle( class TaskGroup: def __init__(self) -> None: - self._nursery: Optional[trio._core._run.Nursery] = None - self._nursery_manager: Optional[trio._core._run.NurseryManager] = None + self._nursery: trio.Nursery | None = None + self._nursery_manager: AbstractAsyncContextManager[trio.Nursery] | None = None async def spawn_app( self, - app: ASGIFramework, + app: AppWrapper, config: Config, scope: Scope, send: Callable[[Optional[ASGISendEvent]], Awaitable[None]], ) -> Callable[[ASGIReceiveEvent], Awaitable[None]]: - app_send_channel, app_receive_channel = trio.open_memory_channel(config.max_app_queue_size) - self._nursery.start_soon(_handle, app, config, scope, app_receive_channel.receive, send) + app_send_channel, app_receive_channel = trio.open_memory_channel[ASGIReceiveEvent]( + config.max_app_queue_size + ) + self._nursery.start_soon( + _handle, + app, + config, + scope, + app_receive_channel.receive, + send, + trio.to_thread.run_sync, + trio.from_thread.run, + ) return app_send_channel.send def spawn(self, func: Callable, *args: Any) -> None: diff --git a/src/hypercorn/trio/tcp_server.py b/src/hypercorn/trio/tcp_server.py index 069f3b7..5e2b633 100644 --- a/src/hypercorn/trio/tcp_server.py +++ b/src/hypercorn/trio/tcp_server.py @@ -1,48 +1,38 @@ from __future__ import annotations from math import inf -from typing import Any, Callable, Generator, Optional +from typing import Any, Generator import trio from .task_group import TaskGroup -from .worker_context import WorkerContext +from .worker_context import TrioSingleTask, WorkerContext from ..config import Config from ..events import Closed, Event, RawData, Updated from ..protocol import ProtocolWrapper -from ..typing import ASGIFramework +from ..typing import AppWrapper, ConnectionState, LifespanState from ..utils import parse_socket_addr -MAX_RECV = 2 ** 16 - - -class EventWrapper: - def __init__(self) -> None: - self._event = trio.Event() - - async def clear(self) -> None: - self._event = trio.Event() - - async def wait(self) -> None: - await self._event.wait() - - async def set(self) -> None: - self._event.set() +MAX_RECV = 2**16 class TCPServer: def __init__( - self, app: ASGIFramework, config: Config, context: WorkerContext, stream: trio.abc.Stream + self, + app: AppWrapper, + config: Config, + context: WorkerContext, + state: LifespanState, + stream: trio.SSLStream[trio.SocketStream], ) -> None: self.app = app self.config = config self.context = context self.protocol: ProtocolWrapper self.send_lock = trio.Lock() - self.timeout_lock = trio.Lock() + self.idle_task = TrioSingleTask() self.stream = stream - - self._keep_alive_timeout_handle: Optional[trio.CancelScope] = None + self.state = state def __await__(self) -> Generator[Any, None, None]: return self.run().__await__() @@ -73,6 +63,7 @@ async def run(self) -> None: self.config, self.context, task_group, + ConnectionState(self.state.copy()), ssl, client, server, @@ -80,9 +71,9 @@ async def run(self) -> None: alpn_protocol, ) await self.protocol.initiate() - await self._start_keep_alive_timeout() + await self.idle_task.restart(self._task_group, self._idle_timeout) await self._read_data() - except (trio.MultiError, OSError): + except OSError: pass finally: await self._close() @@ -101,22 +92,26 @@ async def protocol_send(self, event: Event) -> None: await self.protocol.handle(Closed()) elif isinstance(event, Updated): if event.idle: - await self._start_keep_alive_timeout() + await self.idle_task.restart(self._task_group, self._idle_timeout) else: - await self._stop_keep_alive_timeout() + await self.idle_task.stop() async def _read_data(self) -> None: while True: try: with trio.fail_after(self.config.read_timeout or inf): data = await self.stream.receive_some(MAX_RECV) - except (trio.ClosedResourceError, trio.BrokenResourceError): - await self.protocol.handle(Closed()) + except ( + trio.ClosedResourceError, + trio.BrokenResourceError, + trio.TooSlowError, + ): break else: await self.protocol.handle(RawData(data)) if data == b"": break + await self.protocol.handle(Closed()) async def _close(self) -> None: try: @@ -132,32 +127,13 @@ async def _close(self) -> None: pass await self.stream.aclose() - async def _start_keep_alive_timeout(self) -> None: - async with self.timeout_lock: - if self._keep_alive_timeout_handle is None: - self._keep_alive_timeout_handle = await self._task_group._nursery.start( - _call_later, self.config.keep_alive_timeout, self._timeout - ) + async def _idle_timeout(self) -> None: + with trio.move_on_after(self.config.keep_alive_timeout): + await self.context.terminated.wait() - async def _timeout(self) -> None: + with trio.CancelScope(shield=True): + await self._initiate_server_close() + + async def _initiate_server_close(self) -> None: await self.protocol.handle(Closed()) await self.stream.aclose() - - async def _stop_keep_alive_timeout(self) -> None: - async with self.timeout_lock: - if self._keep_alive_timeout_handle is not None: - self._keep_alive_timeout_handle.cancel() - self._keep_alive_timeout_handle = None - - -async def _call_later( - timeout: float, - callback: Callable, - task_status: trio._core._run._TaskStatus = trio.TASK_STATUS_IGNORED, -) -> None: - cancel_scope = trio.CancelScope() - task_status.started(cancel_scope) - with cancel_scope: - await trio.sleep(timeout) - cancel_scope.shield = True - await callback() diff --git a/src/hypercorn/trio/udp_server.py b/src/hypercorn/trio/udp_server.py index 667f12b..566c082 100644 --- a/src/hypercorn/trio/udp_server.py +++ b/src/hypercorn/trio/udp_server.py @@ -1,43 +1,51 @@ from __future__ import annotations +import socket + import trio from .task_group import TaskGroup from .worker_context import WorkerContext from ..config import Config from ..events import Event, RawData -from ..typing import ASGIFramework +from ..typing import AppWrapper, ConnectionState, LifespanState from ..utils import parse_socket_addr -MAX_RECV = 2 ** 16 +MAX_RECV = 2**16 class UDPServer: def __init__( self, - app: ASGIFramework, + app: AppWrapper, config: Config, context: WorkerContext, - socket: trio.socket.socket, + state: LifespanState, + socket: socket.socket, ) -> None: self.app = app self.config = config self.context = context self.socket = trio.socket.from_stdlib_socket(socket) + self.state = state - async def run( - self, task_status: trio._core._run._TaskStatus = trio.TASK_STATUS_IGNORED - ) -> None: + async def run(self, task_status: trio.TaskStatus = trio.TASK_STATUS_IGNORED) -> None: from ..protocol.quic import QuicProtocol # h3/Quic is an optional part of Hypercorn task_status.started() server = parse_socket_addr(self.socket.family, self.socket.getsockname()) async with TaskGroup() as task_group: self.protocol = QuicProtocol( - self.app, self.config, self.context, task_group, server, self.protocol_send + self.app, + self.config, + self.context, + task_group, + ConnectionState(self.state.copy()), + server, + self.protocol_send, ) - while not self.context.terminated or not self.protocol.idle: + while not self.context.terminated.is_set() or not self.protocol.idle: data, address = await self.socket.recvfrom(MAX_RECV) await self.protocol.handle(RawData(data=data, address=address)) diff --git a/src/hypercorn/trio/worker_context.py b/src/hypercorn/trio/worker_context.py index c6c91e2..1cac17e 100644 --- a/src/hypercorn/trio/worker_context.py +++ b/src/hypercorn/trio/worker_context.py @@ -1,10 +1,42 @@ from __future__ import annotations -from typing import Type, Union +from functools import wraps +from typing import Awaitable, Callable, Optional, Type, Union import trio -from ..typing import Event +from ..typing import Event, SingleTask, TaskGroup + + +def _cancel_wrapper(func: Callable[[], Awaitable[None]]) -> Callable[[], Awaitable[None]]: + @wraps(func) + async def wrapper( + task_status: trio.TaskStatus = trio.TASK_STATUS_IGNORED, + ) -> None: + cancel_scope = trio.CancelScope() + task_status.started(cancel_scope) + with cancel_scope: + await func() + + return wrapper + + +class TrioSingleTask: + def __init__(self) -> None: + self._handle: Optional[trio.CancelScope] = None + self._lock = trio.Lock() + + async def restart(self, task_group: TaskGroup, action: Callable) -> None: + async with self._lock: + if self._handle is not None: + self._handle.cancel() + self._handle = await task_group._nursery.start(_cancel_wrapper(action)) # type: ignore + + async def stop(self) -> None: + async with self._lock: + if self._handle is not None: + self._handle.cancel() + self._handle = None class EventWrapper: @@ -20,12 +52,27 @@ async def wait(self) -> None: async def set(self) -> None: self._event.set() + def is_set(self) -> bool: + return self._event.is_set() + class WorkerContext: event_class: Type[Event] = EventWrapper + single_task_class: Type[SingleTask] = TrioSingleTask - def __init__(self) -> None: - self.terminated = False + def __init__(self, max_requests: Optional[int]) -> None: + self.max_requests = max_requests + self.requests = 0 + self.terminate = self.event_class() + self.terminated = self.event_class() + + async def mark_request(self) -> None: + if self.max_requests is None: + return + + self.requests += 1 + if self.requests > self.max_requests: + await self.terminate.set() @staticmethod async def sleep(wait: Union[float, int]) -> None: diff --git a/src/hypercorn/typing.py b/src/hypercorn/typing.py index 64fa0e0..cba7d5d 100644 --- a/src/hypercorn/typing.py +++ b/src/hypercorn/typing.py @@ -2,23 +2,40 @@ from multiprocessing.synchronize import Event as EventType from types import TracebackType -from typing import Any, Awaitable, Callable, Dict, Iterable, Optional, Tuple, Type, Union +from typing import ( + Any, + Awaitable, + Callable, + Dict, + Iterable, + Literal, + NewType, + Optional, + Protocol, + Tuple, + Type, + TypedDict, + Union, +) import h2.events import h11 -# Till PEP 544 is accepted +from .config import Config, Sockets + try: - from typing import Literal, Protocol, TypedDict + from typing import NotRequired except ImportError: - from typing_extensions import Literal, Protocol, TypedDict # type: ignore - -from .config import Config, Sockets + from typing_extensions import NotRequired H11SendableEvent = Union[h11.Data, h11.EndOfMessage, h11.InformationalResponse, h11.Response] WorkerFunc = Callable[[Config, Optional[Sockets], Optional[EventType]], None] +LifespanState = Dict[str, Any] + +ConnectionState = NewType("ConnectionState", Dict[str, Any]) + class ASGIVersions(TypedDict, total=False): spec_version: str @@ -38,6 +55,7 @@ class HTTPScope(TypedDict): headers: Iterable[Tuple[bytes, bytes]] client: Optional[Tuple[str, int]] server: Optional[Tuple[str, Optional[int]]] + state: ConnectionState extensions: Dict[str, dict] @@ -54,12 +72,14 @@ class WebsocketScope(TypedDict): client: Optional[Tuple[str, int]] server: Optional[Tuple[str, Optional[int]]] subprotocols: Iterable[str] + state: ConnectionState extensions: Dict[str, dict] class LifespanScope(TypedDict): type: Literal["lifespan"] asgi: ASGIVersions + state: LifespanState WWWScope = Union[HTTPScope, WebsocketScope] @@ -76,6 +96,7 @@ class HTTPResponseStartEvent(TypedDict): type: Literal["http.response.start"] status: int headers: Iterable[Tuple[bytes, bytes]] + trailers: NotRequired[bool] class HTTPResponseBodyEvent(TypedDict): @@ -84,12 +105,23 @@ class HTTPResponseBodyEvent(TypedDict): more_body: bool +class HTTPResponseTrailersEvent(TypedDict): + type: Literal["http.response.trailers"] + headers: Iterable[Tuple[bytes, bytes]] + more_trailers: NotRequired[bool] + + class HTTPServerPushEvent(TypedDict): type: Literal["http.response.push"] path: str headers: Iterable[Tuple[bytes, bytes]] +class HTTPEarlyHintEvent(TypedDict): + type: Literal["http.response.early_hint"] + links: Iterable[bytes] + + class HTTPDisconnectEvent(TypedDict): type: Literal["http.disconnect"] @@ -179,7 +211,9 @@ class LifespanShutdownFailedEvent(TypedDict): ASGISendEvent = Union[ HTTPResponseStartEvent, HTTPResponseBodyEvent, + HTTPResponseTrailersEvent, HTTPServerPushEvent, + HTTPEarlyHintEvent, HTTPDisconnectEvent, WebsocketAcceptEvent, WebsocketSendEvent, @@ -196,19 +230,7 @@ class LifespanShutdownFailedEvent(TypedDict): ASGIReceiveCallable = Callable[[], Awaitable[ASGIReceiveEvent]] ASGISendCallable = Callable[[ASGISendEvent], Awaitable[None]] - -class ASGI2Protocol(Protocol): - # Should replace with a Protocol when PEP 544 is accepted. - - def __init__(self, scope: Scope) -> None: - ... - - async def __call__(self, receive: ASGIReceiveCallable, send: ASGISendCallable) -> None: - ... - - -ASGI2Framework = Type[ASGI2Protocol] -ASGI3Framework = Callable[ +ASGIFramework = Callable[ [ Scope, ASGIReceiveCallable, @@ -216,23 +238,24 @@ async def __call__(self, receive: ASGIReceiveCallable, send: ASGISendCallable) - ], Awaitable[None], ] -ASGIFramework = Union[ASGI2Framework, ASGI3Framework] +WSGIFramework = Callable[[dict, Callable], Iterable[bytes]] +Framework = Union[ASGIFramework, WSGIFramework] class H2SyncStream(Protocol): scope: dict def data_received(self, data: bytes) -> None: - ... + pass def ended(self) -> None: - ... + pass def reset(self) -> None: - ... + pass def close(self) -> None: - ... + pass async def handle_request( self, @@ -241,23 +264,23 @@ async def handle_request( client: Tuple[str, int], server: Tuple[str, int], ) -> None: - ... + pass class H2AsyncStream(Protocol): scope: dict async def data_received(self, data: bytes) -> None: - ... + pass async def ended(self) -> None: - ... + pass async def reset(self) -> None: - ... + pass async def close(self) -> None: - ... + pass async def handle_request( self, @@ -266,56 +289,87 @@ async def handle_request( client: Tuple[str, int], server: Tuple[str, int], ) -> None: - ... + pass class Event(Protocol): def __init__(self) -> None: - ... + pass async def clear(self) -> None: - ... + pass async def set(self) -> None: - ... + pass async def wait(self) -> None: - ... + pass + + def is_set(self) -> bool: + pass class WorkerContext(Protocol): event_class: Type[Event] - terminated: bool + single_task_class: Type[SingleTask] + terminate: Event + terminated: Event + + async def mark_request(self) -> None: + pass @staticmethod async def sleep(wait: Union[float, int]) -> None: - ... + pass @staticmethod def time() -> float: - ... + pass class TaskGroup(Protocol): async def spawn_app( self, - app: ASGIFramework, + app: AppWrapper, config: Config, scope: Scope, send: Callable[[Optional[ASGISendEvent]], Awaitable[None]], ) -> Callable[[ASGIReceiveEvent], Awaitable[None]]: - ... + pass def spawn(self, func: Callable, *args: Any) -> None: - ... + pass async def __aenter__(self) -> TaskGroup: - ... + pass async def __aexit__(self, exc_type: type, exc_value: BaseException, tb: TracebackType) -> None: - ... + pass class ResponseSummary(TypedDict): status: int headers: Iterable[Tuple[bytes, bytes]] + + +class AppWrapper(Protocol): + async def __call__( + self, + scope: Scope, + receive: ASGIReceiveCallable, + send: ASGISendCallable, + sync_spawn: Callable, + call_soon: Callable, + ) -> None: + pass + + +class SingleTask(Protocol): + def __init__(self) -> None: + pass + + async def restart(self, task_group: TaskGroup, action: Callable) -> None: + pass + + async def stop(self) -> None: + pass diff --git a/src/hypercorn/utils.py b/src/hypercorn/utils.py index 4fce714..39249c5 100644 --- a/src/hypercorn/utils.py +++ b/src/hypercorn/utils.py @@ -2,10 +2,8 @@ import inspect import os -import platform import socket import sys -from dataclasses import dataclass from enum import Enum from importlib import import_module from multiprocessing.synchronize import Event as EventType @@ -18,20 +16,15 @@ Dict, Iterable, List, + Literal, Optional, Tuple, TYPE_CHECKING, ) +from .app_wrappers import ASGIWrapper, WSGIWrapper from .config import Config -from .typing import ( - ASGI2Framework, - ASGI3Framework, - ASGIFramework, - ASGIReceiveCallable, - ASGISendCallable, - Scope, -) +from .typing import AppWrapper, ASGIFramework, Framework, WSGIFramework if TYPE_CHECKING: from .protocol.events import Request @@ -41,10 +34,6 @@ class ShutdownError(Exception): pass -class MustReloadError(Exception): - pass - - class NoAppError(Exception): pass @@ -72,7 +61,7 @@ class FrameTooLargeError(Exception): def suppress_body(method: str, status_code: int) -> bool: - return method == "HEAD" or 100 <= status_code < 200 or status_code in {204, 304, 412} + return method == "HEAD" or 100 <= status_code < 200 or status_code in {204, 304} def build_and_validate_headers(headers: Iterable[Tuple[bytes, bytes]]) -> List[Tuple[bytes, bytes]]: @@ -81,7 +70,7 @@ def build_and_validate_headers(headers: Iterable[Tuple[bytes, bytes]]) -> List[T for name, value in headers: if name[0] == b":"[0]: raise ValueError("Pseudo headers are not valid") - validated_headers.append((bytes(name).lower().strip(), bytes(value).strip())) + validated_headers.append((bytes(name).strip(), bytes(value).strip())) return validated_headers @@ -100,13 +89,16 @@ def filter_pseudo_headers(headers: List[Tuple[bytes, bytes]]) -> List[Tuple[byte return filtered_headers -def load_application(path: str) -> ASGIFramework: - try: - module_name, app_name = path.split(":", 1) - except ValueError: +def load_application(path: str, wsgi_max_body_size: int) -> AppWrapper: + mode: Optional[Literal["asgi", "wsgi"]] = None + if ":" not in path: module_name, app_name = path, "app" - except AttributeError: - raise NoAppError() + elif path.count(":") == 2: + mode, module_name, app_name = path.split(":", 2) # type: ignore + if mode not in {"asgi", "wsgi"}: + raise ValueError("Invalid mode, must be 'asgi', or 'wsgi'") + else: + module_name, app_name = path.split(":", 1) module_path = Path(module_name).resolve() sys.path.insert(0, str(module_path.parent)) @@ -118,17 +110,29 @@ def load_application(path: str) -> ASGIFramework: module = import_module(import_name) except ModuleNotFoundError as error: if error.name == import_name: - raise NoAppError() + raise NoAppError(f"Cannot load application from '{path}', module not found.") else: raise - try: - return eval(app_name, vars(module)) + app = eval(app_name, vars(module)) except NameError: - raise NoAppError() + raise NoAppError(f"Cannot load application from '{path}', application not found.") + else: + return wrap_app(app, wsgi_max_body_size, mode) -async def observe_changes(sleep: Callable[[float], Awaitable[Any]]) -> None: +def wrap_app( + app: Framework, wsgi_max_body_size: int, mode: Optional[Literal["asgi", "wsgi"]] +) -> AppWrapper: + if mode is None: + mode = "asgi" if is_asgi(app) else "wsgi" + if mode == "asgi": + return ASGIWrapper(cast(ASGIFramework, app)) + else: + return WSGIWrapper(cast(WSGIFramework, app), wsgi_max_body_size) + + +def files_to_watch() -> Dict[Path, float]: last_updates: Dict[Path, float] = {} for module in list(sys.modules.values()): filename = getattr(module, "__file__", None) @@ -139,62 +143,24 @@ async def observe_changes(sleep: Callable[[float], Awaitable[Any]]) -> None: last_updates[Path(filename)] = path.stat().st_mtime except (FileNotFoundError, NotADirectoryError): pass + return last_updates - while True: - await sleep(1) - - for index, (path, last_mtime) in enumerate(last_updates.items()): - if index % 10 == 0: - # Yield to the event loop - await sleep(0) - - try: - mtime = path.stat().st_mtime - except FileNotFoundError: - # File deleted - raise MustReloadError() - else: - if mtime > last_mtime: - raise MustReloadError() - else: - last_updates[path] = mtime - - -def restart() -> None: - # Restart this process (only safe for dev/debug) - executable = sys.executable - script_path = Path(sys.argv[0]).resolve() - args = sys.argv[1:] - main_package = sys.modules["__main__"].__package__ - - if main_package is None: - # Executed by filename - if platform.system() == "Windows": - if not script_path.exists() and script_path.with_suffix(".exe").exists(): - # quart run - executable = str(script_path.with_suffix(".exe")) - else: - # python run.py - args.append(str(script_path)) + +def check_for_updates(files: Dict[Path, float]) -> bool: + for path, last_mtime in files.items(): + try: + mtime = path.stat().st_mtime + except FileNotFoundError: + return True else: - if script_path.is_file() and os.access(script_path, os.X_OK): - # hypercorn run:app --reload - executable = str(script_path) + if mtime > last_mtime: + return True else: - # python run.py - args.append(str(script_path)) - else: - # Executed as a module e.g. python -m run - module = script_path.stem - import_name = main_package - if module != "__main__": - import_name = f"{main_package}.{module}" - args[:0] = ["-m", import_name.lstrip(".")] + files[path] = mtime + return False - os.execv(executable, [executable] + args) - -async def raise_shutdown(shutdown_event: Callable[..., Awaitable[None]]) -> None: +async def raise_shutdown(shutdown_event: Callable[..., Awaitable]) -> None: await shutdown_event() raise ShutdownError() @@ -215,7 +181,7 @@ def write_pid_file(pid_path: str) -> None: def parse_socket_addr(family: int, address: tuple) -> Optional[Tuple[str, int]]: if family == socket.AF_INET: - return address # type: ignore + return address elif family == socket.AF_INET6: return (address[0], address[1]) else: @@ -233,30 +199,6 @@ def repr_socket_addr(family: int, address: tuple) -> str: return f"{address}" -async def invoke_asgi( - app: ASGIFramework, scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable -) -> None: - if _is_asgi_2(app): - scope["asgi"]["version"] = "2.0" - app = cast(ASGI2Framework, app) - asgi_instance = app(scope) - await asgi_instance(receive, send) - else: - scope["asgi"]["version"] = "3.0" - app = cast(ASGI3Framework, app) - await app(scope, receive, send) - - -def _is_asgi_2(app: ASGIFramework) -> bool: - if inspect.isclass(app): - return True - - if hasattr(app, "__call__") and inspect.iscoroutinefunction(app.__call__): # type: ignore - return False - - return not inspect.iscoroutinefunction(app) - - def valid_server_name(config: Config, request: "Request") -> bool: if len(config.server_names) == 0: return True @@ -269,6 +211,9 @@ def valid_server_name(config: Config, request: "Request") -> bool: return host in config.server_names -@dataclass -class WorkerState: - terminated: bool = False +def is_asgi(app: Any) -> bool: + if inspect.iscoroutinefunction(app): + return True + elif hasattr(app, "__call__"): + return inspect.iscoroutinefunction(app.__call__) + return False diff --git a/tests/asyncio/test_keep_alive.py b/tests/asyncio/test_keep_alive.py index 274802f..9ed4cf6 100644 --- a/tests/asyncio/test_keep_alive.py +++ b/tests/asyncio/test_keep_alive.py @@ -1,29 +1,34 @@ from __future__ import annotations import asyncio -from typing import AsyncGenerator, Callable +from typing import AsyncGenerator import h11 import pytest +import pytest_asyncio +from hypercorn.app_wrappers import ASGIWrapper from hypercorn.asyncio.tcp_server import TCPServer from hypercorn.asyncio.worker_context import WorkerContext from hypercorn.config import Config +from hypercorn.typing import ASGIReceiveCallable, ASGISendCallable, Scope from .helpers import MemoryReader, MemoryWriter KEEP_ALIVE_TIMEOUT = 0.01 REQUEST = h11.Request(method="GET", target="/", headers=[(b"host", b"hypercorn")]) -async def slow_framework(scope: dict, receive: Callable, send: Callable) -> None: +async def slow_framework( + scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable +) -> None: while True: event = await receive() if event["type"] == "http.disconnect": break elif event["type"] == "lifespan.startup": - await send({"type": "lifspan.startup.complete"}) + await send({"type": "lifspan.startup.complete"}) # type: ignore elif event["type"] == "lifespan.shutdown": - await send({"type": "lifspan.shutdown.complete"}) + await send({"type": "lifspan.shutdown.complete"}) # type: ignore elif event["type"] == "http.request" and not event.get("more_body", False): await asyncio.sleep(2 * KEEP_ALIVE_TIMEOUT) await send( @@ -37,12 +42,20 @@ async def slow_framework(scope: dict, receive: Callable, send: Callable) -> None break -@pytest.fixture(name="server", scope="function") -async def _server(event_loop: asyncio.AbstractEventLoop) -> AsyncGenerator[TCPServer, None]: +@pytest_asyncio.fixture(name="server", scope="function") # type: ignore[misc] +async def _server() -> AsyncGenerator[TCPServer, None]: + event_loop: asyncio.AbstractEventLoop = asyncio.get_running_loop() + config = Config() config.keep_alive_timeout = KEEP_ALIVE_TIMEOUT server = TCPServer( - slow_framework, event_loop, config, WorkerContext(), MemoryReader(), MemoryWriter() # type: ignore # noqa: E501 + ASGIWrapper(slow_framework), + event_loop, + config, + WorkerContext(None), + {}, + MemoryReader(), # type: ignore + MemoryWriter(), # type: ignore ) task = event_loop.create_task(server.run()) yield server diff --git a/tests/asyncio/test_lifespan.py b/tests/asyncio/test_lifespan.py index 64217b4..e79d173 100644 --- a/tests/asyncio/test_lifespan.py +++ b/tests/asyncio/test_lifespan.py @@ -6,11 +6,17 @@ import pytest +from hypercorn.app_wrappers import ASGIWrapper from hypercorn.asyncio.lifespan import Lifespan from hypercorn.config import Config -from hypercorn.typing import Scope +from hypercorn.typing import ASGIReceiveCallable, ASGISendCallable, Scope from hypercorn.utils import LifespanFailureError, LifespanTimeoutError -from ..helpers import lifespan_failure, SlowLifespanFramework +from ..helpers import SlowLifespanFramework + +try: + from asyncio import TaskGroup +except ImportError: + from taskgroup import TaskGroup # type: ignore async def no_lifespan_app(scope: Scope, receive: Callable, send: Callable) -> None: @@ -19,20 +25,26 @@ async def no_lifespan_app(scope: Scope, receive: Callable, send: Callable) -> No @pytest.mark.asyncio -async def test_ensure_no_race_condition(event_loop: asyncio.AbstractEventLoop) -> None: +async def test_ensure_no_race_condition() -> None: + event_loop: asyncio.AbstractEventLoop = asyncio.get_running_loop() + config = Config() config.startup_timeout = 0.2 - lifespan = Lifespan(no_lifespan_app, config) + lifespan = Lifespan(ASGIWrapper(no_lifespan_app), config, event_loop, {}) task = event_loop.create_task(lifespan.handle_lifespan()) await lifespan.wait_for_startup() # Raises if there is a race condition await task @pytest.mark.asyncio -async def test_startup_timeout_error(event_loop: asyncio.AbstractEventLoop) -> None: +async def test_startup_timeout_error() -> None: + event_loop: asyncio.AbstractEventLoop = asyncio.get_running_loop() + config = Config() config.startup_timeout = 0.01 - lifespan = Lifespan(SlowLifespanFramework(0.02, asyncio.sleep), config) # type: ignore + lifespan = Lifespan( + ASGIWrapper(SlowLifespanFramework(0.02, asyncio.sleep)), config, event_loop, {} + ) task = event_loop.create_task(lifespan.handle_lifespan()) with pytest.raises(LifespanTimeoutError) as exc_info: await lifespan.wait_for_startup() @@ -40,15 +52,27 @@ async def test_startup_timeout_error(event_loop: asyncio.AbstractEventLoop) -> N await task +async def _lifespan_failure( + scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable +) -> None: + async with TaskGroup(): + while True: + message = await receive() + if message["type"] == "lifespan.startup": + await send({"type": "lifespan.startup.failed", "message": "Failure"}) + break + + @pytest.mark.asyncio -async def test_startup_failure(event_loop: asyncio.AbstractEventLoop) -> None: - lifespan = Lifespan(lifespan_failure, Config()) +async def test_startup_failure() -> None: + event_loop: asyncio.AbstractEventLoop = asyncio.get_running_loop() + + lifespan = Lifespan(ASGIWrapper(_lifespan_failure), Config(), event_loop, {}) lifespan_task = event_loop.create_task(lifespan.handle_lifespan()) await lifespan.wait_for_startup() assert lifespan_task.done() exception = lifespan_task.exception() - assert isinstance(exception, LifespanFailureError) - assert str(exception) == "Lifespan failure in startup. 'Failure'" + assert exception.subgroup(LifespanFailureError) is not None # type: ignore async def return_app(scope: Scope, receive: Callable, send: Callable) -> None: @@ -56,8 +80,10 @@ async def return_app(scope: Scope, receive: Callable, send: Callable) -> None: @pytest.mark.asyncio -async def test_lifespan_return(event_loop: asyncio.AbstractEventLoop) -> None: - lifespan = Lifespan(return_app, Config()) +async def test_lifespan_return() -> None: + event_loop: asyncio.AbstractEventLoop = asyncio.get_running_loop() + + lifespan = Lifespan(ASGIWrapper(return_app), Config(), event_loop, {}) lifespan_task = event_loop.create_task(lifespan.handle_lifespan()) await lifespan.wait_for_startup() await lifespan.wait_for_shutdown() diff --git a/tests/asyncio/test_sanity.py b/tests/asyncio/test_sanity.py index 8632177..c4c87f9 100644 --- a/tests/asyncio/test_sanity.py +++ b/tests/asyncio/test_sanity.py @@ -7,6 +7,7 @@ import pytest import wsproto +from hypercorn.app_wrappers import ASGIWrapper from hypercorn.asyncio.tcp_server import TCPServer from hypercorn.asyncio.worker_context import WorkerContext from hypercorn.config import Config @@ -15,9 +16,17 @@ @pytest.mark.asyncio -async def test_http1_request(event_loop: asyncio.AbstractEventLoop) -> None: +async def test_http1_request() -> None: + event_loop: asyncio.AbstractEventLoop = asyncio.get_running_loop() + server = TCPServer( - sanity_framework, event_loop, Config(), WorkerContext(), MemoryReader(), MemoryWriter() # type: ignore # noqa: E501 + ASGIWrapper(sanity_framework), + event_loop, + Config(), + WorkerContext(None), + {}, + MemoryReader(), # type: ignore + MemoryWriter(), # type: ignore ) task = event_loop.create_task(server.run()) client = h11.Connection(h11.CLIENT) @@ -67,9 +76,17 @@ async def test_http1_request(event_loop: asyncio.AbstractEventLoop) -> None: @pytest.mark.asyncio -async def test_http1_websocket(event_loop: asyncio.AbstractEventLoop) -> None: +async def test_http1_websocket() -> None: + event_loop: asyncio.AbstractEventLoop = asyncio.get_running_loop() + server = TCPServer( - sanity_framework, event_loop, Config(), WorkerContext(), MemoryReader(), MemoryWriter() # type: ignore # noqa: E501 + ASGIWrapper(sanity_framework), + event_loop, + Config(), + WorkerContext(None), + {}, + MemoryReader(), # type: ignore + MemoryWriter(), # type: ignore ) task = event_loop.create_task(server.run()) client = wsproto.WSConnection(wsproto.ConnectionType.CLIENT) @@ -99,12 +116,15 @@ async def test_http1_websocket(event_loop: asyncio.AbstractEventLoop) -> None: @pytest.mark.asyncio -async def test_http2_request(event_loop: asyncio.AbstractEventLoop) -> None: +async def test_http2_request() -> None: + event_loop: asyncio.AbstractEventLoop = asyncio.get_running_loop() + server = TCPServer( - sanity_framework, + ASGIWrapper(sanity_framework), event_loop, Config(), - WorkerContext(), + WorkerContext(None), + {}, MemoryReader(), # type: ignore MemoryWriter(http2=True), # type: ignore ) @@ -162,12 +182,15 @@ async def test_http2_request(event_loop: asyncio.AbstractEventLoop) -> None: @pytest.mark.asyncio -async def test_http2_websocket(event_loop: asyncio.AbstractEventLoop) -> None: +async def test_http2_websocket() -> None: + event_loop: asyncio.AbstractEventLoop = asyncio.get_running_loop() + server = TCPServer( - sanity_framework, + ASGIWrapper(sanity_framework), event_loop, Config(), - WorkerContext(), + WorkerContext(None), + {}, MemoryReader(), # type: ignore MemoryWriter(http2=True), # type: ignore ) diff --git a/tests/asyncio/test_task_group.py b/tests/asyncio/test_task_group.py index c11d75f..0e64efd 100644 --- a/tests/asyncio/test_task_group.py +++ b/tests/asyncio/test_task_group.py @@ -5,13 +5,16 @@ import pytest +from hypercorn.app_wrappers import ASGIWrapper from hypercorn.asyncio.task_group import TaskGroup from hypercorn.config import Config from hypercorn.typing import HTTPScope, Scope @pytest.mark.asyncio -async def test_spawn_app(event_loop: asyncio.AbstractEventLoop, http_scope: HTTPScope) -> None: +async def test_spawn_app(http_scope: HTTPScope) -> None: + event_loop: asyncio.AbstractEventLoop = asyncio.get_running_loop() + async def _echo_app(scope: Scope, receive: Callable, send: Callable) -> None: while True: message = await receive() @@ -21,34 +24,22 @@ async def _echo_app(scope: Scope, receive: Callable, send: Callable) -> None: app_queue: asyncio.Queue = asyncio.Queue() async with TaskGroup(event_loop) as task_group: - put = await task_group.spawn_app(_echo_app, Config(), http_scope, app_queue.put) - await put({"type": "http.disconnect"}) # type: ignore + put = await task_group.spawn_app( + ASGIWrapper(_echo_app), Config(), http_scope, app_queue.put + ) + await put({"type": "http.disconnect"}) assert (await app_queue.get()) == {"type": "http.disconnect"} await put(None) @pytest.mark.asyncio -async def test_spawn_app_error( - event_loop: asyncio.AbstractEventLoop, http_scope: HTTPScope -) -> None: +async def test_spawn_app_error(http_scope: HTTPScope) -> None: + event_loop: asyncio.AbstractEventLoop = asyncio.get_running_loop() + async def _error_app(scope: Scope, receive: Callable, send: Callable) -> None: raise Exception() app_queue: asyncio.Queue = asyncio.Queue() async with TaskGroup(event_loop) as task_group: - await task_group.spawn_app(_error_app, Config(), http_scope, app_queue.put) - assert (await app_queue.get()) is None - - -@pytest.mark.asyncio -async def test_spawn_app_cancelled( - event_loop: asyncio.AbstractEventLoop, http_scope: HTTPScope -) -> None: - async def _error_app(scope: Scope, receive: Callable, send: Callable) -> None: - raise asyncio.CancelledError() - - app_queue: asyncio.Queue = asyncio.Queue() - with pytest.raises(asyncio.CancelledError): - async with TaskGroup(event_loop) as task_group: - await task_group.spawn_app(_error_app, Config(), http_scope, app_queue.put) + await task_group.spawn_app(ASGIWrapper(_error_app), Config(), http_scope, app_queue.put) assert (await app_queue.get()) is None diff --git a/tests/asyncio/test_tcp_server.py b/tests/asyncio/test_tcp_server.py index 959daac..1aa2898 100644 --- a/tests/asyncio/test_tcp_server.py +++ b/tests/asyncio/test_tcp_server.py @@ -4,6 +4,7 @@ import pytest +from hypercorn.app_wrappers import ASGIWrapper from hypercorn.asyncio.tcp_server import TCPServer from hypercorn.asyncio.worker_context import WorkerContext from hypercorn.config import Config @@ -12,9 +13,17 @@ @pytest.mark.asyncio -async def test_completes_on_closed(event_loop: asyncio.AbstractEventLoop) -> None: +async def test_completes_on_closed() -> None: + event_loop: asyncio.AbstractEventLoop = asyncio.get_running_loop() + server = TCPServer( - echo_framework, event_loop, Config(), WorkerContext(), MemoryReader(), MemoryWriter() # type: ignore # noqa: E501 + ASGIWrapper(echo_framework), + event_loop, + Config(), + WorkerContext(None), + {}, + MemoryReader(), # type: ignore + MemoryWriter(), # type: ignore ) server.reader.close() # type: ignore await server.run() @@ -23,17 +32,24 @@ async def test_completes_on_closed(event_loop: asyncio.AbstractEventLoop) -> Non @pytest.mark.asyncio -async def test_complets_on_half_close(event_loop: asyncio.AbstractEventLoop) -> None: +async def test_complets_on_half_close() -> None: + event_loop: asyncio.AbstractEventLoop = asyncio.get_running_loop() + server = TCPServer( - echo_framework, event_loop, Config(), WorkerContext(), MemoryReader(), MemoryWriter() # type: ignore # noqa: E501 + ASGIWrapper(echo_framework), + event_loop, + Config(), + WorkerContext(None), + {}, + MemoryReader(), # type: ignore + MemoryWriter(), # type: ignore ) task = event_loop.create_task(server.run()) await server.reader.send(b"GET / HTTP/1.1\r\nHost: hypercorn\r\n\r\n") # type: ignore server.reader.close() # type: ignore - await asyncio.sleep(0) + await task data = await server.writer.receive() # type: ignore assert ( data - == b"HTTP/1.1 200 \r\ncontent-length: 335\r\ndate: Thu, 01 Jan 1970 01:23:20 GMT\r\nserver: hypercorn-h11\r\n\r\n" # noqa: E501 + == b"HTTP/1.1 200 \r\ncontent-length: 348\r\ndate: Thu, 01 Jan 1970 01:23:20 GMT\r\nserver: hypercorn-h11\r\n\r\n" # noqa: E501 ) - await task diff --git a/tests/conftest.py b/tests/conftest.py index f25c3f1..be84f59 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,7 +4,7 @@ from _pytest.monkeypatch import MonkeyPatch import hypercorn.config -from hypercorn.typing import HTTPScope +from hypercorn.typing import ConnectionState, HTTPScope @pytest.fixture(autouse=True) @@ -32,4 +32,5 @@ def _http_scope() -> HTTPScope: "client": ("127.0.0.1", 80), "server": None, "extensions": {}, + "state": ConnectionState({}), } diff --git a/tests/e2e/__init__.py b/tests/e2e/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/e2e/test_httpx.py b/tests/e2e/test_httpx.py new file mode 100644 index 0000000..9ce2cc7 --- /dev/null +++ b/tests/e2e/test_httpx.py @@ -0,0 +1,57 @@ +from __future__ import annotations + +import httpx # type: ignore +import pytest +import trio + +import hypercorn.trio +from hypercorn.config import Config + + +async def app(scope, receive, send) -> None: # type: ignore + assert scope["type"] == "http" + + await send( + { + "type": "http.response.start", + "status": 200, + "headers": [ + [b"content-type", b"text/plain"], + ], + } + ) + await send( + { + "type": "http.response.body", + "body": b"Hello, world!", + } + ) + + +@pytest.mark.trio +async def test_keep_alive_max_requests_regression() -> None: + config = Config() + config.bind = ["0.0.0.0:1234"] + config.accesslog = "-" # Log to stdout/err + config.errorlog = "-" + config.keep_alive_max_requests = 2 + + async with trio.open_nursery() as nursery: + shutdown = trio.Event() + + async def serve() -> None: + await hypercorn.trio.serve(app, config, shutdown_trigger=shutdown.wait) + + nursery.start_soon(serve) + + await trio.testing.wait_all_tasks_blocked() + + client = httpx.AsyncClient() + + # Make sure that we properly clean up connections when `keep_alive_max_requests` + # is hit such that the client stays good over multiple hangups. + for _ in range(10): + result = await client.post("http://0.0.0.0:1234/test", json={"key": "value"}) + result.raise_for_status() + + shutdown.set() diff --git a/tests/helpers.py b/tests/helpers.py index 988464e..0e2d4d8 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -5,13 +5,12 @@ from socket import AF_INET from typing import Callable, cast, Tuple -from hypercorn.typing import Scope, WWWScope +from hypercorn.typing import ASGIReceiveCallable, ASGISendCallable, Scope, WWWScope SANITY_BODY = b"Hello Hypercorn" class MockSocket: - family = AF_INET def getsockname(self) -> Tuple[str, int]: @@ -26,15 +25,19 @@ async def empty_framework(scope: Scope, receive: Callable, send: Callable) -> No class SlowLifespanFramework: - def __init__(self, delay: int, sleep: Callable) -> None: + def __init__(self, delay: float, sleep: Callable) -> None: self.delay = delay self.sleep = sleep - async def __call__(self, scope: dict, receive: Callable, send: Callable) -> None: + async def __call__( + self, scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable + ) -> None: await self.sleep(self.delay) -async def echo_framework(input_scope: Scope, receive: Callable, send: Callable) -> None: +async def echo_framework( + input_scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable +) -> None: input_scope = cast(WWWScope, input_scope) scope = deepcopy(input_scope) scope["query_string"] = scope["query_string"].decode() # type: ignore @@ -63,32 +66,27 @@ async def echo_framework(input_scope: Scope, receive: Callable, send: Callable) await send({"type": "http.response.body", "body": response, "more_body": False}) break elif event["type"] == "websocket.connect": - await send({"type": "websocket.accept"}) + await send({"type": "websocket.accept"}) # type: ignore elif event["type"] == "websocket.receive": await send({"type": "websocket.send", "text": event["text"], "bytes": event["bytes"]}) -async def lifespan_failure(scope: Scope, receive: Callable, send: Callable) -> None: - while True: - message = await receive() - if message["type"] == "lifespan.startup": - await send({"type": "lifespan.startup.failed", "message": "Failure"}) - break - - -async def sanity_framework(scope: Scope, receive: Callable, send: Callable) -> None: +async def sanity_framework( + scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable +) -> None: body = b"" if scope["type"] == "websocket": - await send({"type": "websocket.accept"}) + await send({"type": "websocket.accept"}) # type: ignore while True: event = await receive() if event["type"] in {"http.disconnect", "websocket.disconnect"}: break elif event["type"] == "lifespan.startup": - await send({"type": "lifspan.startup.complete"}) + assert "state" in scope + await send({"type": "lifspan.startup.complete"}) # type: ignore elif event["type"] == "lifespan.shutdown": - await send({"type": "lifspan.shutdown.complete"}) + await send({"type": "lifspan.shutdown.complete"}) # type: ignore elif event["type"] == "http.request" and event.get("more_body", False): body += event["body"] elif event["type"] == "http.request" and not event.get("more_body", False): @@ -107,4 +105,4 @@ async def sanity_framework(scope: Scope, receive: Callable, send: Callable) -> N break elif event["type"] == "websocket.receive": assert event["bytes"] == SANITY_BODY - await send({"type": "websocket.send", "text": "Hello & Goodbye"}) + await send({"type": "websocket.send", "text": "Hello & Goodbye"}) # type: ignore diff --git a/tests/middleware/test_dispatcher.py b/tests/middleware/test_dispatcher.py index dbb3f43..2c6d8a1 100644 --- a/tests/middleware/test_dispatcher.py +++ b/tests/middleware/test_dispatcher.py @@ -40,10 +40,10 @@ async def send(message: dict) -> None: await app({**http_scope, **{"path": "/api/b"}}, None, send) # type: ignore await app({**http_scope, **{"path": "/"}}, None, send) # type: ignore assert sent_events == [ - {"type": "http.response.start", "status": 200, "headers": [(b"content-length", b"7")]}, - {"type": "http.response.body", "body": b"apix-/b"}, - {"type": "http.response.start", "status": 200, "headers": [(b"content-length", b"6")]}, - {"type": "http.response.body", "body": b"api-/b"}, + {"type": "http.response.start", "status": 200, "headers": [(b"content-length", b"13")]}, + {"type": "http.response.body", "body": b"apix-/api/x/b"}, + {"type": "http.response.start", "status": 200, "headers": [(b"content-length", b"10")]}, + {"type": "http.response.body", "body": b"api-/api/b"}, {"type": "http.response.start", "status": 404, "headers": [(b"content-length", b"0")]}, {"type": "http.response.body"}, ] @@ -72,7 +72,7 @@ async def send(message: dict) -> None: async def receive() -> dict: return {"type": "lifespan.shutdown"} - await app({"type": "lifespan", "asgi": {"version": "3.0"}}, receive, send) + await app({"type": "lifespan", "asgi": {"version": "3.0"}, "state": {}}, receive, send) assert sent_events == [{"type": "lifespan.startup.complete"}] @@ -89,5 +89,5 @@ async def send(message: dict) -> None: async def receive() -> dict: return {"type": "lifespan.shutdown"} - await app({"type": "lifespan", "asgi": {"version": "3.0"}}, receive, send) + await app({"type": "lifespan", "asgi": {"version": "3.0"}, "state": {}}, receive, send) assert sent_events == [{"type": "lifespan.startup.complete"}] diff --git a/tests/middleware/test_http_to_https.py b/tests/middleware/test_http_to_https.py index a4880c0..01583e2 100644 --- a/tests/middleware/test_http_to_https.py +++ b/tests/middleware/test_http_to_https.py @@ -3,7 +3,7 @@ import pytest from hypercorn.middleware import HTTPToHTTPSRedirectMiddleware -from hypercorn.typing import HTTPScope, WebsocketScope +from hypercorn.typing import ConnectionState, HTTPScope, WebsocketScope from ..helpers import empty_framework @@ -31,6 +31,7 @@ async def send(message: dict) -> None: "client": ("127.0.0.1", 80), "server": None, "extensions": {}, + "state": ConnectionState({}), } await app(scope, None, send) @@ -69,6 +70,7 @@ async def send(message: dict) -> None: "server": None, "subprotocols": [], "extensions": {"websocket.http.response": {}}, + "state": ConnectionState({}), } await app(scope, None, send) @@ -105,6 +107,7 @@ async def send(message: dict) -> None: "server": None, "subprotocols": [], "extensions": {"websocket.http.response": {}}, + "state": ConnectionState({}), } await app(scope, None, send) @@ -141,6 +144,7 @@ async def send(message: dict) -> None: "server": None, "subprotocols": [], "extensions": {}, + "state": ConnectionState({}), } await app(scope, None, send) @@ -165,6 +169,7 @@ def test_http_to_https_redirect_new_url_header() -> None: "client": None, "server": None, "extensions": {}, + "state": ConnectionState({}), }, ) assert new_url == "https://localhost/" diff --git a/tests/middleware/test_proxy_fix.py b/tests/middleware/test_proxy_fix.py new file mode 100644 index 0000000..5a9cf41 --- /dev/null +++ b/tests/middleware/test_proxy_fix.py @@ -0,0 +1,73 @@ +from __future__ import annotations + +from unittest.mock import AsyncMock + +import pytest + +from hypercorn.middleware import ProxyFixMiddleware +from hypercorn.typing import ConnectionState, HTTPScope + + +@pytest.mark.asyncio +async def test_proxy_fix_legacy() -> None: + mock = AsyncMock() + app = ProxyFixMiddleware(mock) + scope: HTTPScope = { + "type": "http", + "asgi": {}, + "http_version": "2", + "method": "GET", + "scheme": "http", + "path": "/", + "raw_path": b"/", + "query_string": b"", + "root_path": "", + "headers": [ + (b"x-forwarded-for", b"127.0.0.1"), + (b"x-forwarded-for", b"127.0.0.2"), + (b"x-forwarded-proto", b"http,https"), + (b"x-forwarded-host", b"example.com"), + ], + "client": ("127.0.0.3", 80), + "server": None, + "extensions": {}, + "state": ConnectionState({}), + } + await app(scope, None, None) + mock.assert_called() + scope = mock.call_args[0][0] + assert scope["client"] == ("127.0.0.2", 0) + assert scope["scheme"] == "https" + host_headers = [h for h in scope["headers"] if h[0].lower() == b"host"] + assert host_headers == [(b"host", b"example.com")] + + +@pytest.mark.asyncio +async def test_proxy_fix_modern() -> None: + mock = AsyncMock() + app = ProxyFixMiddleware(mock, mode="modern") + scope: HTTPScope = { + "type": "http", + "asgi": {}, + "http_version": "2", + "method": "GET", + "scheme": "http", + "path": "/", + "raw_path": b"/", + "query_string": b"", + "root_path": "", + "headers": [ + (b"forwarded", b"for=127.0.0.1;proto=http,for=127.0.0.2;proto=https;host=example.com"), + ], + "client": ("127.0.0.3", 80), + "server": None, + "extensions": {}, + "state": ConnectionState({}), + } + await app(scope, None, None) + mock.assert_called() + scope = mock.call_args[0][0] + assert scope["client"] == ("127.0.0.2", 0) + assert scope["scheme"] == "https" + host_headers = [h for h in scope["headers"] if h[0].lower() == b"host"] + assert host_headers == [(b"host", b"example.com")] diff --git a/tests/protocol/test_h11.py b/tests/protocol/test_h11.py index 43b16e7..aa3b0bd 100755 --- a/tests/protocol/test_h11.py +++ b/tests/protocol/test_h11.py @@ -6,16 +6,17 @@ import h11 import pytest +import pytest_asyncio from _pytest.monkeypatch import MonkeyPatch import hypercorn.protocol.h11 -from hypercorn.asyncio.tcp_server import EventWrapper +from hypercorn.asyncio.worker_context import EventWrapper from hypercorn.config import Config from hypercorn.events import Closed, RawData, Updated from hypercorn.protocol.events import Body, Data, EndBody, EndData, Request, Response, StreamClosed from hypercorn.protocol.h11 import H2CProtocolRequiredError, H2ProtocolAssumedError, H11Protocol from hypercorn.protocol.http_stream import HTTPStream -from hypercorn.typing import Event as IOEvent +from hypercorn.typing import ConnectionState, Event as IOEvent try: from unittest.mock import AsyncMock @@ -27,15 +28,28 @@ BASIC_HEADERS = [("Host", "hypercorn"), ("Connection", "close")] -@pytest.fixture(name="protocol") +@pytest_asyncio.fixture(name="protocol") # type: ignore[misc] async def _protocol(monkeypatch: MonkeyPatch) -> H11Protocol: MockHTTPStream = Mock() # noqa: N806 MockHTTPStream.return_value = AsyncMock(spec=HTTPStream) monkeypatch.setattr(hypercorn.protocol.h11, "HTTPStream", MockHTTPStream) context = Mock() - context.terminated = False context.event_class.return_value = AsyncMock(spec=IOEvent) - return H11Protocol(AsyncMock(), Config(), context, AsyncMock(), False, None, None, AsyncMock()) + context.mark_request = AsyncMock() + context.terminate = context.event_class() + context.terminated = context.event_class() + context.terminated.is_set.return_value = False + return H11Protocol( + AsyncMock(), + Config(), + context, + AsyncMock(), + ConnectionState({}), + False, + None, + None, + AsyncMock(), + ) @pytest.mark.asyncio @@ -54,6 +68,25 @@ async def test_protocol_send_response(protocol: H11Protocol) -> None: ] +@pytest.mark.asyncio +async def test_protocol_preserve_headers(protocol: H11Protocol) -> None: + await protocol.stream_send( + Response(stream_id=1, status_code=201, headers=[(b"X-Special", b"Value")]) + ) + protocol.send.assert_called() # type: ignore + assert protocol.send.call_args_list == [ # type: ignore + call( + RawData( + data=( + b"HTTP/1.1 201 \r\nX-Special: Value\r\n" + b"date: Thu, 01 Jan 1970 01:23:20 GMT\r\n" + b"server: hypercorn-h11\r\nConnection: close\r\n\r\n" + ) + ) + ) + ] + + @pytest.mark.asyncio async def test_protocol_send_data(protocol: H11Protocol) -> None: await protocol.stream_send(Data(stream_id=1, data=b"hello")) @@ -82,6 +115,18 @@ async def test_protocol_send_body(protocol: H11Protocol) -> None: ] +@pytest.mark.asyncio +async def test_protocol_keep_alive_max_requests(protocol: H11Protocol) -> None: + data = b"GET / HTTP/1.1\r\nHost: hypercorn\r\n\r\n" + protocol.config.keep_alive_max_requests = 0 + await protocol.handle(RawData(data=data)) + await protocol.stream_send(Response(stream_id=1, status_code=200, headers=[])) + await protocol.stream_send(EndBody(stream_id=1)) + await protocol.stream_send(StreamClosed(stream_id=1)) + protocol.send.assert_called() # type: ignore + assert protocol.send.call_args_list[3] == call(Closed()) # type: ignore + + @pytest.mark.asyncio @pytest.mark.parametrize("keep_alive, expected", [(True, Updated(idle=True)), (False, Closed())]) async def test_protocol_send_stream_closed( @@ -101,9 +146,9 @@ async def test_protocol_send_stream_closed( @pytest.mark.asyncio -async def test_protocol_instant_recycle( - protocol: H11Protocol, event_loop: asyncio.AbstractEventLoop -) -> None: +async def test_protocol_instant_recycle(protocol: H11Protocol) -> None: + event_loop: asyncio.AbstractEventLoop = asyncio.get_running_loop() + # This test task acts as the asgi app, spawned tasks act as the # server. data = b"GET / HTTP/1.1\r\nHost: hypercorn\r\n\r\n" @@ -148,6 +193,7 @@ async def test_protocol_handle_closed(protocol: H11Protocol) -> None: http_version="1.1", method="GET", raw_path=b"/", + state=ConnectionState({}), ) ), call(EndBody(stream_id=1)), @@ -170,6 +216,35 @@ async def test_protocol_handle_request(protocol: H11Protocol) -> None: http_version="1.1", method="GET", raw_path=b"/?a=b", + state=ConnectionState({}), + ) + ), + call(EndBody(stream_id=1)), + ] + + +@pytest.mark.asyncio +async def test_protocol_handle_request_with_raw_headers(protocol: H11Protocol) -> None: + protocol.config.h11_pass_raw_headers = True + client = h11.Connection(h11.CLIENT) + headers = BASIC_HEADERS + [("FOO_BAR", "foobar")] + await protocol.handle( + RawData(data=client.send(h11.Request(method="GET", target="/?a=b", headers=headers))) + ) + protocol.stream.handle.assert_called() # type: ignore + assert protocol.stream.handle.call_args_list == [ # type: ignore + call( + Request( + stream_id=1, + headers=[ + (b"Host", b"hypercorn"), + (b"Connection", b"close"), + (b"FOO_BAR", b"foobar"), + ], + http_version="1.1", + method="GET", + raw_path=b"/?a=b", + state=ConnectionState({}), ) ), call(EndBody(stream_id=1)), @@ -247,14 +322,22 @@ async def test_protocol_handle_max_incomplete(monkeypatch: MonkeyPatch) -> None: context = Mock() context.event_class.return_value = AsyncMock(spec=IOEvent) protocol = H11Protocol( - AsyncMock(), config, context, AsyncMock(), False, None, None, AsyncMock() + AsyncMock(), + config, + context, + AsyncMock(), + ConnectionState({}), + False, + None, + None, + AsyncMock(), ) await protocol.handle(RawData(data=b"GET / HTTP/1.1\r\nHost: hypercorn\r\n")) protocol.send.assert_called() # type: ignore assert protocol.send.call_args_list == [ # type: ignore call( RawData( - data=b"HTTP/1.1 400 \r\ncontent-length: 0\r\nconnection: close\r\n" + data=b"HTTP/1.1 431 \r\ncontent-length: 0\r\nconnection: close\r\n" b"date: Thu, 01 Jan 1970 01:23:20 GMT\r\nserver: hypercorn-h11\r\n\r\n" ) ), @@ -275,6 +358,7 @@ async def test_protocol_handle_h2c_upgrade(protocol: H11Protocol) -> None: ) ) assert protocol.send.call_args_list == [ # type: ignore + call(Updated(idle=False)), call( RawData( b"HTTP/1.1 101 \r\n" @@ -284,7 +368,7 @@ async def test_protocol_handle_h2c_upgrade(protocol: H11Protocol) -> None: b"upgrade: h2c\r\n" b"\r\n" ) - ) + ), ] assert exc_info.value.data == b"bbb" assert exc_info.value.headers == [ diff --git a/tests/protocol/test_h2.py b/tests/protocol/test_h2.py index 91d5ea2..a13c494 100644 --- a/tests/protocol/test_h2.py +++ b/tests/protocol/test_h2.py @@ -4,12 +4,14 @@ from unittest.mock import call, Mock import pytest +from h2.connection import H2Connection +from h2.events import ConnectionTerminated -from hypercorn.asyncio.tcp_server import EventWrapper -from hypercorn.asyncio.worker_context import WorkerContext +from hypercorn.asyncio.worker_context import EventWrapper, WorkerContext from hypercorn.config import Config from hypercorn.events import Closed, RawData from hypercorn.protocol.h2 import BUFFER_HIGH_WATER, BufferCompleteError, H2Protocol, StreamBuffer +from hypercorn.typing import ConnectionState try: from unittest.mock import AsyncMock @@ -19,7 +21,9 @@ @pytest.mark.asyncio -async def test_stream_buffer_push_and_pop(event_loop: asyncio.AbstractEventLoop) -> None: +async def test_stream_buffer_push_and_pop() -> None: + event_loop: asyncio.AbstractEventLoop = asyncio.get_running_loop() + stream_buffer = StreamBuffer(EventWrapper) async def _push_over_limit() -> bool: @@ -35,7 +39,9 @@ async def _push_over_limit() -> bool: @pytest.mark.asyncio -async def test_stream_buffer_drain(event_loop: asyncio.AbstractEventLoop) -> None: +async def test_stream_buffer_drain() -> None: + event_loop: asyncio.AbstractEventLoop = asyncio.get_running_loop() + stream_buffer = StreamBuffer(EventWrapper) await stream_buffer.push(b"a" * 10) @@ -50,7 +56,7 @@ async def _drain() -> bool: @pytest.mark.asyncio -async def test_stream_buffer_closed(event_loop: asyncio.AbstractEventLoop) -> None: +async def test_stream_buffer_closed() -> None: stream_buffer = StreamBuffer(EventWrapper) await stream_buffer.close() await stream_buffer._is_empty.wait() @@ -61,7 +67,7 @@ async def test_stream_buffer_closed(event_loop: asyncio.AbstractEventLoop) -> No @pytest.mark.asyncio -async def test_stream_buffer_complete(event_loop: asyncio.AbstractEventLoop) -> None: +async def test_stream_buffer_complete() -> None: stream_buffer = StreamBuffer(EventWrapper) await stream_buffer.push(b"a" * 10) assert not stream_buffer.complete @@ -74,8 +80,45 @@ async def test_stream_buffer_complete(event_loop: asyncio.AbstractEventLoop) -> @pytest.mark.asyncio async def test_protocol_handle_protocol_error() -> None: protocol = H2Protocol( - Mock(), Config(), WorkerContext(), AsyncMock(), False, None, None, AsyncMock() + Mock(), + Config(), + WorkerContext(None), + AsyncMock(), + ConnectionState({}), + False, + None, + None, + AsyncMock(), ) await protocol.handle(RawData(data=b"broken nonsense\r\n\r\n")) protocol.send.assert_awaited() # type: ignore assert protocol.send.call_args_list == [call(Closed())] # type: ignore + + +@pytest.mark.asyncio +async def test_protocol_keep_alive_max_requests() -> None: + protocol = H2Protocol( + Mock(), + Config(), + WorkerContext(None), + AsyncMock(), + ConnectionState({}), + False, + None, + None, + AsyncMock(), + ) + protocol.config.keep_alive_max_requests = 0 + client = H2Connection() + client.initiate_connection() + headers = [ + (":method", "GET"), + (":path", "/reqinfo"), + (":authority", "hypercorn"), + (":scheme", "https"), + ] + client.send_headers(1, headers, end_stream=True) + await protocol.handle(RawData(data=client.data_to_send())) + protocol.send.assert_awaited() # type: ignore + events = client.receive_data(protocol.send.call_args_list[1].args[0].data) # type: ignore + assert isinstance(events[-1], ConnectionTerminated) diff --git a/tests/protocol/test_http_stream.py b/tests/protocol/test_http_stream.py index f6550d3..b25cb2f 100644 --- a/tests/protocol/test_http_stream.py +++ b/tests/protocol/test_http_stream.py @@ -4,13 +4,28 @@ from unittest.mock import call import pytest +import pytest_asyncio +from hypercorn.asyncio.statsd import StatsdLogger from hypercorn.asyncio.worker_context import WorkerContext from hypercorn.config import Config from hypercorn.logging import Logger -from hypercorn.protocol.events import Body, EndBody, Request, Response, StreamClosed +from hypercorn.protocol.events import ( + Body, + EndBody, + InformationalResponse, + Request, + Response, + StreamClosed, + Trailers, +) from hypercorn.protocol.http_stream import ASGIHTTPState, HTTPStream -from hypercorn.typing import HTTPResponseBodyEvent, HTTPResponseStartEvent, HTTPScope +from hypercorn.typing import ( + ConnectionState, + HTTPResponseBodyEvent, + HTTPResponseStartEvent, + HTTPScope, +) from hypercorn.utils import UnexpectedMessageError try: @@ -20,10 +35,10 @@ from mock import AsyncMock # type: ignore -@pytest.fixture(name="stream") +@pytest_asyncio.fixture(name="stream") # type: ignore[misc] async def _stream() -> HTTPStream: stream = HTTPStream( - AsyncMock(), Config(), WorkerContext(), AsyncMock(), False, None, None, AsyncMock(), 1 + AsyncMock(), Config(), WorkerContext(None), AsyncMock(), False, None, None, AsyncMock(), 1 ) stream.app_put = AsyncMock() stream.config._log = AsyncMock(spec=Logger) @@ -34,14 +49,21 @@ async def _stream() -> HTTPStream: @pytest.mark.asyncio async def test_handle_request_http_1(stream: HTTPStream, http_version: str) -> None: await stream.handle( - Request(stream_id=1, http_version=http_version, headers=[], raw_path=b"/?a=b", method="GET") + Request( + stream_id=1, + http_version=http_version, + headers=[], + raw_path=b"/?a=b", + method="GET", + state=ConnectionState({}), + ) ) stream.task_group.spawn_app.assert_called() # type: ignore scope = stream.task_group.spawn_app.call_args[0][2] # type: ignore assert scope == { "type": "http", "http_version": http_version, - "asgi": {"spec_version": "2.1"}, + "asgi": {"spec_version": "2.1", "version": "3.0"}, "method": "GET", "scheme": "http", "path": "/", @@ -52,20 +74,28 @@ async def test_handle_request_http_1(stream: HTTPStream, http_version: str) -> N "client": None, "server": None, "extensions": {}, + "state": ConnectionState({}), } @pytest.mark.asyncio async def test_handle_request_http_2(stream: HTTPStream) -> None: await stream.handle( - Request(stream_id=1, http_version="2", headers=[], raw_path=b"/?a=b", method="GET") + Request( + stream_id=1, + http_version="2", + headers=[], + raw_path=b"/?a=b", + method="GET", + state=ConnectionState({}), + ) ) stream.task_group.spawn_app.assert_called() # type: ignore scope = stream.task_group.spawn_app.call_args[0][2] # type: ignore assert scope == { "type": "http", "http_version": "2", - "asgi": {"spec_version": "2.1"}, + "asgi": {"spec_version": "2.1", "version": "3.0"}, "method": "GET", "scheme": "http", "path": "/", @@ -75,7 +105,12 @@ async def test_handle_request_http_2(stream: HTTPStream) -> None: "headers": [], "client": None, "server": None, - "extensions": {"http.response.push": {}}, + "extensions": { + "http.response.trailers": {}, + "http.response.early_hint": {}, + "http.response.push": {}, + }, + "state": ConnectionState({}), } @@ -100,6 +135,16 @@ async def test_handle_end_body(stream: HTTPStream) -> None: @pytest.mark.asyncio async def test_handle_closed(stream: HTTPStream) -> None: + await stream.handle( + Request( + stream_id=1, + http_version="2", + headers=[], + raw_path=b"/?a=b", + method="GET", + state=ConnectionState({}), + ) + ) await stream.handle(StreamClosed(stream_id=1)) stream.app_put.assert_called() # type: ignore assert stream.app_put.call_args_list == [call({"type": "http.disconnect"})] # type: ignore @@ -108,26 +153,31 @@ async def test_handle_closed(stream: HTTPStream) -> None: @pytest.mark.asyncio async def test_send_response(stream: HTTPStream) -> None: await stream.handle( - Request(stream_id=1, http_version="2", headers=[], raw_path=b"/?a=b", method="GET") + Request( + stream_id=1, + http_version="2", + headers=[], + raw_path=b"/?a=b", + method="GET", + state=ConnectionState({}), + ) ) await stream.app_send( cast(HTTPResponseStartEvent, {"type": "http.response.start", "status": 200, "headers": []}) ) - assert stream.state == ASGIHTTPState.REQUEST - # Must wait for response before sending anything - stream.send.assert_not_called() # type: ignore + assert stream.state == ASGIHTTPState.RESPONSE await stream.app_send( cast(HTTPResponseBodyEvent, {"type": "http.response.body", "body": b"Body"}) ) - assert stream.state == ASGIHTTPState.CLOSED - stream.send.assert_called() # type: ignore - assert stream.send.call_args_list == [ # type: ignore + assert stream.state == ASGIHTTPState.CLOSED # type: ignore + stream.send.assert_called() + assert stream.send.call_args_list == [ call(Response(stream_id=1, headers=[], status_code=200)), call(Body(stream_id=1, data=b"Body")), call(EndBody(stream_id=1)), call(StreamClosed(stream_id=1)), ] - stream.config._log.access.assert_called() # type: ignore + stream.config._log.access.assert_called() @pytest.mark.asyncio @@ -140,6 +190,7 @@ async def test_invalid_server_name(stream: HTTPStream) -> None: headers=[(b"host", b"example.com")], raw_path=b"/", method="GET", + state=ConnectionState({}), ) ) assert stream.send.call_args_list == [ # type: ignore @@ -169,15 +220,102 @@ async def test_send_push(stream: HTTPStream, http_scope: HTTPScope) -> None: http_version="2", method="GET", raw_path=b"/push", + state=ConnectionState({}), + ) + ) + ] + + +@pytest.mark.asyncio +async def test_send_early_hint(stream: HTTPStream, http_scope: HTTPScope) -> None: + stream.scope = http_scope + stream.stream_id = 1 + await stream.app_send( + {"type": "http.response.early_hint", "links": [b'; rel="preload"; as="style"']} + ) + assert stream.send.call_args_list == [ # type: ignore + call( + InformationalResponse( + stream_id=1, + headers=[(b"link", b'; rel="preload"; as="style"')], + status_code=103, ) ) ] +@pytest.mark.asyncio +async def test_send_trailers(stream: HTTPStream) -> None: + await stream.handle( + Request( + stream_id=1, + http_version="2", + headers=[(b"te", b"trailers")], + raw_path=b"/?a=b", + method="GET", + state=ConnectionState({}), + ) + ) + await stream.app_send( + cast( + HTTPResponseStartEvent, + {"type": "http.response.start", "status": 200, "trailers": True}, + ) + ) + await stream.app_send( + cast(HTTPResponseBodyEvent, {"type": "http.response.body", "body": b"Body"}) + ) + await stream.app_send({"type": "http.response.trailers", "headers": [(b"X", b"V")]}) + assert stream.send.call_args_list == [ # type: ignore + call(Response(stream_id=1, headers=[], status_code=200)), + call(Body(stream_id=1, data=b"Body")), + call(Trailers(stream_id=1, headers=[(b"X", b"V")])), + call(EndBody(stream_id=1)), + call(StreamClosed(stream_id=1)), + ] + + +@pytest.mark.asyncio +async def test_send_trailers_ignored(stream: HTTPStream) -> None: + await stream.handle( + Request( + stream_id=1, + http_version="2", + headers=[], # no TE: trailers header + raw_path=b"/?a=b", + method="GET", + state=ConnectionState({}), + ) + ) + await stream.app_send( + cast( + HTTPResponseStartEvent, + {"type": "http.response.start", "status": 200, "trailers": True}, + ) + ) + await stream.app_send( + cast(HTTPResponseBodyEvent, {"type": "http.response.body", "body": b"Body"}) + ) + await stream.app_send({"type": "http.response.trailers", "headers": [(b"X", b"V")]}) + assert stream.send.call_args_list == [ # type: ignore + call(Response(stream_id=1, headers=[], status_code=200)), + call(Body(stream_id=1, data=b"Body")), + call(EndBody(stream_id=1)), + call(StreamClosed(stream_id=1)), + ] + + @pytest.mark.asyncio async def test_send_app_error(stream: HTTPStream) -> None: await stream.handle( - Request(stream_id=1, http_version="2", headers=[], raw_path=b"/?a=b", method="GET") + Request( + stream_id=1, + http_version="2", + headers=[], + raw_path=b"/?a=b", + method="GET", + state=ConnectionState({}), + ) ) await stream.app_send(None) stream.send.assert_called() # type: ignore @@ -200,15 +338,18 @@ async def test_send_app_error(stream: HTTPStream) -> None: [ (ASGIHTTPState.REQUEST, "not_a_real_type"), (ASGIHTTPState.RESPONSE, "http.response.start"), + (ASGIHTTPState.TRAILERS, "http.response.start"), (ASGIHTTPState.CLOSED, "http.response.start"), (ASGIHTTPState.CLOSED, "http.response.body"), + (ASGIHTTPState.CLOSED, "http.response.trailers"), ], ) @pytest.mark.asyncio async def test_send_invalid_message_given_state( - stream: HTTPStream, state: ASGIHTTPState, message_type: str + stream: HTTPStream, state: ASGIHTTPState, http_scope: HTTPScope, message_type: str ) -> None: stream.state = state + stream.scope = http_scope with pytest.raises(UnexpectedMessageError): await stream.app_send({"type": message_type}) # type: ignore @@ -249,6 +390,16 @@ def test_stream_idle(stream: HTTPStream) -> None: @pytest.mark.asyncio async def test_closure(stream: HTTPStream) -> None: + await stream.handle( + Request( + stream_id=1, + http_version="2", + headers=[], + raw_path=b"/?a=b", + method="GET", + state=ConnectionState({}), + ) + ) assert not stream.closed await stream.handle(StreamClosed(stream_id=1)) assert stream.closed @@ -260,9 +411,25 @@ async def test_closure(stream: HTTPStream) -> None: @pytest.mark.asyncio -async def test_closed_app_send_noop(stream: HTTPStream) -> None: - stream.closed = True - await stream.app_send( - cast(HTTPResponseStartEvent, {"type": "http.response.start", "status": 200, "headers": []}) +async def test_abnormal_close_logging() -> None: + config = Config() + config.accesslog = "-" + config.statsd_host = "localhost:9125" + # This exercises an issue where `HTTPStream` at one point called the statsd logger + # with `response=None` when the statsd logger failed to handle it. + config.set_statsd_logger_class(StatsdLogger) + stream = HTTPStream( + AsyncMock(), config, WorkerContext(None), AsyncMock(), False, None, None, AsyncMock(), 1 ) - stream.send.assert_not_called() # type: ignore + + await stream.handle( + Request( + stream_id=1, + http_version="2", + headers=[], + raw_path=b"/?a=b", + method="GET", + state=ConnectionState({}), + ) + ) + await stream.handle(StreamClosed(stream_id=1)) diff --git a/tests/protocol/test_ws_stream.py b/tests/protocol/test_ws_stream.py index cd66622..7b5ee98 100644 --- a/tests/protocol/test_ws_stream.py +++ b/tests/protocol/test_ws_stream.py @@ -5,6 +5,7 @@ from unittest.mock import call, Mock import pytest +import pytest_asyncio from wsproto.events import BytesMessage, TextMessage from hypercorn.asyncio.task_group import TaskGroup @@ -20,6 +21,7 @@ WSStream, ) from hypercorn.typing import ( + ConnectionState, WebsocketAcceptEvent, WebsocketCloseEvent, WebsocketResponseBodyEvent, @@ -161,10 +163,10 @@ def test_handshake_accept_additional_headers() -> None: ] -@pytest.fixture(name="stream") +@pytest_asyncio.fixture(name="stream") # type: ignore[misc] async def _stream() -> WSStream: stream = WSStream( - AsyncMock(), Config(), WorkerContext(), AsyncMock(), False, None, None, AsyncMock(), 1 + AsyncMock(), Config(), WorkerContext(None), AsyncMock(), False, None, None, AsyncMock(), 1 ) stream.task_group.spawn_app.return_value = AsyncMock() # type: ignore stream.app_put = AsyncMock() @@ -181,13 +183,14 @@ async def test_handle_request(stream: WSStream) -> None: headers=[(b"sec-websocket-version", b"13")], raw_path=b"/?a=b", method="GET", + state=ConnectionState({}), ) ) stream.task_group.spawn_app.assert_called() # type: ignore scope = stream.task_group.spawn_app.call_args[0][2] # type: ignore assert scope == { "type": "websocket", - "asgi": {"spec_version": "2.3"}, + "asgi": {"spec_version": "2.3", "version": "3.0"}, "scheme": "ws", "http_version": "2", "path": "/", @@ -199,9 +202,40 @@ async def test_handle_request(stream: WSStream) -> None: "server": None, "subprotocols": [], "extensions": {"websocket.http.response": {}}, + "state": ConnectionState({}), } +@pytest.mark.asyncio +async def test_handle_data_before_acceptance(stream: WSStream) -> None: + await stream.handle( + Request( + stream_id=1, + http_version="2", + headers=[(b"sec-websocket-version", b"13")], + raw_path=b"/?a=b", + method="GET", + state=ConnectionState({}), + ) + ) + await stream.handle( + Data( + stream_id=1, + data=b"X", + ) + ) + assert stream.send.call_args_list == [ # type: ignore + call( + Response( + stream_id=1, + headers=[(b"content-length", b"0"), (b"connection", b"close")], + status_code=400, + ) + ), + call(EndBody(stream_id=1)), + ] + + @pytest.mark.asyncio async def test_handle_connection(stream: WSStream) -> None: await stream.handle( @@ -211,6 +245,7 @@ async def test_handle_connection(stream: WSStream) -> None: headers=[(b"sec-websocket-version", b"13")], raw_path=b"/?a=b", method="GET", + state=ConnectionState({}), ) ) await stream.app_send(cast(WebsocketAcceptEvent, {"type": "websocket.accept"})) @@ -240,6 +275,7 @@ async def test_send_accept(stream: WSStream) -> None: headers=[(b"sec-websocket-version", b"13")], raw_path=b"/", method="GET", + state=ConnectionState({}), ) ) await stream.app_send(cast(WebsocketAcceptEvent, {"type": "websocket.accept"})) @@ -259,6 +295,7 @@ async def test_send_accept_with_additional_headers(stream: WSStream) -> None: headers=[(b"sec-websocket-version", b"13")], raw_path=b"/", method="GET", + state=ConnectionState({}), ) ) await stream.app_send( @@ -283,6 +320,7 @@ async def test_send_reject(stream: WSStream) -> None: headers=[(b"sec-websocket-version", b"13")], raw_path=b"/", method="GET", + state=ConnectionState({}), ) ) await stream.app_send( @@ -297,14 +335,14 @@ async def test_send_reject(stream: WSStream) -> None: await stream.app_send( cast(WebsocketResponseBodyEvent, {"type": "websocket.http.response.body", "body": b"Body"}) ) - assert stream.state == ASGIWebsocketState.HTTPCLOSED - stream.send.assert_called() # type: ignore - assert stream.send.call_args_list == [ # type: ignore + assert stream.state == ASGIWebsocketState.HTTPCLOSED # type: ignore + stream.send.assert_called() + assert stream.send.call_args_list == [ call(Response(stream_id=1, headers=[], status_code=200)), call(Body(stream_id=1, data=b"Body")), call(EndBody(stream_id=1)), ] - stream.config._log.access.assert_called() # type: ignore + stream.config._log.access.assert_called() @pytest.mark.asyncio @@ -317,6 +355,7 @@ async def test_invalid_server_name(stream: WSStream) -> None: headers=[(b"host", b"example.com"), (b"sec-websocket-version", b"13")], raw_path=b"/", method="GET", + state=ConnectionState({}), ) ) assert stream.send.call_args_list == [ # type: ignore @@ -342,6 +381,7 @@ async def test_send_app_error_handshake(stream: WSStream) -> None: headers=[(b"sec-websocket-version", b"13")], raw_path=b"/", method="GET", + state=ConnectionState({}), ) ) await stream.app_send(None) @@ -369,6 +409,7 @@ async def test_send_app_error_connected(stream: WSStream) -> None: headers=[(b"sec-websocket-version", b"13")], raw_path=b"/", method="GET", + state=ConnectionState({}), ) ) await stream.app_send(cast(WebsocketAcceptEvent, {"type": "websocket.accept"})) @@ -376,7 +417,7 @@ async def test_send_app_error_connected(stream: WSStream) -> None: stream.send.assert_called() # type: ignore assert stream.send.call_args_list == [ # type: ignore call(Response(stream_id=1, headers=[], status_code=200)), - call(Data(stream_id=1, data=b"\x88\x02\x03\xe8")), + call(Data(stream_id=1, data=b"\x88\x02\x03\xf3")), call(StreamClosed(stream_id=1)), ] stream.config._log.access.assert_called() # type: ignore @@ -391,6 +432,7 @@ async def test_send_connection(stream: WSStream) -> None: headers=[(b"sec-websocket-version", b"13")], raw_path=b"/", method="GET", + state=ConnectionState({}), ) ) await stream.app_send(cast(WebsocketAcceptEvent, {"type": "websocket.accept"})) @@ -406,7 +448,9 @@ async def test_send_connection(stream: WSStream) -> None: @pytest.mark.asyncio -async def test_pings(stream: WSStream, event_loop: asyncio.AbstractEventLoop) -> None: +async def test_pings(stream: WSStream) -> None: + event_loop: asyncio.AbstractEventLoop = asyncio.get_running_loop() + stream.config.websocket_ping_interval = 0.1 await stream.handle( Request( @@ -415,6 +459,7 @@ async def test_pings(stream: WSStream, event_loop: asyncio.AbstractEventLoop) -> headers=[(b"sec-websocket-version", b"13")], raw_path=b"/?a=b", method="GET", + state=ConnectionState({}), ) ) async with TaskGroup(event_loop) as task_group: diff --git a/tests/middleware/test_wsgi.py b/tests/test_app_wrappers.py similarity index 58% rename from tests/middleware/test_wsgi.py rename to tests/test_app_wrappers.py index c0b9747..0640350 100644 --- a/tests/middleware/test_wsgi.py +++ b/tests/test_app_wrappers.py @@ -1,14 +1,14 @@ from __future__ import annotations import asyncio -from typing import Callable, List +from functools import partial +from typing import Any, Callable, List import pytest import trio -from hypercorn.middleware import AsyncioWSGIMiddleware, TrioWSGIMiddleware -from hypercorn.middleware.wsgi import _build_environ, InvalidPathError -from hypercorn.typing import HTTPScope +from hypercorn.app_wrappers import _build_environ, InvalidPathError, WSGIWrapper +from hypercorn.typing import ASGIReceiveEvent, ASGISendEvent, ConnectionState, HTTPScope def echo_body(environ: dict, start_response: Callable) -> List[bytes]: @@ -24,7 +24,7 @@ def echo_body(environ: dict, start_response: Callable) -> List[bytes]: @pytest.mark.trio async def test_wsgi_trio() -> None: - middleware = TrioWSGIMiddleware(echo_body) + app = WSGIWrapper(echo_body, 2**16) scope: HTTPScope = { "http_version": "1.1", "asgi": {}, @@ -39,30 +39,52 @@ async def test_wsgi_trio() -> None: "client": ("localhost", 80), "server": None, "extensions": {}, + "state": ConnectionState({}), } - send_channel, receive_channel = trio.open_memory_channel(1) - await send_channel.send({"type": "http.request"}) + send_channel, receive_channel = trio.open_memory_channel[ASGIReceiveEvent](1) + await send_channel.send({"type": "http.request"}) # type: ignore messages = [] - async def _send(message: dict) -> None: + async def _send(message: ASGISendEvent) -> None: nonlocal messages messages.append(message) - await middleware(scope, receive_channel.receive, _send) + await app(scope, receive_channel.receive, _send, trio.to_thread.run_sync, trio.from_thread.run) assert messages == [ { "headers": [(b"content-type", b"text/plain; charset=utf-8"), (b"content-length", b"0")], "status": 200, "type": "http.response.start", }, - {"body": bytearray(b""), "type": "http.response.body"}, + {"body": bytearray(b""), "type": "http.response.body", "more_body": True}, + {"body": bytearray(b""), "type": "http.response.body", "more_body": False}, ] +async def _run_app(app: WSGIWrapper, scope: HTTPScope, body: bytes = b"") -> List[ASGISendEvent]: + queue: asyncio.Queue = asyncio.Queue() + await queue.put({"type": "http.request", "body": body}) + + messages = [] + + async def _send(message: ASGISendEvent) -> None: + nonlocal messages + messages.append(message) + + event_loop = asyncio.get_running_loop() + + def _call_soon(func: Callable, *args: Any) -> Any: + future = asyncio.run_coroutine_threadsafe(func(*args), event_loop) + return future.result() + + await app(scope, queue.get, _send, partial(event_loop.run_in_executor, None), _call_soon) + return messages + + @pytest.mark.asyncio async def test_wsgi_asyncio() -> None: - middleware = AsyncioWSGIMiddleware(echo_body) + app = WSGIWrapper(echo_body, 2**16) scope: HTTPScope = { "http_version": "1.1", "asgi": {}, @@ -77,30 +99,23 @@ async def test_wsgi_asyncio() -> None: "client": ("localhost", 80), "server": None, "extensions": {}, + "state": ConnectionState({}), } - queue: asyncio.Queue = asyncio.Queue() - await queue.put({"type": "http.request"}) - - messages = [] - - async def _send(message: dict) -> None: - nonlocal messages - messages.append(message) - - await middleware(scope, queue.get, _send) + messages = await _run_app(app, scope) assert messages == [ { "headers": [(b"content-type", b"text/plain; charset=utf-8"), (b"content-length", b"0")], "status": 200, "type": "http.response.start", }, - {"body": bytearray(b""), "type": "http.response.body"}, + {"body": bytearray(b""), "type": "http.response.body", "more_body": True}, + {"body": bytearray(b""), "type": "http.response.body", "more_body": False}, ] @pytest.mark.asyncio async def test_max_body_size() -> None: - middleware = AsyncioWSGIMiddleware(echo_body, max_body_size=4) + app = WSGIWrapper(echo_body, 4) scope: HTTPScope = { "http_version": "1.1", "asgi": {}, @@ -115,22 +130,42 @@ async def test_max_body_size() -> None: "client": ("localhost", 80), "server": None, "extensions": {}, + "state": ConnectionState({}), } - queue: asyncio.Queue = asyncio.Queue() - await queue.put({"type": "http.request", "body": b"abcde"}) - messages = [] - - async def _send(message: dict) -> None: - nonlocal messages - messages.append(message) - - await middleware(scope, queue.get, _send) + messages = await _run_app(app, scope, b"abcde") assert messages == [ {"headers": [], "status": 400, "type": "http.response.start"}, - {"body": bytearray(b""), "type": "http.response.body"}, + {"body": bytearray(b""), "type": "http.response.body", "more_body": False}, ] +def no_start_response(environ: dict, start_response: Callable) -> List[bytes]: + return [b"result"] + + +@pytest.mark.asyncio +async def test_no_start_response() -> None: + app = WSGIWrapper(no_start_response, 2**16) + scope: HTTPScope = { + "http_version": "1.1", + "asgi": {}, + "method": "GET", + "headers": [], + "path": "/", + "root_path": "/", + "query_string": b"a=b", + "raw_path": b"/", + "scheme": "http", + "type": "http", + "client": ("localhost", 80), + "server": None, + "extensions": {}, + "state": ConnectionState({}), + } + with pytest.raises(RuntimeError): + await _run_app(app, scope) + + def test_build_environ_encoding() -> None: scope: HTTPScope = { "http_version": "1.0", @@ -146,6 +181,7 @@ def test_build_environ_encoding() -> None: "client": ("localhost", 80), "server": None, "extensions": {}, + "state": ConnectionState({}), } environ = _build_environ(scope, b"") assert environ["SCRIPT_NAME"] == "/δΈ­".encode("utf8").decode("latin-1") @@ -167,6 +203,7 @@ def test_build_environ_root_path() -> None: "client": ("localhost", 80), "server": None, "extensions": {}, + "state": ConnectionState({}), } with pytest.raises(InvalidPathError): _build_environ(scope, b"") diff --git a/tests/test_config.py b/tests/test_config.py index fe758b5..fa511cf 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -53,12 +53,23 @@ def test_create_ssl_context() -> None: path = os.path.join(os.path.dirname(__file__), "assets/config_ssl.py") config = Config.from_pyfile(path) context = config.create_ssl_context() - assert context.options & ( - ssl.OP_NO_SSLv2 - | ssl.OP_NO_SSLv3 - | ssl.OP_NO_TLSv1 - | ssl.OP_NO_TLSv1_1 - | ssl.OP_NO_COMPRESSION + + # NOTE: In earlier versions of python context.options is equal to + # hence the ANDing context.options with the specified ssl options results in + # "", which as a Boolean value, is False. + # + # To overcome this, instead of checking that the result in True, we will check that it is + # equal to "context.options". + assert ( + context.options + & ( + ssl.OP_NO_SSLv2 + | ssl.OP_NO_SSLv3 + | ssl.OP_NO_TLSv1 + | ssl.OP_NO_TLSv1_1 + | ssl.OP_NO_COMPRESSION + ) + == context.options ) diff --git a/tests/test_utils.py b/tests/test_utils.py index e589e95..632161a 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,58 +1,26 @@ from __future__ import annotations -from typing import Callable +from typing import Any, Callable, Iterable import pytest -import hypercorn.utils -from hypercorn.typing import ASGIFramework, HTTPScope, Scope +from hypercorn.typing import Scope +from hypercorn.utils import ( + build_and_validate_headers, + filter_pseudo_headers, + is_asgi, + suppress_body, +) @pytest.mark.parametrize( "method, status, expected", [("HEAD", 200, True), ("GET", 200, False), ("GET", 101, True)] ) def test_suppress_body(method: str, status: int, expected: bool) -> None: - assert hypercorn.utils.suppress_body(method, status) is expected - - -@pytest.mark.asyncio -async def test_invoke_asgi_3(http_scope: HTTPScope) -> None: - result: Scope = {} # type: ignore - - async def asgi3_callable(scope: Scope, receive: Callable, send: Callable) -> None: - nonlocal result - result = scope - - await hypercorn.utils.invoke_asgi(asgi3_callable, http_scope, None, None) - assert result["asgi"]["version"] == "3.0" - - -@pytest.mark.asyncio -async def test_invoke_asgi_2(http_scope: HTTPScope) -> None: - result: Scope = {} # type: ignore - - def asgi2_callable(scope: Scope) -> Callable: - nonlocal result - result = scope - - async def inner(receive: Callable, send: Callable) -> None: - pass - - return inner - - await hypercorn.utils.invoke_asgi(asgi2_callable, http_scope, None, None) # type: ignore - assert result["asgi"]["version"] == "2.0" + assert suppress_body(method, status) is expected -class ASGI2Class: - def __init__(self, scope: Scope) -> None: - pass - - async def __call__(self, receive: Callable, send: Callable) -> None: - pass - - -class ASGI3ClassInstance: +class ASGIClassInstance: def __init__(self) -> None: pass @@ -60,49 +28,54 @@ async def __call__(self, scope: Scope, receive: Callable, send: Callable) -> Non pass -def asgi2_callable(scope: Scope) -> Callable: - async def inner(receive: Callable, send: Callable) -> None: +async def asgi_callable(scope: Scope, receive: Callable, send: Callable) -> None: + pass + + +class WSGIClassInstance: + def __init__(self) -> None: pass - return inner + def __call__(self, environ: dict, start_response: Callable) -> Iterable[bytes]: + pass -async def asgi3_callable(scope: Scope, receive: Callable, send: Callable) -> None: +def wsgi_callable(environ: dict, start_response: Callable) -> Iterable[bytes]: pass @pytest.mark.parametrize( - "app, is_asgi_2", + "app, expected", [ - (ASGI2Class, True), - (ASGI3ClassInstance(), False), - (asgi2_callable, True), - (asgi3_callable, False), + (WSGIClassInstance(), False), + (ASGIClassInstance(), True), + (wsgi_callable, False), + (asgi_callable, True), ], ) -def test__is_asgi_2(app: ASGIFramework, is_asgi_2: bool) -> None: - assert hypercorn.utils._is_asgi_2(app) == is_asgi_2 +def test_is_asgi(app: Any, expected: bool) -> None: + assert is_asgi(app) == expected def test_build_and_validate_headers_validate() -> None: with pytest.raises(TypeError): - hypercorn.utils.build_and_validate_headers([("string", "string")]) # type: ignore + build_and_validate_headers([("string", "string")]) # type: ignore def test_build_and_validate_headers_pseudo() -> None: with pytest.raises(ValueError): - hypercorn.utils.build_and_validate_headers([(b":authority", b"quart")]) + build_and_validate_headers([(b":authority", b"quart")]) def test_filter_pseudo_headers() -> None: - result = hypercorn.utils.filter_pseudo_headers( + result = filter_pseudo_headers( [(b":authority", b"quart"), (b":path", b"/"), (b"user-agent", b"something")] ) assert result == [(b"host", b"quart"), (b"user-agent", b"something")] def test_filter_pseudo_headers_no_authority() -> None: - result = hypercorn.utils.filter_pseudo_headers( + result = filter_pseudo_headers( [(b"host", b"quart"), (b":path", b"/"), (b"user-agent", b"something")] ) assert result == [(b"host", b"quart"), (b"user-agent", b"something")] diff --git a/tests/trio/test_keep_alive.py b/tests/trio/test_keep_alive.py index 796c61f..c570a2a 100644 --- a/tests/trio/test_keep_alive.py +++ b/tests/trio/test_keep_alive.py @@ -1,30 +1,45 @@ from __future__ import annotations -from typing import Callable, Generator +from typing import Awaitable, Callable, cast, Generator import h11 import pytest import trio +from hypercorn.app_wrappers import ASGIWrapper from hypercorn.config import Config from hypercorn.trio.tcp_server import TCPServer from hypercorn.trio.worker_context import WorkerContext -from hypercorn.typing import Scope +from hypercorn.typing import ASGIReceiveEvent, ASGISendEvent, Scope from ..helpers import MockSocket +try: + from typing import TypeAlias +except ImportError: + from typing_extensions import TypeAlias + + KEEP_ALIVE_TIMEOUT = 0.01 REQUEST = h11.Request(method="GET", target="/", headers=[(b"host", b"hypercorn")]) +ClientStream: TypeAlias = trio.StapledStream[ + trio.testing.MemorySendStream, trio.testing.MemoryReceiveStream +] + -async def slow_framework(scope: Scope, receive: Callable, send: Callable) -> None: +async def slow_framework( + scope: Scope, + receive: Callable[[], Awaitable[ASGIReceiveEvent]], + send: Callable[[ASGISendEvent], Awaitable[None]], +) -> None: while True: event = await receive() if event["type"] == "http.disconnect": break elif event["type"] == "lifespan.startup": - await send({"type": "lifspan.startup.complete"}) + await send({"type": "lifespan.startup.complete"}) elif event["type"] == "lifespan.shutdown": - await send({"type": "lifspan.shutdown.complete"}) + await send({"type": "lifespan.shutdown.complete"}) elif event["type"] == "http.request" and not event.get("more_body", False): await trio.sleep(2 * KEEP_ALIVE_TIMEOUT) await send( @@ -40,21 +55,20 @@ async def slow_framework(scope: Scope, receive: Callable, send: Callable) -> Non @pytest.fixture(name="client_stream", scope="function") def _client_stream( - nursery: trio._core._run.Nursery, -) -> Generator[trio.testing._memory_streams.MemorySendStream, None, None]: + nursery: trio.Nursery, +) -> Generator[ClientStream, None, None]: config = Config() config.keep_alive_timeout = KEEP_ALIVE_TIMEOUT client_stream, server_stream = trio.testing.memory_stream_pair() + server_stream = cast("trio.SSLStream[trio.SocketStream]", server_stream) server_stream.socket = MockSocket() - server = TCPServer(slow_framework, config, WorkerContext(), server_stream) + server = TCPServer(ASGIWrapper(slow_framework), config, WorkerContext(None), {}, server_stream) nursery.start_soon(server.run) yield client_stream @pytest.mark.trio -async def test_http1_keep_alive_pre_request( - client_stream: trio.testing._memory_streams.MemorySendStream, -) -> None: +async def test_http1_keep_alive_pre_request(client_stream: ClientStream) -> None: await client_stream.send_all(b"GET") await trio.sleep(2 * KEEP_ALIVE_TIMEOUT) # Only way to confirm closure is to invoke an error @@ -64,44 +78,47 @@ async def test_http1_keep_alive_pre_request( @pytest.mark.trio async def test_http1_keep_alive_during( - client_stream: trio.testing._memory_streams.MemorySendStream, + client_stream: ClientStream, ) -> None: client = h11.Connection(h11.CLIENT) - await client_stream.send_all(client.send(REQUEST)) + # client.send(h11.Request) and client.send(h11.EndOfMessage) only returns bytes. + # Fixed on master/ in the h11 repo, once released the ignore's can be removed. + # See https://github.com/python-hyper/h11/issues/175 + await client_stream.send_all(client.send(REQUEST)) # type: ignore[arg-type] await trio.sleep(2 * KEEP_ALIVE_TIMEOUT) # Key is that this doesn't error - await client_stream.send_all(client.send(h11.EndOfMessage())) + await client_stream.send_all(client.send(h11.EndOfMessage())) # type: ignore[arg-type] @pytest.mark.trio async def test_http1_keep_alive( - client_stream: trio.testing._memory_streams.MemorySendStream, + client_stream: ClientStream, ) -> None: client = h11.Connection(h11.CLIENT) - await client_stream.send_all(client.send(REQUEST)) + await client_stream.send_all(client.send(REQUEST)) # type: ignore[arg-type] await trio.sleep(2 * KEEP_ALIVE_TIMEOUT) - await client_stream.send_all(client.send(h11.EndOfMessage())) + await client_stream.send_all(client.send(h11.EndOfMessage())) # type: ignore[arg-type] while True: event = client.next_event() if event == h11.NEED_DATA: - data = await client_stream.receive_some(2 ** 16) + data = await client_stream.receive_some(2**16) client.receive_data(data) elif isinstance(event, h11.EndOfMessage): break client.start_next_cycle() - await client_stream.send_all(client.send(REQUEST)) + await client_stream.send_all(client.send(REQUEST)) # type: ignore[arg-type] await trio.sleep(2 * KEEP_ALIVE_TIMEOUT) # Key is that this doesn't error - await client_stream.send_all(client.send(h11.EndOfMessage())) + await client_stream.send_all(client.send(h11.EndOfMessage())) # type: ignore[arg-type] @pytest.mark.trio async def test_http1_keep_alive_pipelining( - client_stream: trio.testing._memory_streams.MemorySendStream, + client_stream: ClientStream, ) -> None: await client_stream.send_all( b"GET / HTTP/1.1\r\nHost: hypercorn\r\n\r\nGET / HTTP/1.1\r\nHost: hypercorn\r\n\r\n" ) - await client_stream.receive_some(2 ** 16) + await client_stream.receive_some(2**16) await trio.sleep(2 * KEEP_ALIVE_TIMEOUT) await client_stream.send_all(b"") diff --git a/tests/trio/test_lifespan.py b/tests/trio/test_lifespan.py index afa1f83..1dbc008 100644 --- a/tests/trio/test_lifespan.py +++ b/tests/trio/test_lifespan.py @@ -1,31 +1,49 @@ from __future__ import annotations +import sys + +if sys.version_info < (3, 11): + from exceptiongroup import ExceptionGroup + import pytest import trio +from hypercorn.app_wrappers import ASGIWrapper from hypercorn.config import Config from hypercorn.trio.lifespan import Lifespan +from hypercorn.typing import ASGIReceiveCallable, ASGISendCallable, Scope from hypercorn.utils import LifespanFailureError, LifespanTimeoutError -from ..helpers import lifespan_failure, SlowLifespanFramework +from ..helpers import SlowLifespanFramework @pytest.mark.trio async def test_startup_timeout_error(nursery: trio._core._run.Nursery) -> None: config = Config() config.startup_timeout = 0.01 - lifespan = Lifespan(SlowLifespanFramework(0.02, trio.sleep), config) # type: ignore + lifespan = Lifespan(ASGIWrapper(SlowLifespanFramework(0.02, trio.sleep)), config, {}) nursery.start_soon(lifespan.handle_lifespan) with pytest.raises(LifespanTimeoutError) as exc_info: await lifespan.wait_for_startup() assert str(exc_info.value).startswith("Timeout whilst awaiting startup") +async def _lifespan_failure( + scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable +) -> None: + async with trio.open_nursery(): + while True: + message = await receive() + if message["type"] == "lifespan.startup": + await send({"type": "lifespan.startup.failed", "message": "Failure"}) + break + + @pytest.mark.trio async def test_startup_failure() -> None: - lifespan = Lifespan(lifespan_failure, Config()) - with pytest.raises(LifespanFailureError) as exc_info: + lifespan = Lifespan(ASGIWrapper(_lifespan_failure), Config(), {}) + try: async with trio.open_nursery() as lifespan_nursery: await lifespan_nursery.start(lifespan.handle_lifespan) await lifespan.wait_for_startup() - - assert str(exc_info.value) == "Lifespan failure in startup. 'Failure'" + except ExceptionGroup as error: + assert error.subgroup(LifespanFailureError) is not None diff --git a/tests/trio/test_sanity.py b/tests/trio/test_sanity.py index 1263333..bea93f1 100644 --- a/tests/trio/test_sanity.py +++ b/tests/trio/test_sanity.py @@ -1,5 +1,6 @@ from __future__ import annotations +from typing import cast from unittest.mock import Mock, PropertyMock import h2 @@ -8,6 +9,7 @@ import trio import wsproto +from hypercorn.app_wrappers import ASGIWrapper from hypercorn.config import Config from hypercorn.trio.tcp_server import TCPServer from hypercorn.trio.worker_context import WorkerContext @@ -23,12 +25,15 @@ @pytest.mark.trio async def test_http1_request(nursery: trio._core._run.Nursery) -> None: client_stream, server_stream = trio.testing.memory_stream_pair() + server_stream = cast("trio.SSLStream[trio.SocketStream]", server_stream) server_stream.socket = MockSocket() - server = TCPServer(sanity_framework, Config(), WorkerContext(), server_stream) + server = TCPServer( + ASGIWrapper(sanity_framework), Config(), WorkerContext(None), {}, server_stream + ) nursery.start_soon(server.run) client = h11.Connection(h11.CLIENT) await client_stream.send_all( - client.send( + client.send( # type: ignore[arg-type] h11.Request( method="POST", target="/", @@ -40,8 +45,8 @@ async def test_http1_request(nursery: trio._core._run.Nursery) -> None: ) ) ) - await client_stream.send_all(client.send(h11.Data(data=SANITY_BODY))) - await client_stream.send_all(client.send(h11.EndOfMessage())) + await client_stream.send_all(client.send(h11.Data(data=SANITY_BODY))) # type: ignore[arg-type] + await client_stream.send_all(client.send(h11.EndOfMessage())) # type: ignore[arg-type] events = [] while True: event = client.next_event() @@ -74,8 +79,11 @@ async def test_http1_request(nursery: trio._core._run.Nursery) -> None: @pytest.mark.trio async def test_http1_websocket(nursery: trio._core._run.Nursery) -> None: client_stream, server_stream = trio.testing.memory_stream_pair() + server_stream = cast("trio.SSLStream[trio.SocketStream]", server_stream) server_stream.socket = MockSocket() - server = TCPServer(sanity_framework, Config(), WorkerContext(), server_stream) + server = TCPServer( + ASGIWrapper(sanity_framework), Config(), WorkerContext(None), {}, server_stream + ) nursery.start_soon(server.run) client = wsproto.WSConnection(wsproto.ConnectionType.CLIENT) await client_stream.send_all(client.send(wsproto.events.Request(host="hypercorn", target="/"))) @@ -99,10 +107,13 @@ async def test_http1_websocket(nursery: trio._core._run.Nursery) -> None: @pytest.mark.trio async def test_http2_request(nursery: trio._core._run.Nursery) -> None: client_stream, server_stream = trio.testing.memory_stream_pair() + server_stream = cast("trio.SSLStream[trio.SocketStream]", server_stream) server_stream.transport_stream = Mock(return_value=PropertyMock(return_value=MockSocket())) - server_stream.do_handshake = AsyncMock() + server_stream.do_handshake = AsyncMock() # type: ignore[method-assign] server_stream.selected_alpn_protocol = Mock(return_value="h2") - server = TCPServer(sanity_framework, Config(), WorkerContext(), server_stream) + server = TCPServer( + ASGIWrapper(sanity_framework), Config(), WorkerContext(None), {}, server_stream + ) nursery.start_soon(server.run) client = h2.connection.H2Connection() client.initiate_connection() @@ -154,10 +165,13 @@ async def test_http2_request(nursery: trio._core._run.Nursery) -> None: @pytest.mark.trio async def test_http2_websocket(nursery: trio._core._run.Nursery) -> None: client_stream, server_stream = trio.testing.memory_stream_pair() + server_stream = cast("trio.SSLStream[trio.SocketStream]", server_stream) server_stream.transport_stream = Mock(return_value=PropertyMock(return_value=MockSocket())) - server_stream.do_handshake = AsyncMock() + server_stream.do_handshake = AsyncMock() # type: ignore[method-assign] server_stream.selected_alpn_protocol = Mock(return_value="h2") - server = TCPServer(sanity_framework, Config(), WorkerContext(), server_stream) + server = TCPServer( + ASGIWrapper(sanity_framework), Config(), WorkerContext(None), {}, server_stream + ) nursery.start_soon(server.run) h2_client = h2.connection.H2Connection() h2_client.initiate_connection() diff --git a/tox.ini b/tox.ini index 675992b..d931bfa 100644 --- a/tox.ini +++ b/tox.ini @@ -1,11 +1,12 @@ [tox] -envlist = docs,format,mypy,py37,py38,py39,py310,package,pep8 +envlist = docs,format,mypy,py38,py39,py310,py311,py312,package,pep8 minversion = 3.3 isolated_build = true [testenv] deps = py37: mock + httpx hypothesis pytest pytest-asyncio @@ -15,17 +16,18 @@ deps = commands = pytest --cov=hypercorn {posargs} [testenv:docs] -basepython = python3.10 +basepython = python3.12 deps = pydata-sphinx-theme sphinx + sphinxcontrib-mermaid trio commands = sphinx-apidoc -e -f -o docs/reference/source/ src/hypercorn/ src/hypercorn/protocol/quic.py src/hypercorn/protocol/h3.py sphinx-build -W --keep-going -b html -d {envtmpdir}/doctrees docs/ docs/_build/html/ [testenv:format] -basepython = python3.10 +basepython = python3.12 deps = black isort @@ -34,7 +36,7 @@ commands = isort --check --diff src/hypercorn tests [testenv:pep8] -basepython = python3.10 +basepython = python3.12 deps = flake8 pep8-naming @@ -43,16 +45,16 @@ deps = commands = flake8 src/hypercorn/ tests/ [testenv:mypy] -basepython = python3.10 +basepython = python3.12 deps = mypy pytest - types-toml + trio commands = mypy src/hypercorn/ tests/ [testenv:package] -basepython = python3.10 +basepython = python3.12 deps = poetry twine