Skip to content

Commit 4687e28

Browse files
committed
Move FileOutput to helpers module
Signed-off-by: Mattt Zmuda <[email protected]>
1 parent c238c39 commit 4687e28

File tree

3 files changed

+43
-46
lines changed

3 files changed

+43
-46
lines changed

replicate/helpers.py

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44
from collections.abc import Mapping, Sequence
55
from pathlib import Path
66
from types import GeneratorType
7-
from typing import TYPE_CHECKING, Any, Optional
7+
from typing import TYPE_CHECKING, Any, AsyncIterator, Iterator, Optional
8+
9+
import httpx
810

911
if TYPE_CHECKING:
1012
from replicate.client import Client
@@ -111,9 +113,44 @@ def base64_encode_file(file: io.IOBase) -> str:
111113
return f"data:{mime_type};base64,{encoded_body}"
112114

113115

114-
def transform_output(value: Any, client: "Client") -> Any:
115-
from replicate.stream import FileOutput # pylint: disable=import-outside-toplevel
116+
class FileOutput(httpx.ByteStream, httpx.AsyncByteStream):
117+
url: str
118+
client: "Client"
119+
120+
def __init__(self, url: str, client: "Client"):
121+
self.url = url
122+
self.client = client
123+
124+
def read(self) -> bytes:
125+
with self.client._client.stream("GET", self.url) as response:
126+
response.raise_for_status()
127+
return response.read()
128+
129+
def __iter__(self) -> Iterator[bytes]:
130+
with self.client._client.stream("GET", self.url) as response:
131+
response.raise_for_status()
132+
for chunk in response.iter_bytes():
133+
yield chunk
134+
135+
async def aread(self) -> bytes:
136+
async with self.client._async_client.stream("GET", self.url) as response:
137+
response.raise_for_status()
138+
return await response.aread()
116139

140+
async def __aiter__(self) -> AsyncIterator[bytes]:
141+
async with self.client._async_client.stream("GET", self.url) as response:
142+
response.raise_for_status()
143+
async for chunk in response.aiter_bytes():
144+
yield chunk
145+
146+
def __str__(self) -> str:
147+
return self.url
148+
149+
def __repr__(self) -> str:
150+
return self.url
151+
152+
153+
def transform_output(value: Any, client: "Client") -> Any:
117154
def transform(obj: Any) -> Any:
118155
if isinstance(obj, Mapping):
119156
return {k: transform(v) for k, v in obj.items()}

replicate/stream.py

Lines changed: 2 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,3 @@
1-
import io
2-
import base64
3-
import httpx
41
from enum import Enum
52
from typing import (
63
TYPE_CHECKING,
@@ -12,9 +9,9 @@
129
Optional,
1310
Union,
1411
)
15-
from contextlib import asynccontextmanager, contextmanager
16-
from typing_extensions import Unpack
1712

13+
import httpx
14+
from typing_extensions import Unpack
1815

1916
from replicate import identifier
2017
from replicate.exceptions import ReplicateError
@@ -33,43 +30,6 @@
3330
from replicate.version import Version
3431

3532

36-
class FileOutput(httpx.ByteStream, httpx.AsyncByteStream):
37-
url: str
38-
client: "Client"
39-
40-
def __init__(self, url: str, client: "Client"):
41-
self.url = url
42-
self.client = client
43-
44-
def read(self) -> bytes:
45-
with self.client._client.stream("GET", self.url) as response:
46-
response.raise_for_status()
47-
return response.read()
48-
49-
def __iter__(self) -> Iterator[bytes]:
50-
with self.client._client.stream("GET", self.url) as response:
51-
response.raise_for_status()
52-
for chunk in response.iter_bytes():
53-
yield chunk
54-
55-
async def aread(self) -> bytes:
56-
async with self.client._async_client.stream("GET", self.url) as response:
57-
response.raise_for_status()
58-
return await response.aread()
59-
60-
async def __aiter__(self) -> AsyncIterator[bytes]:
61-
async with self.client._async_client.stream("GET", self.url) as response:
62-
response.raise_for_status()
63-
async for chunk in response.aiter_bytes():
64-
yield chunk
65-
66-
def __str__(self) -> str:
67-
return self.url
68-
69-
def __repr__(self) -> str:
70-
return self.url
71-
72-
7333
class ServerSentEvent(pydantic.BaseModel): # type: ignore
7434
"""
7535
A server-sent event.

tests/test_run.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import replicate
1010
from replicate.client import Client
1111
from replicate.exceptions import ModelError, ReplicateError
12-
from replicate.stream import FileOutput
12+
from replicate.helpers import FileOutput
1313

1414

1515
@pytest.mark.vcr("run.yaml")

0 commit comments

Comments
 (0)