from collections.abc import Sequence from functools import partial from inspect import isawaitable from operator import itemgetter from typing import Union from sanic import Sanic from sanic.constants import HTTPMethod from sanic.exceptions import SanicException from sanic.response import empty, raw from sanic_ext.config import PRIORITY from sanic_ext.extensions.openapi import openapi from sanic_ext.utils.route import clean_route_name def add_http_methods( app: Sanic, methods: Sequence[Union[str, HTTPMethod]] ) -> None: """ Adds HTTP methods to an app :param app: Your Sanic app :type app: Sanic :param methods: The http methods being added, eg: CONNECT, TRACE :type methods: Sequence[str] """ app.router.ALLOWED_METHODS = tuple( [*app.router.ALLOWED_METHODS, *methods] # type: ignore ) def add_auto_handlers( app: Sanic, auto_head: bool, auto_options: bool, auto_trace: bool ) -> None: if auto_trace and "TRACE" not in app.router.ALLOWED_METHODS: raise SanicException( "Cannot use apply(..., auto_trace=True) if TRACE is not an " "allowed HTTP method. Make sure apply(..., all_http_methods=True) " "has been set." ) async def head_handler(request, get_handler, *args, **kwargs): retval = get_handler(request, *args, **kwargs) if isawaitable(retval): retval = await retval return retval async def options_handler(request, methods, *args, **kwargs): resp = empty() resp.headers["allow"] = ",".join([*methods, "OPTIONS"]) return resp async def trace_handler(request): cleaned_head = b"" for line in request.head.split(b"\r\n"): first_word, _ = line.split(b" ", 1) if ( first_word.lower().replace(b":", b"").decode("utf-8") not in request.app.config.TRACE_EXCLUDED_HEADERS ): cleaned_head += line + b"\r\n" message = "\r\n\r\n".join( [part.decode("utf-8") for part in [cleaned_head, request.body]] ) return raw(message, content_type="message/http") @app.before_server_start(priority=PRIORITY) def _add_handlers(app, _): nonlocal auto_head nonlocal auto_options if auto_head: app.router.reset() for group in app.router.groups.values(): if "GET" in group.methods and "HEAD" not in group.methods: for route in group: if "GET" in route.methods: host = route.requirements.get("host") name = f"{route.name}_head" handler = openapi.definition( summary=clean_route_name(route.name).title(), description="Retrieve HEAD details", )(partial(head_handler, get_handler=route.handler)) handler.__auto_handler__ = True handler.__route_handler__ = route.handler app.add_route( handler=handler, uri=group.uri, methods=["HEAD"], strict_slashes=group.strict, name=name, host=host, unquote=group.unquote, ) app.finalize() if auto_trace: app.router.reset() for group in app.router.groups.values(): if "TRACE" not in group.methods: app.add_route( handler=trace_handler, uri=group.uri, methods=["TRACE"], strict_slashes=group.strict, ) app.finalize() if auto_options: app.router.reset() for group in app.router.groups.values(): if "OPTIONS" not in group.methods: if not group.requirements: hosts = [None] else: hosts = set( map(itemgetter("host"), group.requirements) ) try: base_route = next( r for r in group if not r.name.endswith("_head") ) except StopIteration: base_route = group[0] name = f"{base_route.name}_options" handler = openapi.definition( summary=clean_route_name(base_route.name).title(), description="Retrieve OPTIONS details", )(partial(options_handler, methods=group.methods)) handler.__auto_handler__ = True app.add_route( handler=handler, uri=group.uri, methods=["OPTIONS"], strict_slashes=group.strict, name=name, host=hosts, unquote=group.unquote, ) app.finalize()