diff --git a/livekit-rtc/livekit/rtc/chat.py b/livekit-rtc/livekit/rtc/chat.py index 22e4b0c8..f779dae3 100644 --- a/livekit-rtc/livekit/rtc/chat.py +++ b/livekit-rtc/livekit/rtc/chat.py @@ -12,18 +12,28 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations from dataclasses import dataclass, field from datetime import datetime import json import logging from typing import Any, Dict, Literal, Optional - -from .room import Room, Participant, DataPacket +import asyncio + +from .room import ( + Room, + Participant, + DataPacket, + TextStreamReader, + RemoteParticipant, + LocalParticipant, +) from .event_emitter import EventEmitter from ._utils import generate_random_base62 -_CHAT_TOPIC = "lk-chat-topic" -_CHAT_UPDATE_TOPIC = "lk-chat-update-topic" +_LEGACY_CHAT_TOPIC = "lk-chat-topic" +_LEGACY_CHAT_UPDATE_TOPIC = "lk-chat-update-topic" +_DS_CHAT_TOPIC = "lk.chat" EventTypes = Literal["message_received",] @@ -38,11 +48,17 @@ def __init__(self, room: Room): super().__init__() self._lp = room.local_participant self._room = room + self._tasks: list[asyncio.Task] = [] room.on("data_received", self._on_data_received) + room.unregister_text_stream_handler(_DS_CHAT_TOPIC) + room.register_text_stream_handler(_DS_CHAT_TOPIC, self._on_text_stream_received) def close(self): self._room.off("data_received", self._on_data_received) + self._room.unregister_text_stream_handler(_DS_CHAT_TOPIC) + for task in self._tasks: + task.cancel() async def send_message(self, message: str) -> "ChatMessage": """Send a chat message to the end user using LiveKit Chat Protocol. @@ -58,9 +74,12 @@ async def send_message(self, message: str) -> "ChatMessage": is_local=True, participant=self._lp, ) + msg_dict = msg.asjsondict() + msg_dict["ignoreLegacy"] = True + await self._lp.send_text(message, topic=_DS_CHAT_TOPIC) await self._lp.publish_data( - payload=json.dumps(msg.asjsondict()), - topic=_CHAT_TOPIC, + payload=json.dumps(msg_dict), + topic=_LEGACY_CHAT_TOPIC, ) return msg @@ -72,15 +91,18 @@ async def update_message(self, message: "ChatMessage"): """ await self._lp.publish_data( payload=json.dumps(message.asjsondict()), - topic=_CHAT_UPDATE_TOPIC, + topic=_LEGACY_CHAT_UPDATE_TOPIC, ) def _on_data_received(self, dp: DataPacket): # handle both new and updates the same way, as long as the ID is in there # the user can decide how to replace the previous message - if dp.topic == _CHAT_TOPIC or dp.topic == _CHAT_UPDATE_TOPIC: + if dp.topic == _LEGACY_CHAT_TOPIC or dp.topic == _LEGACY_CHAT_UPDATE_TOPIC: try: parsed = json.loads(dp.data) + # if the message is marked as ignoreLegacy, we'll skip it + if parsed.get("ignoreLegacy"): + return msg = ChatMessage.from_jsondict(parsed) if dp.participant: msg.participant = dp.participant @@ -88,6 +110,31 @@ def _on_data_received(self, dp: DataPacket): except Exception as e: logging.warning("failed to parse chat message: %s", e, exc_info=e) + def _on_text_stream_received( + self, stream: TextStreamReader, participant_identity: str + ): + task = asyncio.create_task( + self._handle_text_stream(stream, participant_identity) + ) + self._tasks.append(task) + task.add_done_callback(self._tasks.remove) + + async def _handle_text_stream( + self, stream: TextStreamReader, participant_identity: str + ): + msg = await ChatMessage.from_text_stream(stream) + participant: RemoteParticipant | LocalParticipant | None = ( + self._room._remote_participants.get(participant_identity) + ) + if ( + participant is None + and self._room.local_participant.identity == participant_identity + ): + participant = self._room.local_participant + msg.is_local = True + msg.participant = participant + self.emit("message_received", msg) + @dataclass class ChatMessage: @@ -129,3 +176,13 @@ def asjsondict(self): if self.deleted: d["deleted"] = True return d + + @classmethod + async def from_text_stream(cls, stream: TextStreamReader): + message_text = await stream.read_all() + timestamp = datetime.fromtimestamp(stream.info.timestamp / 1000.0) + return cls( + message=message_text, + timestamp=timestamp, + id=stream.info.stream_id, + )