diff --git a/README.md b/README.md index d4812d4..a462b2b 100644 --- a/README.md +++ b/README.md @@ -118,14 +118,18 @@ For models that support streaming (particularly language models), you can use `r import replicate for event in replicate.stream( - "meta/meta-llama-3-70b-instruct", + "anthropic/claude-4-sonnet", input={ - "prompt": "Please write a haiku about llamas.", + "prompt": "Give me a recipe for tasty smashed avocado on sourdough toast.", + "max_tokens": 8192, + "system_prompt": "You are a helpful assistant", }, ): print(str(event), end="") ``` +The `stream()` method creates a prediction and returns an iterator that yields output chunks as strings. This is useful for language models where you want to display output as it's generated rather than waiting for the entire response. + ## Async usage Simply import `AsyncReplicate` instead of `Replicate` and use `await` with each API call: @@ -172,7 +176,11 @@ async def main(): # Stream a model's output async for event in replicate.stream( - "meta/meta-llama-3-70b-instruct", input={"prompt": "Write a haiku about coding"} + "anthropic/claude-4-sonnet", + input={ + "prompt": "Write a haiku about coding", + "system_prompt": "You are a helpful assistant", + }, ): print(str(event), end="") diff --git a/src/replicate/__init__.py b/src/replicate/__init__.py index 1cfff56..2e2f286 100644 --- a/src/replicate/__init__.py +++ b/src/replicate/__init__.py @@ -109,7 +109,7 @@ if not __name.startswith("__"): try: # Skip symbols that are imported later from _module_client - if __name in ("run", "use"): + if __name in ("run", "use", "stream"): continue __locals[__name].__module__ = "replicate" except (TypeError, AttributeError): @@ -253,6 +253,7 @@ def _reset_client() -> None: # type: ignore[reportUnusedFunction] use as use, files as files, models as models, + stream as stream, account as account, hardware as hardware, webhooks as webhooks, diff --git a/src/replicate/_client.py b/src/replicate/_client.py index 390a552..ab28379 100644 --- a/src/replicate/_client.py +++ b/src/replicate/_client.py @@ -320,6 +320,36 @@ def use( # TODO: Fix mypy overload matching for streaming parameter return _use(self, ref, hint=hint, streaming=streaming) # type: ignore[call-overload, no-any-return] + def stream( + self, + ref: Union[Model, Version, ModelVersionIdentifier, str], + *, + file_encoding_strategy: Optional["FileEncodingStrategy"] = None, + **params: Unpack[PredictionCreateParamsWithoutVersion], + ) -> Iterator[str]: + """ + Stream output from a model prediction. + + Example: + ```python + for event in client.stream( + "meta/meta-llama-3-70b-instruct", + input={"prompt": "Write a haiku about coding"}, + ): + print(str(event), end="") + ``` + + See `replicate.lib._predictions_stream.stream` for full documentation. + """ + from .lib._predictions_stream import stream + + return stream( + self, + ref, + file_encoding_strategy=file_encoding_strategy, + **params, + ) + def copy( self, *, @@ -695,6 +725,37 @@ def use( # TODO: Fix mypy overload matching for streaming parameter return _use(self, ref, hint=hint, streaming=streaming) # type: ignore[call-overload, no-any-return] + async def stream( + self, + ref: Union[Model, Version, ModelVersionIdentifier, str], + *, + file_encoding_strategy: Optional["FileEncodingStrategy"] = None, + **params: Unpack[PredictionCreateParamsWithoutVersion], + ) -> AsyncIterator[str]: + """ + Stream output from a model prediction asynchronously. + + Example: + ```python + async for event in client.stream( + "meta/meta-llama-3-70b-instruct", + input={"prompt": "Write a haiku about coding"}, + ): + print(str(event), end="") + ``` + + See `replicate.lib._predictions_stream.async_stream` for full documentation. + """ + from .lib._predictions_stream import async_stream + + async for chunk in async_stream( + self, + ref, + file_encoding_strategy=file_encoding_strategy, + **params, + ): + yield chunk + def copy( self, *, diff --git a/src/replicate/_module_client.py b/src/replicate/_module_client.py index a3e8ab4..6b7a1f7 100644 --- a/src/replicate/_module_client.py +++ b/src/replicate/_module_client.py @@ -82,6 +82,7 @@ def __load__(self) -> PredictionsResource: __client: Replicate = cast(Replicate, {}) run = __client.run use = __client.use + stream = __client.stream else: def _run(*args, **kwargs): @@ -100,8 +101,12 @@ def _use(ref, *, hint=None, streaming=False, use_async=False, **kwargs): return use(Replicate, ref, hint=hint, streaming=streaming, **kwargs) + def _stream(*args, **kwargs): + return _load_client().stream(*args, **kwargs) + run = _run use = _use + stream = _stream files: FilesResource = FilesResourceProxy().__as_proxied__() models: ModelsResource = ModelsResourceProxy().__as_proxied__() diff --git a/src/replicate/lib/_predictions_stream.py b/src/replicate/lib/_predictions_stream.py new file mode 100644 index 0000000..d087f65 --- /dev/null +++ b/src/replicate/lib/_predictions_stream.py @@ -0,0 +1,163 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Tuple, Union, Iterator, Optional +from collections.abc import AsyncIterator +from typing_extensions import Unpack + +from replicate.lib._files import FileEncodingStrategy +from replicate.types.prediction_create_params import PredictionCreateParamsWithoutVersion + +from ..types import PredictionCreateParams +from ._models import Model, Version, ModelVersionIdentifier, resolve_reference + +if TYPE_CHECKING: + from .._client import Replicate, AsyncReplicate + +_STREAM_DOCSTRING = """ +Stream output from a model prediction. + +This creates a prediction and returns an iterator that yields output chunks +as strings as they become available from the streaming API. + +Args: + ref: Reference to the model or version to run. Can be: + - A string containing a version ID + - A string with owner/name format (e.g. "replicate/hello-world") + - A string with owner/name:version format + - A Model instance + - A Version instance + - A ModelVersionIdentifier dictionary + file_encoding_strategy: Strategy for encoding file inputs + **params: Additional parameters including the required "input" dictionary + +Yields: + str: Output chunks from the model as they become available + +Raises: + ValueError: If the reference format is invalid + ReplicateError: If the prediction fails or streaming is not available +""" + + +def _resolve_reference( + ref: Union[Model, Version, ModelVersionIdentifier, str], +) -> Tuple[Optional[Version], Optional[str], Optional[str], Optional[str]]: + """Resolve a model reference to its components, with fallback for plain version IDs.""" + try: + return resolve_reference(ref) + except ValueError: + # If resolution fails, treat it as a version ID if it's a string + if isinstance(ref, str): + return None, None, None, ref + else: + raise + + +def stream( + client: "Replicate", + ref: Union[Model, Version, ModelVersionIdentifier, str], + *, + file_encoding_strategy: Optional["FileEncodingStrategy"] = None, + **params: Unpack[PredictionCreateParamsWithoutVersion], +) -> Iterator[str]: + __doc__ = _STREAM_DOCSTRING + _version, owner, name, version_id = _resolve_reference(ref) + + # Create prediction + if version_id is not None: + params_with_version: PredictionCreateParams = {**params, "version": version_id} + prediction = client.predictions.create(file_encoding_strategy=file_encoding_strategy, **params_with_version) + elif owner and name: + prediction = client.models.predictions.create( + file_encoding_strategy=file_encoding_strategy, model_owner=owner, model_name=name, **params + ) + else: + if isinstance(ref, str): + params_with_version = {**params, "version": ref} + prediction = client.predictions.create(file_encoding_strategy=file_encoding_strategy, **params_with_version) + else: + raise ValueError( + f"Invalid reference format: {ref}. Expected a model name ('owner/name'), " + "a version ID, a Model object, a Version object, or a ModelVersionIdentifier." + ) + + # Check if streaming URL is available + if not prediction.urls or not prediction.urls.stream: + raise ValueError("Model does not support streaming. The prediction URLs do not include a stream endpoint.") + + stream_url = prediction.urls.stream + + with client._client.stream( + "GET", + stream_url, + headers={ + "Accept": "text/event-stream", + "Cache-Control": "no-store", + }, + timeout=None, # No timeout for streaming + ) as response: + response.raise_for_status() + + # Parse SSE events and yield output chunks + decoder = client._make_sse_decoder() + for sse in decoder.iter_bytes(response.iter_bytes()): + # The SSE data contains the output chunks + if sse.data: + yield sse.data + + +async def async_stream( + client: "AsyncReplicate", + ref: Union[Model, Version, ModelVersionIdentifier, str], + *, + file_encoding_strategy: Optional["FileEncodingStrategy"] = None, + **params: Unpack[PredictionCreateParamsWithoutVersion], +) -> AsyncIterator[str]: + __doc__ = _STREAM_DOCSTRING + _version, owner, name, version_id = _resolve_reference(ref) + + # Create prediction + if version_id is not None: + params_with_version: PredictionCreateParams = {**params, "version": version_id} + prediction = await client.predictions.create( + file_encoding_strategy=file_encoding_strategy, **params_with_version + ) + elif owner and name: + prediction = await client.models.predictions.create( + file_encoding_strategy=file_encoding_strategy, model_owner=owner, model_name=name, **params + ) + else: + if isinstance(ref, str): + params_with_version = {**params, "version": ref} + prediction = await client.predictions.create( + file_encoding_strategy=file_encoding_strategy, **params_with_version + ) + else: + raise ValueError( + f"Invalid reference format: {ref}. Expected a model name ('owner/name'), " + "a version ID, a Model object, a Version object, or a ModelVersionIdentifier." + ) + + # Check if streaming URL is available + if not prediction.urls or not prediction.urls.stream: + raise ValueError("Model does not support streaming. The prediction URLs do not include a stream endpoint.") + + stream_url = prediction.urls.stream + + async with client._client.stream( + "GET", + stream_url, + headers={ + "Accept": "text/event-stream", + "Cache-Control": "no-store", + }, + timeout=None, # No timeout for streaming + ) as response: + response.raise_for_status() + + # Parse SSE events and yield output chunks + decoder = client._make_sse_decoder() + async for sse in decoder.aiter_bytes(response.aiter_bytes()): + # The SSE data contains the output chunks + if sse.data: + yield sse.data diff --git a/tests/lib/test_stream.py b/tests/lib/test_stream.py new file mode 100644 index 0000000..2d58172 --- /dev/null +++ b/tests/lib/test_stream.py @@ -0,0 +1,223 @@ +from __future__ import annotations + +import os +from typing import Any, Iterator + +import httpx +import pytest +from respx import MockRouter + +from replicate import Replicate, AsyncReplicate + +base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") +bearer_token = "My Bearer Token" + + +def create_mock_prediction_json(stream_url: str | None = None) -> dict[str, Any]: + """Helper to create a complete prediction JSON response""" + prediction: dict[str, Any] = { + "id": "test-prediction-id", + "created_at": "2023-01-01T00:00:00Z", + "data_removed": False, + "input": {"prompt": "Test"}, + "model": "test-model", + "output": None, + "status": "starting", + "version": "test-version-id", + "urls": { + "get": f"{base_url}/predictions/test-prediction-id", + "cancel": f"{base_url}/predictions/test-prediction-id/cancel", + "web": "https://replicate.com/p/test-prediction-id", + }, + } + if stream_url: + prediction["urls"]["stream"] = stream_url # type: ignore[index] + return prediction + + +def test_stream_with_model_owner_name(respx_mock: MockRouter) -> None: + """Test streaming with owner/name format""" + client = Replicate(base_url=base_url, bearer_token=bearer_token, _strict_response_validation=True) + + # Mock the prediction creation + respx_mock.post(f"{base_url}/models/meta/meta-llama-3-70b-instruct/predictions").mock( + return_value=httpx.Response( + 201, + json=create_mock_prediction_json(stream_url=f"{base_url}/stream/test-prediction-id"), + ) + ) + + # Mock the SSE stream endpoint + def stream_content() -> Iterator[bytes]: + yield b"data: Hello\n\n" + yield b"data: world\n\n" + yield b"data: !\n\n" + + respx_mock.get(f"{base_url}/stream/test-prediction-id").mock( + return_value=httpx.Response( + 200, + headers={"content-type": "text/event-stream"}, + content=stream_content(), + ) + ) + + # Stream the model + output: list[str] = [] + for chunk in client.stream( + "meta/meta-llama-3-70b-instruct", + input={"prompt": "Say hello"}, + ): + output.append(chunk) + + assert output == ["Hello", " world", "!"] + + +def test_stream_with_version_id(respx_mock: MockRouter) -> None: + """Test streaming with version ID""" + client = Replicate(base_url=base_url, bearer_token=bearer_token, _strict_response_validation=True) + version_id = "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa" + + # Mock the prediction creation + respx_mock.post(f"{base_url}/predictions").mock( + return_value=httpx.Response( + 201, + json=create_mock_prediction_json(stream_url=f"{base_url}/stream/test-prediction-id"), + ) + ) + + # Mock the SSE stream endpoint + def stream_content() -> Iterator[bytes]: + yield b"data: Test\n\n" + yield b"data: output\n\n" + + respx_mock.get(f"{base_url}/stream/test-prediction-id").mock( + return_value=httpx.Response( + 200, + headers={"content-type": "text/event-stream"}, + content=stream_content(), + ) + ) + + # Stream the model + output: list[str] = [] + for chunk in client.stream( + version_id, + input={"prompt": "Test"}, + ): + output.append(chunk) + + assert output == ["Test", "output"] + + +def test_stream_no_stream_url_raises_error(respx_mock: MockRouter) -> None: + """Test that streaming raises an error when model doesn't support streaming""" + client = Replicate(base_url=base_url, bearer_token=bearer_token, _strict_response_validation=True) + + # Mock the prediction creation without stream URL + respx_mock.post(f"{base_url}/models/owner/model/predictions").mock( + return_value=httpx.Response( + 201, + json=create_mock_prediction_json(stream_url=None), + ) + ) + + # Try to stream and expect an error + with pytest.raises(ValueError, match="Model does not support streaming"): + for _ in client.stream("owner/model", input={"prompt": "Test"}): + pass + + +@pytest.mark.asyncio +async def test_async_stream_with_model_owner_name(respx_mock: MockRouter) -> None: + """Test async streaming with owner/name format""" + async_client = AsyncReplicate(base_url=base_url, bearer_token=bearer_token, _strict_response_validation=True) + + # Mock the prediction creation + respx_mock.post(f"{base_url}/models/meta/meta-llama-3-70b-instruct/predictions").mock( + return_value=httpx.Response( + 201, + json=create_mock_prediction_json(stream_url=f"{base_url}/stream/test-prediction-id"), + ) + ) + + # Mock the SSE stream endpoint + async def stream_content(): + yield b"data: Async\n\n" + yield b"data: test\n\n" + + respx_mock.get(f"{base_url}/stream/test-prediction-id").mock( + return_value=httpx.Response( + 200, + headers={"content-type": "text/event-stream"}, + content=stream_content(), + ) + ) + + # Stream the model + output: list[str] = [] + async for chunk in async_client.stream( + "meta/meta-llama-3-70b-instruct", + input={"prompt": "Say hello"}, + ): + output.append(chunk) + + assert output == ["Async", " test"] + + +@pytest.mark.asyncio +async def test_async_stream_no_stream_url_raises_error(respx_mock: MockRouter) -> None: + """Test that async streaming raises an error when model doesn't support streaming""" + async_client = AsyncReplicate(base_url=base_url, bearer_token=bearer_token, _strict_response_validation=True) + + # Mock the prediction creation without stream URL + respx_mock.post(f"{base_url}/models/owner/model/predictions").mock( + return_value=httpx.Response( + 201, + json=create_mock_prediction_json(stream_url=None), + ) + ) + + # Try to stream and expect an error + with pytest.raises(ValueError, match="Model does not support streaming"): + async for _ in async_client.stream("owner/model", input={"prompt": "Test"}): + pass + + +def test_stream_module_level(respx_mock: MockRouter) -> None: + """Test that module-level stream function works""" + import replicate + + # Set up module level client configuration + replicate.base_url = base_url + replicate.bearer_token = bearer_token + + # Mock the prediction creation + respx_mock.post(f"{base_url}/models/meta/meta-llama-3-70b-instruct/predictions").mock( + return_value=httpx.Response( + 201, + json=create_mock_prediction_json(stream_url=f"{base_url}/stream/test-prediction-id"), + ) + ) + + # Mock the SSE stream endpoint + def stream_content() -> Iterator[bytes]: + yield b"data: Module\n\n" + yield b"data: level\n\n" + + respx_mock.get(f"{base_url}/stream/test-prediction-id").mock( + return_value=httpx.Response( + 200, + headers={"content-type": "text/event-stream"}, + content=stream_content(), + ) + ) + + # Stream using module-level function + output: list[str] = [] + for chunk in replicate.stream( + "meta/meta-llama-3-70b-instruct", + input={"prompt": "Test"}, + ): + output.append(chunk) + + assert output == ["Module", " level"]