diff --git a/src/codegate/dashboard/post_processing.py b/src/codegate/dashboard/post_processing.py index 9d5bbeb6..2ffa841e 100644 --- a/src/codegate/dashboard/post_processing.py +++ b/src/codegate/dashboard/post_processing.py @@ -1,7 +1,8 @@ import asyncio import json import re -from typing import List, Optional, Tuple, Union +from collections import defaultdict +from typing import List, Optional, Union import structlog @@ -9,7 +10,8 @@ AlertConversation, ChatMessage, Conversation, - PartialConversation, + PartialQuestionAnswer, + PartialQuestions, QuestionAnswer, ) from codegate.db.models import GetAlertsWithPromptAndOutputRow, GetPromptWithOutputsRow @@ -74,60 +76,57 @@ async def parse_request(request_str: str) -> Optional[str]: return None # Only respond with the latest message - return messages[-1] + return messages -async def parse_output(output_str: str) -> Tuple[Optional[str], Optional[str]]: +async def parse_output(output_str: str) -> Optional[str]: """ - Parse the output string from the pipeline and return the message and chat_id. + Parse the output string from the pipeline and return the message. """ try: if output_str is None: - return None, None + return None output = json.loads(output_str) except Exception as e: logger.warning(f"Error parsing output: {output_str}. {e}") - return None, None + return None def _parse_single_output(single_output: dict) -> str: - single_chat_id = single_output.get("id") single_output_message = "" for choice in single_output.get("choices", []): if not isinstance(choice, dict): continue content_dict = choice.get("delta", {}) or choice.get("message", {}) single_output_message += content_dict.get("content", "") - return single_output_message, single_chat_id + return single_output_message full_output_message = "" - chat_id = None if isinstance(output, list): for output_chunk in output: - output_message, output_chat_id = "", None + output_message = "" if isinstance(output_chunk, dict): - output_message, output_chat_id = _parse_single_output(output_chunk) + output_message = _parse_single_output(output_chunk) elif isinstance(output_chunk, str): try: output_decoded = json.loads(output_chunk) - output_message, output_chat_id = _parse_single_output(output_decoded) + output_message = _parse_single_output(output_decoded) except Exception: logger.error(f"Error reading chunk: {output_chunk}") else: logger.warning( f"Could not handle output: {output_chunk}", out_type=type(output_chunk) ) - chat_id = chat_id or output_chat_id full_output_message += output_message elif isinstance(output, dict): - full_output_message, chat_id = _parse_single_output(output) + full_output_message = _parse_single_output(output) - return full_output_message, chat_id + return full_output_message async def _get_question_answer( row: Union[GetPromptWithOutputsRow, GetAlertsWithPromptAndOutputRow] -) -> Tuple[Optional[QuestionAnswer], Optional[str]]: +) -> Optional[PartialQuestionAnswer]: """ Parse a row from the get_prompt_with_outputs query and return a PartialConversation @@ -137,17 +136,19 @@ async def _get_question_answer( request_task = tg.create_task(parse_request(row.request)) output_task = tg.create_task(parse_output(row.output)) - request_msg_str = request_task.result() - output_msg_str, chat_id = output_task.result() + request_user_msgs = request_task.result() + output_msg_str = output_task.result() - # If we couldn't parse the request or output, return None - if not request_msg_str: - return None, None + # If we couldn't parse the request, return None + if not request_user_msgs: + return None - request_message = ChatMessage( - message=request_msg_str, + request_message = PartialQuestions( + messages=request_user_msgs, timestamp=row.timestamp, message_id=row.id, + provider=row.provider, + type=row.type, ) if output_msg_str: output_message = ChatMessage( @@ -157,28 +158,7 @@ async def _get_question_answer( ) else: output_message = None - chat_id = row.id - return QuestionAnswer(question=request_message, answer=output_message), chat_id - - -async def parse_get_prompt_with_output( - row: GetPromptWithOutputsRow, -) -> Optional[PartialConversation]: - """ - Parse a row from the get_prompt_with_outputs query and return a PartialConversation - - The row contains the raw request and output strings from the pipeline. - """ - question_answer, chat_id = await _get_question_answer(row) - if not question_answer or not chat_id: - return None - return PartialConversation( - question_answer=question_answer, - provider=row.provider, - type=row.type, - chat_id=chat_id, - request_timestamp=row.timestamp, - ) + return PartialQuestionAnswer(partial_questions=request_message, answer=output_message) def parse_question_answer(input_text: str) -> str: @@ -195,50 +175,135 @@ def parse_question_answer(input_text: str) -> str: return input_text +def _group_partial_messages(pq_list: List[PartialQuestions]) -> List[List[PartialQuestions]]: + """ + A PartialQuestion is an object that contains several user messages provided from a + chat conversation. Example: + - PartialQuestion(messages=["Hello"], timestamp=2022-01-01T00:00:00Z) + - PartialQuestion(messages=["Hello", "How are you?"], timestamp=2022-01-01T00:00:01Z) + In the above example both PartialQuestions are part of the same conversation and should be + matched together. + Group PartialQuestions objects such that: + - If one PartialQuestion (pq) is a subset of another pq's messages, group them together. + - If multiple subsets exist for the same superset, choose only the one + closest in timestamp to the superset. + - Leave any unpaired pq by itself. + - Finally, sort the resulting groups by the earliest timestamp in each group. + """ + # 1) Sort by length of messages descending (largest/most-complete first), + # then by timestamp ascending for stable processing. + pq_list_sorted = sorted(pq_list, key=lambda x: (-len(x.messages), x.timestamp)) + + used = set() + groups = [] + + # 2) Iterate in order of "largest messages first" + for sup in pq_list_sorted: + if sup.message_id in used: + continue # Already grouped + + # Find all potential subsets of 'sup' that are not yet used + # (If sup's messages == sub's messages, that also counts, because sub ⊆ sup) + possible_subsets = [] + for sub in pq_list_sorted: + if sub.message_id == sup.message_id: + continue + if sub.message_id in used: + continue + if ( + set(sub.messages).issubset(set(sup.messages)) + and sub.provider == sup.provider + and set(sub.messages) != set(sup.messages) + ): + possible_subsets.append(sub) + + # 3) If there are no subsets, this sup stands alone + if not possible_subsets: + groups.append([sup]) + used.add(sup.message_id) + else: + # 4) Group subsets by messages to discard duplicates e.g.: 2 subsets with single 'hello' + subs_group_by_messages = defaultdict(list) + for q in possible_subsets: + subs_group_by_messages[tuple(q.messages)].append(q) + + new_group = [sup] + used.add(sup.message_id) + for subs_same_message in subs_group_by_messages.values(): + # If more than one pick the one subset closest in time to sup + closest_subset = min( + subs_same_message, key=lambda s: abs(s.timestamp - sup.timestamp) + ) + new_group.append(closest_subset) + used.add(closest_subset.message_id) + groups.append(new_group) + + # 5) Sort the groups by the earliest timestamp within each group + groups.sort(key=lambda g: min(pq.timestamp for pq in g)) + return groups + + +def _get_question_answer_from_partial( + partial_question_answer: PartialQuestionAnswer, +) -> QuestionAnswer: + """ + Get a QuestionAnswer object from a PartialQuestionAnswer object. + """ + # Get the last user message as the question + question = ChatMessage( + message=partial_question_answer.partial_questions.messages[-1], + timestamp=partial_question_answer.partial_questions.timestamp, + message_id=partial_question_answer.partial_questions.message_id, + ) + + return QuestionAnswer(question=question, answer=partial_question_answer.answer) + + async def match_conversations( - partial_conversations: List[Optional[PartialConversation]], + partial_question_answers: List[Optional[PartialQuestionAnswer]], ) -> List[Conversation]: """ Match partial conversations to form a complete conversation. """ - convers = {} - for partial_conversation in partial_conversations: - if not partial_conversation: - continue - - # Group by chat_id - if partial_conversation.chat_id not in convers: - convers[partial_conversation.chat_id] = [] - convers[partial_conversation.chat_id].append(partial_conversation) + valid_partial_qas = [ + partial_qas for partial_qas in partial_question_answers if partial_qas is not None + ] + grouped_partial_questions = _group_partial_messages( + [partial_qs_a.partial_questions for partial_qs_a in valid_partial_qas] + ) - # Sort by timestamp - sorted_convers = { - chat_id: sorted(conversations, key=lambda x: x.request_timestamp) - for chat_id, conversations in convers.items() - } # Create the conversation objects conversations = [] - for chat_id, sorted_convers in sorted_convers.items(): + for group in grouped_partial_questions: questions_answers = [] - first_partial_conversation = None - for partial_conversation in sorted_convers: + first_partial_qa = None + for partial_question in sorted(group, key=lambda x: x.timestamp): + # Partial questions don't contain the answer, so we need to find the corresponding + selected_partial_qa = None + for partial_qa in valid_partial_qas: + if partial_question.message_id == partial_qa.partial_questions.message_id: + selected_partial_qa = partial_qa + break + # check if we have an answer, otherwise do not add it - if partial_conversation.question_answer.answer is not None: - first_partial_conversation = partial_conversation - partial_conversation.question_answer.question.message = parse_question_answer( - partial_conversation.question_answer.question.message + if selected_partial_qa.answer is not None: + # if we don't have a first question, set it + first_partial_qa = first_partial_qa or selected_partial_qa + question_answer = _get_question_answer_from_partial(selected_partial_qa) + question_answer.question.message = parse_question_answer( + question_answer.question.message ) - questions_answers.append(partial_conversation.question_answer) + questions_answers.append(question_answer) # only add conversation if we have some answers - if len(questions_answers) > 0 and first_partial_conversation is not None: + if len(questions_answers) > 0 and first_partial_qa is not None: conversations.append( Conversation( question_answers=questions_answers, - provider=first_partial_conversation.provider, - type=first_partial_conversation.type, - chat_id=chat_id, - conversation_timestamp=sorted_convers[0].request_timestamp, + provider=first_partial_qa.partial_questions.provider, + type=first_partial_qa.partial_questions.type, + chat_id=first_partial_qa.partial_questions.message_id, + conversation_timestamp=first_partial_qa.partial_questions.timestamp, ) ) @@ -254,10 +319,10 @@ async def parse_messages_in_conversations( # Parse the prompts and outputs in parallel async with asyncio.TaskGroup() as tg: - tasks = [tg.create_task(parse_get_prompt_with_output(row)) for row in prompts_outputs] - partial_conversations = [task.result() for task in tasks] + tasks = [tg.create_task(_get_question_answer(row)) for row in prompts_outputs] + partial_question_answers = [task.result() for task in tasks] - conversations = await match_conversations(partial_conversations) + conversations = await match_conversations(partial_question_answers) return conversations @@ -269,15 +334,17 @@ async def parse_row_alert_conversation( The row contains the raw request and output strings from the pipeline. """ - question_answer, chat_id = await _get_question_answer(row) - if not question_answer or not chat_id: + partial_qa = await _get_question_answer(row) + if not partial_qa: return None + question_answer = _get_question_answer_from_partial(partial_qa) + conversation = Conversation( question_answers=[question_answer], provider=row.provider, type=row.type, - chat_id=chat_id or "chat-id-not-found", + chat_id=row.id, conversation_timestamp=row.timestamp, ) code_snippet = json.loads(row.code_snippet) if row.code_snippet else None diff --git a/src/codegate/dashboard/request_models.py b/src/codegate/dashboard/request_models.py index 8f13a03c..d36d9391 100644 --- a/src/codegate/dashboard/request_models.py +++ b/src/codegate/dashboard/request_models.py @@ -25,16 +25,25 @@ class QuestionAnswer(BaseModel): answer: Optional[ChatMessage] -class PartialConversation(BaseModel): +class PartialQuestions(BaseModel): """ - Represents a partial conversation obtained from a DB row. + Represents all user messages obtained from a DB row. """ - question_answer: QuestionAnswer + messages: List[str] + timestamp: datetime.datetime + message_id: str provider: Optional[str] type: str - chat_id: str - request_timestamp: datetime.datetime + + +class PartialQuestionAnswer(BaseModel): + """ + Represents a partial conversation. + """ + + partial_questions: PartialQuestions + answer: Optional[ChatMessage] class Conversation(BaseModel): diff --git a/tests/dashboard/test_post_processing.py b/tests/dashboard/test_post_processing.py index cbdb18a5..aa35cff2 100644 --- a/tests/dashboard/test_post_processing.py +++ b/tests/dashboard/test_post_processing.py @@ -5,16 +5,14 @@ import pytest from codegate.dashboard.post_processing import ( + _get_question_answer, + _group_partial_messages, _is_system_prompt, - parse_get_prompt_with_output, parse_output, parse_request, ) from codegate.dashboard.request_models import ( - ChatMessage, - Conversation, - PartialConversation, - QuestionAnswer, + PartialQuestions, ) from codegate.db.models import GetPromptWithOutputsRow @@ -37,11 +35,11 @@ async def test_is_system_prompt(message, expected_bool): @pytest.mark.asyncio @pytest.mark.parametrize( - "request_dict, expected_str", + "request_dict, expected_str_list", [ ( {"messages": [{"role": "user", "content": "Hello, how can I help you?"}]}, - "Hello, how can I help you?", + ["Hello, how can I help you?"], ), ( { @@ -61,7 +59,7 @@ async def test_is_system_prompt(message, expected_bool): {"role": "user", "content": "Hello, latest"}, ] }, - "Hello, latest", + ["Hello, how can I help you?", "Hello, latest"], ), ( { @@ -72,20 +70,20 @@ async def test_is_system_prompt(message, expected_bool): }, ] }, - "Hello, how can I help you?", + ["Hello, how can I help you?"], ), - ({"prompt": "Hello, how can I help you?"}, "Hello, how can I help you?"), + ({"prompt": "Hello, how can I help you?"}, ["Hello, how can I help you?"]), ], ) -async def test_parse_request(request_dict, expected_str): +async def test_parse_request(request_dict, expected_str_list): request_str = json.dumps(request_dict) result = await parse_request(request_str) - assert result == expected_str + assert result == expected_str_list @pytest.mark.asyncio @pytest.mark.parametrize( - "output_dict, expected_str, expected_chat_id", + "output_dict, expected_str", [ ( [ # Stream output with multiple chunks @@ -115,7 +113,6 @@ async def test_parse_request(request_dict, expected_str): }, ], "Hello world", - "chatcmpl-AaQw9O1O2u360mhba5UbMoPwFgqEl", ), ( { @@ -133,24 +130,21 @@ async def test_parse_request(request_dict, expected_str): ], }, "User seeks", - "chatcmpl-AaQw9O1O2u360mhba5UbMoPwFgqEa", ), ], ) -async def test_parse_output(output_dict, expected_str, expected_chat_id): +async def test_parse_output(output_dict, expected_str): request_str = json.dumps(output_dict) - output_message, chat_id = await parse_output(request_str) + output_message = await parse_output(request_str) assert output_message == expected_str - assert chat_id == expected_chat_id timestamp_now = datetime.datetime.now(datetime.timezone.utc) @pytest.mark.asyncio -@pytest.mark.parametrize("request_msg_str", ["Hello", None]) +@pytest.mark.parametrize("request_msg_list", [["Hello"], None]) @pytest.mark.parametrize("output_msg_str", ["Hello, how can I help you?", None]) -@pytest.mark.parametrize("chat_id", ["chatcmpl-AaQw9O1O2u360mhba5UbMoPwFgqEl"]) @pytest.mark.parametrize( "row", [ @@ -166,7 +160,7 @@ async def test_parse_output(output_dict, expected_str, expected_chat_id): ) ], ) -async def test_parse_get_prompt_with_output(request_msg_str, output_msg_str, chat_id, row): +async def test_get_question_answer(request_msg_list, output_msg_str, row): with patch( "codegate.dashboard.post_processing.parse_request", new_callable=AsyncMock ) as mock_parse_request: @@ -174,67 +168,239 @@ async def test_parse_get_prompt_with_output(request_msg_str, output_msg_str, cha "codegate.dashboard.post_processing.parse_output", new_callable=AsyncMock ) as mock_parse_output: # Set return values for the mocks - mock_parse_request.return_value = request_msg_str - mock_parse_output.return_value = (output_msg_str, chat_id) - result = await parse_get_prompt_with_output(row) + mock_parse_request.return_value = request_msg_list + mock_parse_output.return_value = output_msg_str + result = await _get_question_answer(row) mock_parse_request.assert_called_once() mock_parse_output.assert_called_once() - if request_msg_str is None: + if request_msg_list is None: assert result is None else: - assert result.question_answer.question.message == request_msg_str + assert result.partial_questions.messages == request_msg_list if output_msg_str is not None: - assert result.question_answer.answer.message == output_msg_str - assert result.chat_id == chat_id - assert result.provider == "provider" - assert result.type == "chat" - assert result.request_timestamp == timestamp_now - - -question_answer = QuestionAnswer( - question=ChatMessage( - message="Hello, how can I help you?", - timestamp=timestamp_now, - message_id="1", - ), - answer=ChatMessage( - message="Hello, how can I help you?", - timestamp=timestamp_now, - message_id="2", - ), -) + assert result.answer.message == output_msg_str + assert result.partial_questions.provider == "provider" + assert result.partial_questions.type == "chat" -@pytest.mark.asyncio @pytest.mark.parametrize( - "partial_conversations, expected_conversations", + "pq_list,expected_group_ids", [ - ([None], []), # Test empty list + # 1) No subsets: all items stand alone ( [ - None, - PartialConversation( # Test partial conversation with None - question_answer=question_answer, - provider="provider", + PartialQuestions( + messages=["A"], + timestamp=datetime.datetime(2023, 1, 1, 0, 0, 0), + message_id="pq1", + provider="providerA", + type="chat", + ), + PartialQuestions( + messages=["B"], + timestamp=datetime.datetime(2023, 1, 1, 0, 0, 1), + message_id="pq2", + provider="providerA", type="chat", - chat_id="chat_id", - request_timestamp=timestamp_now, ), ], + [["pq1"], ["pq2"]], + ), + # 2) Single subset: one is a subset of the other + # - "Hello" is a subset of "Hello, how are you?" + ( [ - Conversation( - question_answers=[question_answer], - provider="provider", + PartialQuestions( + messages=["Hello"], + timestamp=datetime.datetime(2022, 1, 1, 0, 0, 0), + message_id="pq1", + provider="providerA", type="chat", - chat_id="chat_id", - conversation_timestamp=timestamp_now, - ) + ), + PartialQuestions( + messages=["Hello", "How are you?"], + timestamp=datetime.datetime(2022, 1, 1, 0, 0, 10), + message_id="pq2", + provider="providerA", + type="chat", + ), ], + [["pq1", "pq2"]], + ), + # 3) Multiple identical subsets: + # We have 3 partial questions with messages=["Hello"], + # plus a superset with messages=["Hello", "Bye"]. + # Only the single subset that is closest in timestamp to the superset is grouped with the + # superset. + ( + [ + PartialQuestions( + messages=["Hello"], + timestamp=datetime.datetime(2023, 1, 1, 10, 0, 0), + message_id="pq1", + provider="providerA", + type="chat", + ), + PartialQuestions( + messages=["Hello"], + timestamp=datetime.datetime(2023, 1, 1, 11, 0, 0), + message_id="pq2", + provider="providerA", + type="chat", + ), + PartialQuestions( + messages=["Hello"], + timestamp=datetime.datetime(2023, 1, 1, 12, 0, 0), + message_id="pq3", + provider="providerA", + type="chat", + ), + PartialQuestions( + messages=["Hello", "Bye"], + timestamp=datetime.datetime(2023, 1, 1, 11, 0, 5), + message_id="pq4", + provider="providerA", + type="chat", + ), + ], + # pq4 is the superset => subsets are pq1, pq2, pq3. + # The closest subset to pq4(11:00:05) is pq2(11:00:00). + # So group = [pq2, pq4]. + # The other two remain alone in their own group. + # The final sorted order is by earliest timestamp in each group: + # group with pq1 => [pq1], earliest 10:00:00 + # group with pq2, pq4 => earliest 11:00:00 + # group with pq3 => earliest 12:00:00 + [["pq1"], ["pq2", "pq4"], ["pq3"]], + ), + # 4) Mixed: multiple subsets, multiple supersets, verifying group logic + ( + [ + # Superset + PartialQuestions( + messages=["hi", "welcome", "bye"], + timestamp=datetime.datetime(2023, 5, 1, 9, 0, 0), + message_id="pqS1", + provider="providerB", + type="chat", + ), + # Subsets for pqS1 + PartialQuestions( + messages=["hi", "welcome"], + timestamp=datetime.datetime(2023, 5, 1, 9, 0, 5), + message_id="pqA1", + provider="providerB", + type="chat", + ), + PartialQuestions( + messages=["hi", "bye"], + timestamp=datetime.datetime(2023, 5, 1, 9, 0, 10), + message_id="pqA2", + provider="providerB", + type="chat", + ), + PartialQuestions( + messages=["hi", "bye"], + timestamp=datetime.datetime(2023, 5, 1, 9, 0, 12), + message_id="pqA3", + provider="providerB", + type="chat", + ), + # Another superset + PartialQuestions( + messages=["apple", "banana", "cherry"], + timestamp=datetime.datetime(2023, 5, 2, 10, 0, 0), + message_id="pqS2", + provider="providerB", + type="chat", + ), + # Subsets for pqS2 + PartialQuestions( + messages=["banana"], + timestamp=datetime.datetime(2023, 5, 2, 10, 0, 1), + message_id="pqB1", + provider="providerB", + type="chat", + ), + PartialQuestions( + messages=["apple", "banana"], + timestamp=datetime.datetime(2023, 5, 2, 10, 0, 3), + message_id="pqB2", + provider="providerB", + type="chat", + ), + # Another item alone, not a subset nor superset + PartialQuestions( + messages=["xyz"], + timestamp=datetime.datetime(2023, 5, 3, 8, 0, 0), + message_id="pqC1", + provider="providerB", + type="chat", + ), + # Different provider => should remain separate + PartialQuestions( + messages=["hi", "welcome"], + timestamp=datetime.datetime(2023, 5, 1, 9, 0, 10), + message_id="pqProvDiff", + provider="providerX", + type="chat", + ), + ], + # Expected: + # For pqS1 (["hi","welcome","bye"]) => subsets are pqA1(["hi","welcome"]), + # pqA2 & pqA3 (["hi","bye"]) + # Among pqA2 and pqA3, we pick the one closest in time to 09:00:00 => + # that is pqA2(09:00:10) vs pqA3(09:00:12). + # The absolute difference: + # pqA2 => 10 seconds + # pqA3 => 12 seconds + # So we pick pqA2. Group => [pqS1, pqA1, pqA2] + # + # For pqS2 (["apple","banana","cherry"]) => subsets are pqB1(["banana"]), + # pqB2(["apple","banana"]) + # Among them, we group them all (because they have distinct messages). + # So => [pqS2, pqB1, pqB2] + # + # pqC1 stands alone => ["pqC1"] + # pqProvDiff stands alone => ["pqProvDiff"] because provider is different + # + # Then we sort by earliest timestamp in each group: + # group with pqS1 => earliest is 09:00:00 + # group with pqProvDiff => earliest is 09:00:10 + # group with pqS2 => earliest is 10:00:00 + # group with pqC1 => earliest is 08:00:00 on 5/3 => actually this is the last date, + # so let's see: + # 2023-05-01 is earlier than 2023-05-02, which is earlier than 2023-05-03. + # Actually, 2023-05-03 is later. So "pqC1" group is last in chronological order. + # + # Correct chronological order of earliest timestamps: + # 1) [pqS1, pqA1, pqA2] => earliest 2023-05-01 09:00:00 + # 2) [pqProvDiff] => earliest 2023-05-01 09:00:10 + # 3) [pqS2, pqB1, pqB2] => earliest 2023-05-02 10:00:00 + # 4) [pqC1] => earliest 2023-05-03 08:00:00 + [["pqS1", "pqA1", "pqA2"], ["pqProvDiff"], ["pqS2", "pqB1", "pqB2"], ["pqC1"]], ), ], ) -async def match_conversations(partial_conversations, expected_conversations): - result_conversations = await match_conversations(partial_conversations) - assert result_conversations == expected_conversations +def test_group_partial_messages(pq_list, expected_group_ids): + """ + Verify that _group_partial_messages produces the correct grouping + (by message_id) in the correct order. + """ + # Execute + grouped = _group_partial_messages(pq_list) + + # Convert from list[list[PartialQuestions]] -> list[list[str]] + # so we can compare with expected_group_ids easily. + grouped_ids = [[pq.message_id for pq in group] for group in grouped] + + is_matched = False + print(grouped_ids) + for group_id in grouped_ids: + for expected_group in expected_group_ids: + if set(group_id) == set(expected_group): + is_matched = True + break + assert is_matched