Skip to content
Merged
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 61 additions & 27 deletions interactions/client/context.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import asyncio
from datetime import datetime
from logging import Logger
from typing import TYPE_CHECKING, List, Optional, Tuple, Union

from ..api.error import LibraryException
from ..api.models.channel import Channel
from ..api.models.channel import Channel, Thread
from ..api.models.flags import MessageFlags, Permissions
from ..api.models.guild import Guild
from ..api.models.member import Member
Expand Down Expand Up @@ -41,11 +43,8 @@ class _Context(ClientSerializerMixin):
"""

message: Optional[Message] = field(converter=Message, default=None, add_client=True)
author: Member = field(converter=Member, default=None, add_client=True)
member: Member = field(converter=Member, add_client=True)
member: Member = field(default=None, converter=Member, add_client=True) # DMs?
user: User = field(converter=User, default=None, add_client=True)
channel: Optional[Channel] = field(converter=Channel, default=None, add_client=True)
guild: Optional[Guild] = field(converter=Guild, default=None, add_client=True)
id: Snowflake = field(converter=Snowflake)
application_id: Snowflake = field(converter=Snowflake)
type: InteractionType = field(converter=InteractionType)
Expand All @@ -63,6 +62,16 @@ class _Context(ClientSerializerMixin):
locale: Optional[Locale] = field(converter=Locale, default=None)
guild_locale: Optional[Locale] = field(converter=Locale, default=None)

def __attrs_post_init__(self) -> None:
if self.member and self.guild_id:
self.member._extras["guild_id"] = self.guild_id

if self.user is None:
self.user = self.member.user if self.member else None

if self.member and not self.member.user and self.user:
self.member.user = self.user

@property
def deferred_ephemeral(self) -> bool:
"""
Expand All @@ -75,48 +84,79 @@ def deferred_ephemeral(self) -> bool:
and self.message.flags & MessageFlags.LOADING
)

def __attrs_post_init__(self) -> None:
if self.member and self.guild_id:
self.member._extras["guild_id"] = self.guild_id
@property
def created_at(self) -> datetime:
"""
.. versionadded:: 4.4.0

self.author = self.member
Returns when the interaction was created.
"""
return self.id.timestamp

if self.user is None:
self.user = self.member.user if self.member else None
@property
def author(self) -> Member:
"""
Returns the author/member that invoked the interaction.
"""
return self.member

@property
def channel(self) -> Optional[Channel]:
"""
.. versionadded:: 4.1.0
.. versionchanged:: 4.4.0
Channel now returns ``None`` instead of ``MISSING`` if it is not found to avoid confusion

if self.guild is None and self.guild_id is not None:
self.guild = self._client.cache[Guild].get(self.guild_id, MISSING)
Returns the current channel, if cached.
"""
return self._client.cache[Channel].get(self.channel_id, None) or self._client.cache[
Thread
].get(self.channel_id, None)

@property
def guild(self) -> Optional[Guild]:
"""
.. versionadded:: 4.1.0
.. versionchanged:: 4.4.0
Guild now returns ``None`` instead of ``MISSING`` if it is not found to avoid confusion

Returns the current guild, if cached.
"""

if self.channel is None:
self.channel = self._client.cache[Channel].get(self.channel_id, MISSING)
return self._client.cache[Guild].get(self.guild_id, None)

async def get_channel(self) -> Channel:
"""
.. versionadded:: 4.1.0

This gets the channel the context was invoked in.
This gets the channel the context was invoked in. If the channel is not cached, an HTTP request is made.

:return: The channel as object
:rtype: Channel
"""
if channel := self.channel:
await asyncio.sleep(0)
return channel

res = await self._client.get_channel(int(self.channel_id))
self.channel = Channel(**res, _client=self._client)
return self.channel
return Channel(**res, _client=self._client)

async def get_guild(self) -> Guild:
"""
.. versionadded:: 4.1.0

This gets the guild the context was invoked in.
This gets the guild the context was invoked in. If the guild is not cached, an HTTP request is made.

:return: The guild as object
:rtype: Guild
"""

if guild := self.guild:
await asyncio.sleep(0)
return guild

res = await self._client.get_guild(int(self.guild_id))
self.guild = Guild(**res, _client=self._client)
return self.guild
return Guild(**res, _client=self._client)

async def send(
self,
Expand Down Expand Up @@ -389,10 +429,7 @@ class CommandContext(_Context):
:ivar str token: The token of the interaction response.
:ivar Snowflake guild_id: The ID of the current guild.
:ivar Snowflake channel_id: The ID of the current channel.
:ivar Member author: The member data model.
:ivar User user: The user data model.
:ivar Optional[Channel] channel: The channel data model.
:ivar Optional[Guild] guild: The guild data model.
:ivar bool responded: Whether an original response was made or not.
:ivar bool deferred: Whether the response was deferred or not.
:ivar Optional[Locale] locale: The selected language of the user invoking the interaction.
Expand Down Expand Up @@ -661,10 +698,7 @@ class ComponentContext(_Context):
:ivar Snowflake guild_id: The ID of the current guild.
:ivar Snowflake channel_id: The ID of the current channel.
:ivar Optional[Message] message: The message data model.
:ivar Member author: The member data model.
:ivar User user: The user data model.
:ivar Optional[Channel] channel: The channel data model.
:ivar Optional[Guild] guild: The guild data model.
:ivar bool responded: Whether an original response was made or not.
:ivar bool deferred: Whether the response was deferred or not.
:ivar str locale: The selected language of the user invoking the interaction.
Expand Down