Skip to content

Commit c238c39

Browse files
committed
Move transform_output to helpers module
Signed-off-by: Mattt Zmuda <[email protected]>
1 parent 69c95e6 commit c238c39

File tree

3 files changed

+20
-18
lines changed

3 files changed

+20
-18
lines changed

replicate/helpers.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import base64
22
import io
33
import mimetypes
4+
from collections.abc import Mapping, Sequence
45
from pathlib import Path
56
from types import GeneratorType
67
from typing import TYPE_CHECKING, Any, Optional
@@ -108,3 +109,20 @@ def base64_encode_file(file: io.IOBase) -> str:
108109
mimetypes.guess_type(getattr(file, "name", ""))[0] or "application/octet-stream"
109110
)
110111
return f"data:{mime_type};base64,{encoded_body}"
112+
113+
114+
def transform_output(value: Any, client: "Client") -> Any:
115+
from replicate.stream import FileOutput # pylint: disable=import-outside-toplevel
116+
117+
def transform(obj: Any) -> Any:
118+
if isinstance(obj, Mapping):
119+
return {k: transform(v) for k, v in obj.items()}
120+
elif isinstance(obj, Sequence) and not isinstance(obj, str):
121+
return [transform(item) for item in obj]
122+
elif isinstance(obj, str) and (
123+
obj.startswith("https:") or obj.startswith("data:")
124+
):
125+
return FileOutput(obj, client)
126+
return obj
127+
128+
return transform(value)

replicate/run.py

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
from collections.abc import Mapping, Sequence
21
from typing import (
32
TYPE_CHECKING,
43
Any,
@@ -14,10 +13,10 @@
1413

1514
from replicate import identifier
1615
from replicate.exceptions import ModelError
16+
from replicate.helpers import transform_output
1717
from replicate.model import Model
1818
from replicate.prediction import Prediction
1919
from replicate.schema import make_schema_backwards_compatible
20-
from replicate.stream import FileOutput
2120
from replicate.version import Version, Versions
2221

2322
if TYPE_CHECKING:
@@ -140,19 +139,4 @@ def _make_async_output_iterator(
140139
return None
141140

142141

143-
def transform_output(value: Any, client: "Client") -> Any:
144-
def transform(obj: Any) -> Any:
145-
if isinstance(obj, Mapping):
146-
return {k: transform(v) for k, v in obj.items()}
147-
elif isinstance(obj, Sequence) and not isinstance(obj, str):
148-
return [transform(item) for item in obj]
149-
elif isinstance(obj, str) and (
150-
obj.startswith("https:") or obj.startswith("data:")
151-
):
152-
return FileOutput(obj, client)
153-
return obj
154-
155-
return transform(value)
156-
157-
158142
__all__: List = []

tests/test_run.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
import asyncio
22
import sys
3+
from typing import cast
34

45
import httpx
56
import pytest
67
import respx
78

8-
from typing import cast
99
import replicate
1010
from replicate.client import Client
1111
from replicate.exceptions import ModelError, ReplicateError

0 commit comments

Comments
 (0)