from __future__ import annotations import asyncio from collections import deque from dataclasses import dataclass from enum import Enum from inspect import isawaitable from typing import Any, cast from sanic_routing import BaseRouter, Route, RouteGroup from sanic_routing.exceptions import NotFound from sanic_routing.utils import path_to_parts from sanic.exceptions import InvalidSignal from sanic.log import error_logger, logger from sanic.models.handler_types import SignalHandler class Event(Enum): """Event names for the SignalRouter""" SERVER_EXCEPTION_REPORT = "server.exception.report" SERVER_INIT_AFTER = "server.init.after" SERVER_INIT_BEFORE = "server.init.before" SERVER_SHUTDOWN_AFTER = "server.shutdown.after" SERVER_SHUTDOWN_BEFORE = "server.shutdown.before" HTTP_LIFECYCLE_BEGIN = "http.lifecycle.begin" HTTP_LIFECYCLE_COMPLETE = "http.lifecycle.complete" HTTP_LIFECYCLE_EXCEPTION = "http.lifecycle.exception" HTTP_LIFECYCLE_HANDLE = "http.lifecycle.handle" HTTP_LIFECYCLE_READ_BODY = "http.lifecycle.read_body" HTTP_LIFECYCLE_READ_HEAD = "http.lifecycle.read_head" HTTP_LIFECYCLE_REQUEST = "http.lifecycle.request" HTTP_LIFECYCLE_RESPONSE = "http.lifecycle.response" HTTP_ROUTING_AFTER = "http.routing.after" HTTP_ROUTING_BEFORE = "http.routing.before" HTTP_HANDLER_AFTER = "http.handler.after" HTTP_HANDLER_BEFORE = "http.handler.before" HTTP_LIFECYCLE_SEND = "http.lifecycle.send" HTTP_MIDDLEWARE_AFTER = "http.middleware.after" HTTP_MIDDLEWARE_BEFORE = "http.middleware.before" WEBSOCKET_HANDLER_AFTER = "websocket.handler.after" WEBSOCKET_HANDLER_BEFORE = "websocket.handler.before" WEBSOCKET_HANDLER_EXCEPTION = "websocket.handler.exception" RESERVED_NAMESPACES = { "server": ( Event.SERVER_EXCEPTION_REPORT.value, Event.SERVER_INIT_AFTER.value, Event.SERVER_INIT_BEFORE.value, Event.SERVER_SHUTDOWN_AFTER.value, Event.SERVER_SHUTDOWN_BEFORE.value, ), "http": ( Event.HTTP_LIFECYCLE_BEGIN.value, Event.HTTP_LIFECYCLE_COMPLETE.value, Event.HTTP_LIFECYCLE_EXCEPTION.value, Event.HTTP_LIFECYCLE_HANDLE.value, Event.HTTP_LIFECYCLE_READ_BODY.value, Event.HTTP_LIFECYCLE_READ_HEAD.value, Event.HTTP_LIFECYCLE_REQUEST.value, Event.HTTP_LIFECYCLE_RESPONSE.value, Event.HTTP_ROUTING_AFTER.value, Event.HTTP_ROUTING_BEFORE.value, Event.HTTP_HANDLER_AFTER.value, Event.HTTP_HANDLER_BEFORE.value, Event.HTTP_LIFECYCLE_SEND.value, Event.HTTP_MIDDLEWARE_AFTER.value, Event.HTTP_MIDDLEWARE_BEFORE.value, ), "websocket": { Event.WEBSOCKET_HANDLER_AFTER.value, Event.WEBSOCKET_HANDLER_BEFORE.value, Event.WEBSOCKET_HANDLER_EXCEPTION.value, }, } GENERIC_SIGNAL_FORMAT = "__generic__.__signal__.%s" def _blank(): ... class Signal(Route): """A `Route` that is used to dispatch signals to handlers""" @dataclass class SignalWaiter: """A record representing a future waiting for a signal""" signal: Signal event_definition: str trigger: str = "" requirements: dict[str, str] | None = None exclusive: bool = True future: asyncio.Future | None = None async def wait(self): """Block until the signal is next dispatched. Return the context of the signal dispatch, if any. """ loop = asyncio.get_running_loop() self.future = loop.create_future() self.signal.ctx.waiters.append(self) try: return await self.future finally: self.signal.ctx.waiters.remove(self) def matches(self, event, condition): return ( (condition is None and not self.exclusive) or (condition is None and not self.requirements) or condition == self.requirements ) and (self.trigger or event == self.event_definition) class SignalGroup(RouteGroup): """A `RouteGroup` that is used to dispatch signals to handlers""" class SignalRouter(BaseRouter): """A `BaseRouter` that is used to dispatch signals to handlers""" def __init__(self) -> None: super().__init__( delimiter=".", route_class=Signal, group_class=SignalGroup, stacking=True, ) self.allow_fail_builtin = True self.ctx.loop = None @staticmethod def format_event(event: str | Enum) -> str: """Ensure event strings in proper format Args: event (str): event string Returns: str: formatted event string """ if isinstance(event, Enum): event = str(event.value) if "." not in event: event = GENERIC_SIGNAL_FORMAT % event return event def get( # type: ignore self, event: str | Enum, condition: dict[str, str] | None = None, ): """Get the handlers for a signal Args: event (str): The event to get the handlers for condition (Optional[Dict[str, str]], optional): A dictionary of conditions to match against the handlers. Defaults to `None`. Returns: Tuple[SignalGroup, List[SignalHandler], Dict[str, Any]]: A tuple of the `SignalGroup` that matched, a list of the handlers that matched, and a dictionary of the params that matched Raises: NotFound: If no handlers are found """ # noqa: E501 event = self.format_event(event) extra = condition or {} try: group, param_basket = self.find_route( f".{event}", self.DEFAULT_METHOD, self, {"__params__": {}, "__matches__": {}}, extra=extra, ) except NotFound: message = "Could not find signal %s" terms: list[str | dict[str, str] | None] = [event] if extra: message += " with %s" terms.append(extra) raise NotFound(message % tuple(terms)) # Regex routes evaluate and can extract params directly. They are set # on param_basket["__params__"] params = param_basket["__params__"] if not params: # If param_basket["__params__"] does not exist, we might have # param_basket["__matches__"], which are indexed based matches # on path segments. They should already be cast types. params = { param.name: param_basket["__matches__"][idx] for idx, param in group.params.items() } return group, [route.handler for route in group], params async def _dispatch( self, event: str, context: dict[str, Any] | None = None, condition: dict[str, str] | None = None, fail_not_found: bool = True, reverse: bool = False, ) -> Any: event = self.format_event(event) try: group, handlers, params = self.get(event, condition=condition) except NotFound as e: is_reserved = event.split(".", 1)[0] in RESERVED_NAMESPACES if fail_not_found and (not is_reserved or self.allow_fail_builtin): raise e else: if self.ctx.app.debug and self.ctx.app.state.verbosity >= 1: error_logger.warning(str(e)) return None if context: params.update(context) params.pop("__trigger__", None) signals = group.routes if not reverse: signals = signals[::-1] try: for signal in signals: for waiter in signal.ctx.waiters: if waiter.matches(event, condition): waiter.future.set_result(dict(params)) for signal in signals: requirements = signal.extra.requirements if ( (condition is None and signal.ctx.exclusive is False) or (condition is None and not requirements) or (condition == requirements) ) and (signal.ctx.trigger or event == signal.ctx.definition): maybe_coroutine = signal.handler(**params) if isawaitable(maybe_coroutine): retval = await maybe_coroutine if retval: return retval elif maybe_coroutine: return maybe_coroutine return None except Exception as e: if self.ctx.app.debug and self.ctx.app.state.verbosity >= 1: error_logger.exception(e) if event != Event.SERVER_EXCEPTION_REPORT.value: await self.dispatch( Event.SERVER_EXCEPTION_REPORT.value, context={"exception": e}, ) setattr(e, "__dispatched__", True) raise e async def dispatch( self, event: str | Enum, *, context: dict[str, Any] | None = None, condition: dict[str, str] | None = None, fail_not_found: bool = True, inline: bool = False, reverse: bool = False, ) -> asyncio.Task | Any: """Dispatch a signal to all handlers that match the event Args: event (str): The event to dispatch context (Optional[Dict[str, Any]], optional): A dictionary of context to pass to the handlers. Defaults to `None`. condition (Optional[Dict[str, str]], optional): A dictionary of conditions to match against the handlers. Defaults to `None`. fail_not_found (bool, optional): Whether to raise an exception if no handlers are found. Defaults to `True`. inline (bool, optional): Whether to run the handlers inline. An inline run means it will return the value of the signal handler. When `False` (which is the default) the signal handler will run in a background task. Defaults to `False`. reverse (bool, optional): Whether to run the handlers in reverse order. Defaults to `False`. Returns: Union[asyncio.Task, Any]: If `inline` is `True` then the return value of the signal handler will be returned. If `inline` is `False` then an `asyncio.Task` will be returned. Raises: RuntimeError: If the signal is dispatched outside of an event loop """ # noqa: E501 event = self.format_event(event) dispatch = self._dispatch( event, context=context, condition=condition, fail_not_found=fail_not_found and inline, reverse=reverse, ) logger.debug(f"Dispatching signal: {event}", extra={"verbosity": 1}) if inline: return await dispatch task = asyncio.get_running_loop().create_task(dispatch) await asyncio.sleep(0) return task def get_waiter( self, event: str | Enum, condition: dict[str, Any] | None = None, exclusive: bool = True, ) -> SignalWaiter | None: event_definition = self.format_event(event) name, trigger, _ = self._get_event_parts(event_definition) signal = cast(Signal, self.name_index.get(name)) if not signal: return None if event_definition.endswith(".*") and not trigger: trigger = "*" return SignalWaiter( signal=signal, event_definition=event_definition, trigger=trigger, requirements=condition, exclusive=bool(exclusive), ) def _get_event_parts(self, event: str) -> tuple[str, str, str]: parts = self._build_event_parts(event) if parts[2].startswith("<"): name = ".".join([*parts[:-1], "*"]) trigger = self._clean_trigger(parts[2]) else: name = event trigger = "" if not trigger: event = ".".join([*parts[:2], "<__trigger__>"]) return name, trigger, event def add( # type: ignore self, handler: SignalHandler, event: str | Enum, condition: dict[str, Any] | None = None, exclusive: bool = True, *, priority: int = 0, ) -> Signal: event_definition = self.format_event(event) name, trigger, event_string = self._get_event_parts(event_definition) signal = super().add( event_string, handler, name=name, append=True, priority=priority, ) # type: ignore signal.ctx.exclusive = exclusive signal.ctx.trigger = trigger signal.ctx.definition = event_definition signal.extra.requirements = condition return cast(Signal, signal) def finalize(self, do_compile: bool = True, do_optimize: bool = False): """Finalize the router and compile the routes Args: do_compile (bool, optional): Whether to compile the routes. Defaults to `True`. do_optimize (bool, optional): Whether to optimize the routes. Defaults to `False`. Returns: SignalRouter: The router Raises: RuntimeError: If the router is finalized outside of an event loop """ # noqa: E501 self.add(_blank, "sanic.__signal__.__init__") try: self.ctx.loop = asyncio.get_running_loop() except RuntimeError: raise RuntimeError("Cannot finalize signals outside of event loop") for signal in self.routes: signal.ctx.waiters = deque() return super().finalize(do_compile=do_compile, do_optimize=do_optimize) def _build_event_parts(self, event: str) -> tuple[str, str, str]: parts = path_to_parts(event, self.delimiter) if ( len(parts) != 3 or parts[0].startswith("<") or parts[1].startswith("<") ): raise InvalidSignal("Invalid signal event: %s" % event) if ( parts[0] in RESERVED_NAMESPACES and event not in RESERVED_NAMESPACES[parts[0]] and not (parts[2].startswith("<") and parts[2].endswith(">")) ): raise InvalidSignal( "Cannot declare reserved signal event: %s" % event ) return parts def _clean_trigger(self, trigger: str) -> str: trigger = trigger[1:-1] if ":" in trigger: trigger, _ = trigger.split(":") return trigger