diff --git a/interactions/client/auto_shard_client.py b/interactions/client/auto_shard_client.py index 935ccaa1f..8eab52ba5 100644 --- a/interactions/client/auto_shard_client.py +++ b/interactions/client/auto_shard_client.py @@ -2,7 +2,7 @@ import time from datetime import datetime from collections import defaultdict -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Optional, List import interactions.api.events as events from interactions.api.events import ShardConnect @@ -35,6 +35,8 @@ def __init__(self, *args, **kwargs) -> None: self.auto_sharding = "total_shards" not in kwargs super().__init__(*args, **kwargs) + self.shard_ids: Optional[List[int]] = kwargs.get("shard_ids", None) + self._connection_state = None self._connection_states: list[ConnectionState] = [] @@ -244,9 +246,13 @@ async def login(self, token: str | None = None) -> None: ) self.logger.debug(f"Starting bot with {self.total_shards} shard{'s' if self.total_shards != 1 else ''}") - self._connection_states: list[ConnectionState] = [ - ConnectionState(self, self.intents, shard_id) for shard_id in range(self.total_shards) - ] + + if self.shard_ids: + self._connection_states = [ConnectionState(self, self.intents, shard_id) for shard_id in self.shard_ids] + else: + self._connection_states = [ + ConnectionState(self, self.intents, shard_id) for shard_id in range(self.total_shards) + ] async def change_presence( self,