diff --git a/interactions/api/gateway/__init__.py b/interactions/api/gateway/__init__.py index 3f0423487..903f17955 100644 --- a/interactions/api/gateway/__init__.py +++ b/interactions/api/gateway/__init__.py @@ -6,3 +6,4 @@ """ from .client import * # noqa: F401 F403 from .heartbeat import * # noqa: F401 F403 +from .ratelimit import * # noqa: F401 F403 diff --git a/interactions/api/gateway/client.py b/interactions/api/gateway/client.py index 667092fe6..ef4cfc698 100644 --- a/interactions/api/gateway/client.py +++ b/interactions/api/gateway/client.py @@ -4,20 +4,23 @@ from json import dumps, loads from asyncio import ( + FIRST_COMPLETED, Event, + Lock, Task, - ensure_future, + TimeoutError, + create_task, get_event_loop, get_running_loop, new_event_loop, - sleep, + wait, + wait_for, ) from sys import platform, version_info from time import perf_counter from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union from aiohttp import ClientWebSocketResponse, WSMessage, WSMsgType -from aiohttp.http import WS_CLOSED_MESSAGE, WS_CLOSING_MESSAGE from ...base import get_logger from ...client.enums import InteractionType, OptionType @@ -33,6 +36,7 @@ from ..models.misc import Snowflake from ..models.presence import ClientPresence from .heartbeat import _Heartbeat +from .ratelimit import WSRateLimit if TYPE_CHECKING: from ...client.context import _Context @@ -47,47 +51,63 @@ class WebSocketClient: """ A class representing the client's connection to the Gateway via. WebSocket. + .. note :: + The ``__heartbeat_event`` Event object is different from the one built in to the Heartbeater object. + The latter is used to trace heartbeat acknowledgement. + :ivar AbstractEventLoop _loop: The asynchronous event loop. :ivar Listener _dispatch: The built-in event dispatcher. + :ivar WSRateLimit _ratelimiter: The websocket ratelimiter object. :ivar HTTPClient _http: The user-facing HTTP client. :ivar ClientWebSocketResponse _client: The WebSocket data of the connection. - :ivar bool _closed: Whether the connection has been closed or not. + :ivar Event __closed: Whether the connection has been closed or not. :ivar dict _options: The connection options made during connection. :ivar Intents _intents: The gateway intents used for connection. :ivar dict _ready: The contents of the application returned when ready. :ivar _Heartbeat __heartbeater: The context state of a "heartbeat" made to the Gateway. + :ivar Event __heartbeat_event: The state of the overall heartbeat process. :ivar Optional[List[Tuple[int]]] __shard: The shards used during connection. :ivar Optional[ClientPresence] __presence: The presence used in connection. :ivar Event ready: The ready state of the client as an ``asyncio.Event``. - :ivar Task __task: The closing task for ending connections. + :ivar Task _task: The task containing the heartbeat manager process. :ivar bool __started: Whether the client has started. :ivar Optional[str] session_id: The ID of the ongoing session. :ivar Optional[int] sequence: The sequence identifier of the ongoing session. :ivar float _last_send: The latest time of the last send_packet function call since connection creation, in seconds. :ivar float _last_ack: The latest time of the last ``HEARTBEAT_ACK`` event since connection creation, in seconds. - :ivar float latency: The latency of the connection, in seconds. + :ivar Optional[str] resume_url: The Websocket ratelimit URL for resuming connections, if any. + :ivar Optional[str] ws_url: The Websocket URL for instantiating connections without resuming. + :ivar Lock reconnect_lock: The lock used for reconnecting the client. + :ivar Lock _closing_lock: The lock used for closing the client. + :ivar Optional[Task] __stopping: The task containing stopping the client, if any. """ __slots__ = ( "_loop", "_dispatch", + "_ratelimiter", "_http", "_client", - "_closed", + "__closed", # placeholder to work with variables atm. its event variant of "_closed" "_options", "_intents", "_ready", "__heartbeater", "__shard", "__presence", - "__task", + "_task", + "__heartbeat_event", "__started", "session_id", "sequence", "ready", "_last_send", "_last_ack", - "latency", + "resume_url", + "ws_url", + "reconnect_lock", + "_closing_lock", + "__stopping", ) def __init__( @@ -112,160 +132,182 @@ def __init__( except RuntimeError: self._loop = new_event_loop() self._dispatch: Listener = Listener() + self._ratelimiter = ( + WSRateLimit(loop=self._loop) if version_info < (3, 10) else WSRateLimit() + ) + self.__heartbeater: _Heartbeat = _Heartbeat( + loop=self._loop if version_info < (3, 10) else None + ) self._http: HTTPClient = HTTPClient(token) self._client: Optional["ClientWebSocketResponse"] = None - self._closed: bool = False + + self.__closed: Event = Event(loop=self._loop) if version_info < (3, 10) else Event() self._options: dict = { "max_msg_size": 1024**2, "timeout": 60, "autoclose": False, "compress": 0, + "headers": {"User-Agent": self._http._req._headers["User-Agent"]}, } + self._intents: Intents = intents - self.__heartbeater: _Heartbeat = _Heartbeat( - loop=self._loop if version_info < (3, 10) else None - ) self.__shard: Optional[List[Tuple[int]]] = None self.__presence: Optional[ClientPresence] = None - self.__task: Optional[Task] = None + + self._task: Optional[Task] = None + self.__heartbeat_event = Event(loop=self._loop) if version_info < (3, 10) else Event() self.__started: bool = False + self.session_id: Optional[str] = None if session_id is MISSING else session_id self.sequence: Optional[str] = None if sequence is MISSING else sequence self.ready: Event = Event(loop=self._loop) if version_info < (3, 10) else Event() self._last_send: float = perf_counter() self._last_ack: float = perf_counter() - self.latency: float = float("nan") # noqa: F821 - # self.latency has to be noqa, this is valid in python but not in Flake8. + + self.resume_url: Optional[str] = None + self.ws_url: Optional[str] = None + self.reconnect_lock = Lock(loop=self._loop) if version_info < (3, 10) else Lock() + + self._closing_lock = Event(loop=self._loop) if version_info < (3, 10) else Event() + + self.__stopping: Optional[Task] = None + + @property + def latency(self) -> float: + """ + The latency of the connection, in seconds. + """ + return self._last_ack - self._last_send + + async def run_heartbeat(self) -> None: + """Controls the heartbeat manager. Do note that this shouldn't be executed by outside processes.""" + + if self.__heartbeat_event.is_set(): # resets task of heartbeat event mgr loop + # Because we're hardresetting the process every instance its called, also helps with recursion + self.__heartbeat_event.clear() + + if not self.__heartbeater.event.is_set(): # resets task of heartbeat ack event + self.__heartbeater.event.set() + + try: + await self._manage_heartbeat() + except Exception: + self._closing_lock.set() + log.exception("Heartbeater exception.") async def _manage_heartbeat(self) -> None: """Manages the heartbeat loop.""" - while True: - if self._closed: - await self.__restart() - if self.__heartbeater.event.is_set(): - await self.__heartbeat() - self.__heartbeater.event.clear() - await sleep(self.__heartbeater.delay / 1000) - else: + log.debug(f"Sending heartbeat every {self.__heartbeater.delay / 1000} seconds...") + while not self.__heartbeat_event.is_set(): + + log.debug("Sending heartbeat...") + if not self.__heartbeater.event.is_set(): log.debug("HEARTBEAT_ACK missing, reconnecting...") - await self.__restart() - break - - async def __restart(self) -> None: - """Restart the client's connection and heartbeat with the Gateway.""" - if self.__task: - self.__task.cancel() - self._client = None # clear pending waits - self.__heartbeater.event.clear() - await self._establish_connection(self.__shard, self.__presence) - - async def _establish_connection( - self, - shard: Optional[List[Tuple[int]]] = MISSING, - presence: Optional[ClientPresence] = MISSING, - ) -> None: - """ - Establishes a client connection with the Gateway. + await self._reconnect(True) # resume here. - :param shard?: The shards to establish a connection with. Defaults to ``None``. - :type shard?: Optional[List[Tuple[int]]] - :param presence: The presence to carry with. Defaults to ``None``. - :type presence: Optional[ClientPresence] + self.__heartbeater.event.clear() + await self.__heartbeat() + + try: + # wait for next iteration, accounting for latency + await wait_for( + self.__heartbeat_event.wait(), timeout=self.__heartbeater.delay / 1000 + ) + except TimeoutError: + continue # Then we can check heartbeat ack this way and then like it autorestarts. + else: + return # break loop because something went wrong. + + async def run(self) -> None: + """ + Handles the client's connection with the Gateway. """ - self._client = None - self.__heartbeater.delay = 0.0 - self._closed = False - self._options["headers"] = {"User-Agent": self._http._req._headers["User-Agent"]} + + # Credit to NAFF for inspiration for the Gateway logic. + url = await self._http.get_gateway() + self.ws_url = url + self._client = await self._http._req._session.ws_connect(url, **self._options) + + data = await self.__receive_packet(True) # First data is the hello packet. - async with self._http._req._session.ws_connect(url, **self._options) as self._client: - self._closed = self._client.closed + self.__heartbeater.delay = data["d"]["heartbeat_interval"] - if self._closed: - await self._establish_connection(self.__shard, self.__presence) + self._task = create_task(self.run_heartbeat()) - while not self._closed: - stream = await self.__receive_packet_stream + await self.__identify(self.__shard, self.__presence) - if stream is None: - continue - if self._client is None or stream == WS_CLOSED_MESSAGE or stream == WSMsgType.CLOSE: - await self._establish_connection(self.__shard, self.__presence) - break + self.__closed.set() + self.__heartbeater.event.set() - if self._client.close_code in range(4010, 4014) or self._client.close_code == 4004: - raise LibraryException(self._client.close_code) + while True: + if self.__stopping is None: + self.__stopping = create_task(self._closing_lock.wait()) + _receive = create_task(self.__receive_packet()) - await self._handle_connection(stream, shard, presence) + done, _ = await wait({self.__stopping, _receive}, return_when=FIRST_COMPLETED) + # Using asyncio.wait to find which one reaches first, when its *closed* or when a message is + # *received* - async def _handle_connection( - self, - stream: Dict[str, Any], - shard: Optional[List[Tuple[int]]] = MISSING, - presence: Optional[ClientPresence] = MISSING, - ) -> None: + if _receive in done: + msg = await _receive + else: + await self.__stopping + _receive.cancel() + return + + await self._handle_stream(msg) + + async def _handle_stream(self, stream: Dict[str, Any]): """ - Handles the client's connection with the Gateway. + Parses raw stream data recieved from the Gateway, including Gateway opcodes and events. + + .. note :: + This should never be called directly. :param stream: The packet stream to handle. :type stream: Dict[str, Any] - :param shard?: The shards to establish a connection with. Defaults to ``None``. - :type shard?: Optional[List[Tuple[int]]] - :param presence: The presence to carry with. Defaults to ``None``. - :type presence: Optional[ClientPresence] """ op: Optional[int] = stream.get("op") event: Optional[str] = stream.get("t") data: Optional[Dict[str, Any]] = stream.get("d") + seq: Optional[str] = stream.get("s") + if seq: + self.sequence = seq + if op != OpCodeType.DISPATCH: log.debug(data) - if op == OpCodeType.HELLO: - self.__heartbeater.delay = data["heartbeat_interval"] - self.__heartbeater.event.set() - - if self.__task: - self.__task.cancel() # so we can reduce redundant heartbeat bg tasks. - - self.__task = ensure_future(self._manage_heartbeat()) - - if not self.session_id: - await self.__identify(shard, presence) - else: - await self.__resume() if op == OpCodeType.HEARTBEAT: await self.__heartbeat() if op == OpCodeType.HEARTBEAT_ACK: self._last_ack = perf_counter() log.debug("HEARTBEAT_ACK") self.__heartbeater.event.set() - self.latency = self._last_ack - self._last_send - if op in (OpCodeType.INVALIDATE_SESSION, OpCodeType.RECONNECT): - log.debug("INVALID_SESSION/RECONNECT") - # if data and op != OpCodeType.RECONNECT: - # self.session_id = None - # self.sequence = None - # self._closed = True + if op == OpCodeType.INVALIDATE_SESSION: + log.debug("INVALID_SESSION") + self.ready.clear() + await self._reconnect(bool(data)) - if not bool(data) and op == OpCodeType.INVALIDATE_SESSION: - self.session_id = None + if op == OpCodeType.RECONNECT: + log.debug("RECONNECT") + await self._reconnect(True) - await self.__restart() elif event == "RESUMED": log.debug(f"RESUMED (session_id: {self.session_id}, seq: {self.sequence})") elif event == "READY": + self.ready.set() + self._dispatch.dispatch("on_ready") self._ready = data self.session_id = data["session_id"] - self.sequence = stream["s"] - self._dispatch.dispatch("on_ready") + self.resume_url = data["resume_gateway_url"] if not self.__started: self.__started = True self._dispatch.dispatch("on_start") log.debug(f"READY (session_id: {self.session_id}, seq: {self.sequence})") - self.ready.set() else: log.debug(f"{event}: {str(data).encode('utf-8')}") self._dispatch_event(event, data) @@ -646,32 +688,112 @@ def __option_type_context(self, context: "_Context", type: int) -> dict: } return _resolved - async def restart(self) -> None: - await self.__restart() + async def _reconnect(self, to_resume: bool, code: Optional[int] = 1012) -> None: + """ + Restarts the client's connection and heartbeat with the Gateway. + """ - @property - async def __receive_packet_stream(self) -> Optional[Union[Dict[str, Any], WSMessage]]: + self._ready.clear() + + async with self.reconnect_lock: + self.__closed.clear() + + if self._client is not None: + await self._client.close(code=code) + + self._client = None + + if not to_resume: + url = self.ws_url if self.ws_url else await self._http.get_gateway() + else: + url = self.resume_url + + self._client = await self._http._req._session.ws_connect(url, **self._options) + + data = await self.__receive_packet(True) # First data is the hello packet. + + self.__heartbeater.delay = data["d"]["heartbeat_interval"] + + if self._task: + self._task.cancel() + if self.__heartbeat_event.is_set(): + self.__heartbeat_event.clear() # Because we're hardresetting the process + + self._task = create_task(self.run_heartbeat()) + + if not to_resume: + await self.__identify(self.__shard, self.__presence) + else: + await self.__resume() + + self.__closed.set() + self.__heartbeat_event.set() + + async def __receive_packet(self, ignore_lock: bool = False) -> Optional[Dict[str, Any]]: """ - Receives a stream of packets sent from the Gateway. + Receives a stream of packets sent from the Gateway in an async process. :return: The packet stream. :rtype: Optional[Dict[str, Any]] """ - packet: WSMessage = await self._client.receive() + while True: + + if not ignore_lock: + # meaning if we're reconnecting or something because of tasks + await self.__closed.wait() + + packet: WSMessage = await self._client.receive() - if packet == WSMsgType.CLOSE: - await self._client.close() - return packet + if packet.type == WSMsgType.CLOSE: + log.debug(f"Disconnecting from gateway = {packet.data}::{packet.extra}") - elif packet == WS_CLOSED_MESSAGE: - return packet + if packet.data >= 4000: # suppress 4001 because of weird presence errors + # This means that the error code is 4000+, which may signify Discord-provided error codes. - elif packet == WS_CLOSING_MESSAGE: - await self._client.close() - return WS_CLOSED_MESSAGE + # However, we suppress 4001 because of weird presence errors with change_presence + # The payload is correct, and the presence object persists. /shrug + + raise LibraryException(packet.data) + + if ignore_lock: + raise LibraryException( + message="Discord unexpectedly wants to close the WS on receiving by force.", + severity=50, + ) - return loads(packet.data) if packet and isinstance(packet.data, str) else None + await self._reconnect(packet.data != 1000, packet.data) + continue + + elif packet.type == WSMsgType.CLOSED: + # We need to wait/reconnect depending about other event holders. + + if ignore_lock: + raise LibraryException( + message="Discord unexpectedly closed on receiving by force.", severity=50 + ) + + if not self.__closed.is_set(): + await self.__closed.wait() + + # Edge case on force reconnecting if we dont + else: + await self._reconnect(True) + + elif packet.type == WSMsgType.CLOSING: + + if ignore_lock: + raise LibraryException( + message="Discord unexpectedly closing on receiving by force.", severity=50 + ) + + await self.__closed.wait() + continue + + if packet.data is None: + continue # We just loop it over because it could just be processing something. + + return loads(packet.data) if isinstance(packet.data, str) else None async def _send_packet(self, data: Dict[str, Any]) -> None: """ @@ -680,11 +802,16 @@ async def _send_packet(self, data: Dict[str, Any]) -> None: :param data: The data to send to the Gateway. :type data: Dict[str, Any] """ - self._last_send = perf_counter() _data = dumps(data) if isinstance(data, dict) else data packet: str = _data.decode("utf-8") if isinstance(_data, bytes) else _data - await self._client.send_str(packet) + + if data["op"] != OpCodeType.HEARTBEAT.value: + # This is because the ratelimiter limits already accounts for this. + await self._ratelimiter.block() + + self._last_send = perf_counter() log.debug(packet) + await self._client.send_str(packet) async def __identify( self, shard: Optional[List[Tuple[int]]] = None, presence: Optional[ClientPresence] = None @@ -700,7 +827,7 @@ async def __identify( self.__shard = shard self.__presence = presence payload: dict = { - "op": OpCodeType.IDENTIFY, + "op": OpCodeType.IDENTIFY.value, "d": { "token": self._http.token, "intents": self._intents.value, @@ -724,7 +851,7 @@ async def __identify( async def __resume(self) -> None: """Sends a ``RESUME`` packet to the gateway.""" payload: dict = { - "op": OpCodeType.RESUME, + "op": OpCodeType.RESUME.value, "d": {"token": self._http.token, "seq": self.sequence, "session_id": self.session_id}, } log.debug(f"RESUMING: {payload}") @@ -733,7 +860,7 @@ async def __resume(self) -> None: async def __heartbeat(self) -> None: """Sends a ``HEARTBEAT`` packet to the gateway.""" - payload: dict = {"op": OpCodeType.HEARTBEAT, "d": self.sequence} + payload: dict = {"op": OpCodeType.HEARTBEAT.value, "d": self.sequence} await self._send_packet(payload) log.debug("HEARTBEAT") @@ -759,7 +886,7 @@ async def _update_presence(self, presence: ClientPresence) -> None: :param presence: The presence to change the bot to on identify. :type presence: ClientPresence """ - payload: dict = {"op": OpCodeType.PRESENCE, "d": presence._json} + payload: dict = {"op": OpCodeType.PRESENCE.value, "d": presence._json} await self._send_packet(payload) log.debug(f"UPDATE_PRESENCE: {presence._json}") self.__presence = presence diff --git a/interactions/api/gateway/ratelimit.py b/interactions/api/gateway/ratelimit.py new file mode 100644 index 000000000..01e30466f --- /dev/null +++ b/interactions/api/gateway/ratelimit.py @@ -0,0 +1,79 @@ +import asyncio +import logging +from sys import version_info +from time import time +from typing import Optional + +log = logging.getLogger("gateway.ratelimit") + + +class WSRateLimit: + """ + A class that controls Gateway ratelimits using locking and a timer. + + .. note :: + While the docs state that the Gateway ratelimits are 120/60 (120 requests per 60 seconds), + this ratelimit offsets to 115 instead of 120 for room. + + :ivar Lock lock: The gateway Lock object. + :ivar int max: The upper limit of the ratelimit in seconds. Defaults to `115`. + :ivar int remaining: How many requests are left per ``per_second``. This is automatically decremented and reset. + :ivar float current_limit: When this cooldown session began. This is defined automatically. + :ivar float per_second: A constant denoting how many requests can be done per unit of seconds. (i.e., per 60 seconds, per 45, etc.) + """ + + def __init__(self, loop: Optional[asyncio.AbstractEventLoop] = None): + self.lock = asyncio.Lock(loop=loop) if version_info < (3, 10) else asyncio.Lock() + # To conserve timings, we need to do 115/60 + # Also, credit to d.py for their ratelimiter inspiration + + self.max = self.remaining = 115 + self.per_second = 60.0 + self.current_limit = 0.0 + + @property + def ratelimited(self) -> bool: + """ + An attribute that reflects whenever the websocket ratelimiter is rate-limited. + + :return: Whether it's rate-limited or not. + :rtype: bool + """ + current = time() + if current > self.current_limit + self.per_second: + return False + return self.remaining == 0 + + @property + def delay(self) -> float: + """ + An attribute that reflects how long we need to wait for ratelimit to pass, if any. + + :return: How long to wait in seconds, if any. Defaults to ``0.0``. + :rtype: float + """ + current = time() + + if current > self.current_limit + self.per_second: + self.remaining = self.max + + if self.remaining == self.max: + self.current_limit = current + + if self.remaining == 0: + return self.per_second - (current - self.current_limit) + + self.remaining -= 1 + if self.remaining == 0: + self.current_limit = current + + return 0.0 + + async def block(self) -> None: + """ + A function that uses the internal Lock to check for rate-limits and cooldown whenever necessary. + """ + async with self.lock: + if delta := self.delay: + log.warning(f"We are rate-limited. Please wait {round(delta, 2)} seconds...") + await asyncio.sleep(delta) diff --git a/interactions/api/models/presence.py b/interactions/api/models/presence.py index 9ca3d1fc5..305774c64 100644 --- a/interactions/api/models/presence.py +++ b/interactions/api/models/presence.py @@ -181,3 +181,7 @@ def __attrs_post_init__(self): # If since is not provided by the developer... self.since = int(time.time() * 1000) if self.status == "idle" else 0 self._json["since"] = self.since + if not self._json.get("afk"): + self.afk = self._json["afk"] = False + if not self._json.get("activities"): + self.activities = self._json["activities"] = [] diff --git a/interactions/client/bot.py b/interactions/client/bot.py index 6161b659b..06079a4b6 100644 --- a/interactions/client/bot.py +++ b/interactions/client/bot.py @@ -390,10 +390,42 @@ async def _ready(self) -> None: log.debug("Client is now ready.") await self._login() + async def _stop(self) -> None: + """Stops the websocket connection gracefully.""" + + log.debug("Shutting down the client....") + self._websocket.ready.clear() # Clears ready state. + self._websocket._closing_lock.set() # Toggles the "ready-to-shutdown" state for the bot. + # And subsequently, the processes will close itself. + + await self._http._req._session.close() # Closes the HTTP session associated with the client. + async def _login(self) -> None: """Makes a login with the Discord API.""" - while not self._websocket._closed: - await self._websocket._establish_connection(self._shards, self._presence) + + try: + await self._websocket.run() + except Exception: + log.exception("Websocket have raised an exception, closing.") + + if self._websocket._closing_lock.is_set(): + # signal for closing. + + try: + if self._websocket._task is not None: + self._websocket.__heartbeat_event.set() + try: + # Wait for the keep-alive handler to finish so we can discard it gracefully + await self._websocket._task + finally: + self._websocket._task = None + finally: # then the overall WS client + if self._websocket._client is not None: + # This needs to be properly closed + try: + await self._websocket._client.close(code=1000) + finally: + self._websocket._client = None async def wait_until_ready(self) -> None: """Helper method that waits until the websocket is ready."""