Skip to content

Commit 81d1173

Browse files
committed
Implement initial multiturn support
1 parent 2e74a1b commit 81d1173

File tree

2 files changed

+43
-10
lines changed

2 files changed

+43
-10
lines changed

src/guidellm/request/loader.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,13 +107,20 @@ def __init__(
107107
self._preserved_iter = None
108108

109109
def __iter__(self) -> Iterator[GenerativeRequestSession]:
110+
turns = 1
111+
112+
data_iter = self._create_requests()
113+
while requests := [i for i, _ in zip(data_iter, range(turns))]:
114+
yield GenerativeRequestSession(requests)
115+
116+
def _create_requests(self) -> Iterator[GenerationRequest]:
110117
scope_create_count = 0
111118

112119
while (dataset_iter := self._get_dataset_iter(scope_create_count)) is not None:
113120
scope_create_count += 1
114121

115122
for item in dataset_iter:
116-
yield GenerativeRequestSession(self._create_request(item))
123+
yield self._create_request(item)
117124

118125
self._preserved_iter = None
119126

src/guidellm/request/session.py

Lines changed: 35 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import itertools
12
from abc import ABC, abstractmethod
23
from typing import Generic, TypeVar
34

@@ -29,24 +30,49 @@ def push_response(self, response: ResponseT) -> None: ...
2930
def complete(self) -> bool: ...
3031

3132

32-
# TODO: Implement multiturn support
33+
# FIXME: Bad implementation. Can only handle string requests
3334
class GenerativeRequestSession(RequestSession[GenerationRequest, ResponseSummary]):
34-
def __init__(self, request: GenerationRequest) -> None:
35-
self.request = request
36-
self._complete = False
35+
def __init__(self, prompts: list[GenerationRequest]) -> None:
36+
if not prompts:
37+
raise ValueError("Prompts cannot be empty")
38+
39+
self.prompts = prompts
40+
self.responses: list[str] = []
3741

3842
def __len__(self) -> int:
39-
return 1
43+
return len(self.prompts)
4044

4145
def get_next_request(self) -> GenerationRequest:
42-
return self.request
46+
completed_responses = len(self.responses)
47+
base_request = self.prompts[completed_responses].model_copy(deep=True)
48+
base_request.content = "".join(
49+
itertools.chain.from_iterable(
50+
zip((x.content for x in self.prompts), self.responses + [""])
51+
)
52+
)
53+
base_request.stats["prompt_tokens"] = sum(
54+
x.stats["prompt_tokens"] for x in self.prompts[: completed_responses + 1]
55+
)
56+
base_request.constraints["output_tokens"] = sum(
57+
x.constraints["output_tokens"]
58+
for x in self.prompts[: completed_responses + 1]
59+
)
60+
61+
return base_request
4362

4463
def get_next_delay(self) -> float:
4564
return 0.0
4665

47-
def push_response(self, response: ResponseSummary) -> None: # noqa: ARG002
48-
self._complete = True
66+
def push_response(self, response: ResponseSummary) -> None:
67+
if len(self.responses) < len(self.prompts):
68+
if response.response_output_tokens is not None:
69+
self.prompts[len(self.responses)].constraints["output_tokens"] = (
70+
response.response_output_tokens
71+
)
72+
self.responses.append(response.value)
73+
else:
74+
raise ValueError("Response list full")
4975

5076
@property
5177
def complete(self) -> bool:
52-
return self._complete
78+
return len(self.responses) >= len(self.prompts)

0 commit comments

Comments
 (0)