Skip to content

Commit 8a19a33

Browse files
committed
Overhaul encode_json logic to accomodate file encoding strategies
Signed-off-by: Mattt Zmuda <[email protected]>
1 parent e19cb78 commit 8a19a33

File tree

11 files changed

+30541
-127
lines changed

11 files changed

+30541
-127
lines changed

replicate/deployment.py

Lines changed: 11 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
1-
import asyncio
21
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, TypedDict, Union
32

43
from typing_extensions import Unpack, deprecated
54

65
from replicate.account import Account
7-
from replicate.file import base64_encode_file
8-
from replicate.json import encode_json
6+
from replicate.json import async_encode_json, encode_json
97
from replicate.pagination import Page
108
from replicate.prediction import (
119
Prediction,
@@ -424,9 +422,8 @@ def create(
424422
if input is not None:
425423
input = encode_json(
426424
input,
427-
upload_file=base64_encode_file
428-
if file_encoding_strategy == "base64"
429-
else lambda file: self._client.files.create(file).urls["get"],
425+
client=self._client,
426+
file_encoding_strategy=file_encoding_strategy,
430427
)
431428
body = _create_prediction_body(version=None, input=input, **params)
432429

@@ -449,13 +446,10 @@ async def async_create(
449446

450447
file_encoding_strategy = params.pop("file_encoding_strategy", None)
451448
if input is not None:
452-
input = encode_json(
449+
input = await async_encode_json(
453450
input,
454-
upload_file=base64_encode_file
455-
if file_encoding_strategy == "base64"
456-
else lambda file: asyncio.get_event_loop()
457-
.run_until_complete(self._client.files.async_create(file))
458-
.urls["get"],
451+
client=self._client,
452+
file_encoding_strategy=file_encoding_strategy,
459453
)
460454
body = _create_prediction_body(version=None, input=input, **params)
461455

@@ -489,9 +483,8 @@ def create(
489483
if input is not None:
490484
input = encode_json(
491485
input,
492-
upload_file=base64_encode_file
493-
if file_encoding_strategy == "base64"
494-
else lambda file: self._client.files.create(file).urls["get"],
486+
client=self._client,
487+
file_encoding_strategy=file_encoding_strategy,
495488
)
496489
body = _create_prediction_body(version=None, input=input, **params)
497490

@@ -517,13 +510,10 @@ async def async_create(
517510

518511
file_encoding_strategy = params.pop("file_encoding_strategy", None)
519512
if input is not None:
520-
input = encode_json(
513+
input = await async_encode_json(
521514
input,
522-
upload_file=base64_encode_file
523-
if file_encoding_strategy == "base64"
524-
else lambda file: asyncio.get_event_loop()
525-
.run_until_complete(self._client.files.async_create(file))
526-
.urls["get"],
515+
client=self._client,
516+
file_encoding_strategy=file_encoding_strategy,
527517
)
528518
body = _create_prediction_body(version=None, input=input, **params)
529519

replicate/file.py

Lines changed: 3 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
1-
import base64
21
import io
32
import json
43
import mimetypes
54
import os
65
import pathlib
76
from typing import Any, BinaryIO, Dict, List, Optional, TypedDict, Union
87

9-
from typing_extensions import NotRequired, Unpack
8+
from typing_extensions import Literal, NotRequired, Unpack
109

1110
from replicate.resource import Namespace, Resource
1211

12+
FileEncodingStrategy = Literal["base64", "url"]
13+
1314

1415
class File(Resource):
1516
"""
@@ -168,26 +169,3 @@ def _create_file_params(
168169

169170
def _json_to_file(json: Dict[str, Any]) -> File: # pylint: disable=redefined-outer-name
170171
return File(**json)
171-
172-
173-
def base64_encode_file(file: io.IOBase) -> str:
174-
"""
175-
Base64 encode a file.
176-
177-
Args:
178-
file: A file handle to upload.
179-
Returns:
180-
str: A base64-encoded data URI.
181-
"""
182-
183-
file.seek(0)
184-
body = file.read()
185-
186-
# Ensure the file handle is in bytes
187-
body = body.encode("utf-8") if isinstance(body, str) else body
188-
encoded_body = base64.b64encode(body).decode("utf-8")
189-
190-
mime_type = (
191-
mimetypes.guess_type(getattr(file, "name", ""))[0] or "application/octet-stream"
192-
)
193-
return f"data:{mime_type};base64,{encoded_body}"

replicate/json.py

Lines changed: 77 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,14 @@
1+
import base64
12
import io
3+
import mimetypes
24
from pathlib import Path
35
from types import GeneratorType
4-
from typing import Any, Callable
6+
from typing import TYPE_CHECKING, Any, Optional
7+
8+
if TYPE_CHECKING:
9+
from replicate.client import Client
10+
from replicate.file import FileEncodingStrategy
11+
512

613
try:
714
import numpy as np # type: ignore
@@ -14,22 +21,62 @@
1421
# pylint: disable=too-many-return-statements
1522
def encode_json(
1623
obj: Any, # noqa: ANN401
17-
upload_file: Callable[[io.IOBase], str],
24+
client: "Client",
25+
file_encoding_strategy: Optional["FileEncodingStrategy"] = None,
1826
) -> Any: # noqa: ANN401
1927
"""
2028
Return a JSON-compatible version of the object.
2129
"""
22-
# Effectively the same thing as cog.json.encode_json.
2330

2431
if isinstance(obj, dict):
25-
return {key: encode_json(value, upload_file) for key, value in obj.items()}
32+
return {
33+
key: encode_json(value, client, file_encoding_strategy)
34+
for key, value in obj.items()
35+
}
36+
if isinstance(obj, (list, set, frozenset, GeneratorType, tuple)):
37+
return [encode_json(value, client, file_encoding_strategy) for value in obj]
38+
if isinstance(obj, Path):
39+
with obj.open("rb") as file:
40+
return encode_json(file, client, file_encoding_strategy)
41+
if isinstance(obj, io.IOBase):
42+
if file_encoding_strategy == "base64":
43+
return base64.b64encode(obj.read()).decode("utf-8")
44+
else:
45+
return client.files.create(obj).urls["get"]
46+
if HAS_NUMPY:
47+
if isinstance(obj, np.integer): # type: ignore
48+
return int(obj)
49+
if isinstance(obj, np.floating): # type: ignore
50+
return float(obj)
51+
if isinstance(obj, np.ndarray): # type: ignore
52+
return obj.tolist()
53+
return obj
54+
55+
56+
async def async_encode_json(
57+
obj: Any, # noqa: ANN401
58+
client: "Client",
59+
file_encoding_strategy: Optional["FileEncodingStrategy"] = None,
60+
) -> Any: # noqa: ANN401
61+
"""
62+
Asynchronously return a JSON-compatible version of the object.
63+
"""
64+
65+
if isinstance(obj, dict):
66+
return {
67+
key: (await async_encode_json(value, client, file_encoding_strategy))
68+
for key, value in obj.items()
69+
}
2670
if isinstance(obj, (list, set, frozenset, GeneratorType, tuple)):
27-
return [encode_json(value, upload_file) for value in obj]
71+
return [
72+
(await async_encode_json(value, client, file_encoding_strategy))
73+
for value in obj
74+
]
2875
if isinstance(obj, Path):
2976
with obj.open("rb") as file:
30-
return upload_file(file)
77+
return encode_json(file, client, file_encoding_strategy)
3178
if isinstance(obj, io.IOBase):
32-
return upload_file(obj)
79+
return (await client.files.async_create(obj)).urls["get"]
3380
if HAS_NUMPY:
3481
if isinstance(obj, np.integer): # type: ignore
3582
return int(obj)
@@ -38,3 +85,26 @@ def encode_json(
3885
if isinstance(obj, np.ndarray): # type: ignore
3986
return obj.tolist()
4087
return obj
88+
89+
90+
def base64_encode_file(file: io.IOBase) -> str:
91+
"""
92+
Base64 encode a file.
93+
94+
Args:
95+
file: A file handle to upload.
96+
Returns:
97+
str: A base64-encoded data URI.
98+
"""
99+
100+
file.seek(0)
101+
body = file.read()
102+
103+
# Ensure the file handle is in bytes
104+
body = body.encode("utf-8") if isinstance(body, str) else body
105+
encoded_body = base64.b64encode(body).decode("utf-8")
106+
107+
mime_type = (
108+
mimetypes.guess_type(getattr(file, "name", ""))[0] or "application/octet-stream"
109+
)
110+
return f"data:{mime_type};base64,{encoded_body}"

replicate/model.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
1-
import asyncio
21
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Tuple, Union, overload
32

43
from typing_extensions import NotRequired, TypedDict, Unpack, deprecated
54

65
from replicate.exceptions import ReplicateException
7-
from replicate.file import base64_encode_file
86
from replicate.identifier import ModelVersionIdentifier
9-
from replicate.json import encode_json
7+
from replicate.json import async_encode_json, encode_json
108
from replicate.pagination import Page
119
from replicate.prediction import (
1210
Prediction,
@@ -399,9 +397,8 @@ def create(
399397
if input is not None:
400398
input = encode_json(
401399
input,
402-
upload_file=base64_encode_file
403-
if file_encoding_strategy == "base64"
404-
else lambda file: self._client.files.create(file).urls["get"],
400+
client=self._client,
401+
file_encoding_strategy=file_encoding_strategy,
405402
)
406403
body = _create_prediction_body(version=None, input=input, **params)
407404

@@ -427,13 +424,10 @@ async def async_create(
427424

428425
file_encoding_strategy = params.pop("file_encoding_strategy", None)
429426
if input is not None:
430-
input = encode_json(
427+
input = await async_encode_json(
431428
input,
432-
upload_file=base64_encode_file
433-
if file_encoding_strategy == "base64"
434-
else lambda file: asyncio.get_event_loop()
435-
.run_until_complete(self._client.files.async_create(file))
436-
.urls["get"],
429+
client=self._client,
430+
file_encoding_strategy=file_encoding_strategy,
437431
)
438432
body = _create_prediction_body(version=None, input=input, **params)
439433

replicate/prediction.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919
from typing_extensions import NotRequired, TypedDict, Unpack
2020

2121
from replicate.exceptions import ModelError, ReplicateError
22-
from replicate.file import base64_encode_file
23-
from replicate.json import encode_json
22+
from replicate.file import FileEncodingStrategy
23+
from replicate.json import async_encode_json, encode_json
2424
from replicate.pagination import Page
2525
from replicate.resource import Namespace, Resource
2626
from replicate.stream import EventSource
@@ -383,7 +383,7 @@ class CreatePredictionParams(TypedDict):
383383
stream: NotRequired[bool]
384384
"""Enable streaming of prediction output."""
385385

386-
file_encoding_strategy: NotRequired[Literal["upload", "base64"]]
386+
file_encoding_strategy: NotRequired[FileEncodingStrategy]
387387
"""The strategy to use for encoding files in the prediction input."""
388388

389389
@overload
@@ -460,9 +460,8 @@ def create( # type: ignore
460460
if input is not None:
461461
input = encode_json(
462462
input,
463-
upload_file=base64_encode_file
464-
if file_encoding_strategy == "base64"
465-
else lambda file: self._client.files.create(file).urls["get"],
463+
client=self._client,
464+
file_encoding_strategy=file_encoding_strategy,
466465
)
467466
body = _create_prediction_body(
468467
version,
@@ -550,13 +549,10 @@ async def async_create( # type: ignore
550549

551550
file_encoding_strategy = params.pop("file_encoding_strategy", None)
552551
if input is not None:
553-
input = encode_json(
552+
input = await async_encode_json(
554553
input,
555-
upload_file=base64_encode_file
556-
if file_encoding_strategy == "base64"
557-
else lambda file: asyncio.get_event_loop()
558-
.run_until_complete(self._client.files.async_create(file))
559-
.urls["get"],
554+
client=self._client,
555+
file_encoding_strategy=file_encoding_strategy,
560556
)
561557
body = _create_prediction_body(
562558
version,

replicate/training.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import asyncio
21
from typing import (
32
TYPE_CHECKING,
43
Any,
@@ -14,9 +13,8 @@
1413

1514
from typing_extensions import NotRequired, Unpack
1615

17-
from replicate.file import base64_encode_file
1816
from replicate.identifier import ModelVersionIdentifier
19-
from replicate.json import encode_json
17+
from replicate.json import async_encode_json, encode_json
2018
from replicate.model import Model
2119
from replicate.pagination import Page
2220
from replicate.resource import Namespace, Resource
@@ -29,6 +27,7 @@
2927

3028
if TYPE_CHECKING:
3129
from replicate.client import Client
30+
from replicate.file import FileEncodingStrategy
3231

3332

3433
class Training(Resource):
@@ -221,7 +220,7 @@ class CreateTrainingParams(TypedDict):
221220
webhook: NotRequired[str]
222221
webhook_completed: NotRequired[str]
223222
webhook_events_filter: NotRequired[List[str]]
224-
file_encoding_strategy: NotRequired[Literal["upload", "base64"]]
223+
file_encoding_strategy: NotRequired["FileEncodingStrategy"]
225224

226225
@overload
227226
def create( # pylint: disable=too-many-arguments
@@ -283,10 +282,10 @@ def create( # type: ignore
283282
if input is not None:
284283
input = encode_json(
285284
input,
286-
upload_file=base64_encode_file
287-
if file_encoding_strategy == "base64"
288-
else lambda file: self._client.files.create(file).urls["get"],
285+
client=self._client,
286+
file_encoding_strategy=file_encoding_strategy,
289287
)
288+
290289
body = _create_training_body(input, **params)
291290

292291
resp = self._client._request(
@@ -322,13 +321,10 @@ async def async_create(
322321

323322
file_encoding_strategy = params.pop("file_encoding_strategy", None)
324323
if input is not None:
325-
input = encode_json(
324+
input = await async_encode_json(
326325
input,
327-
upload_file=base64_encode_file
328-
if file_encoding_strategy == "base64"
329-
else lambda file: asyncio.get_event_loop()
330-
.run_until_complete(self._client.files.async_create(file))
331-
.urls["get"],
326+
client=self._client,
327+
file_encoding_strategy=file_encoding_strategy,
332328
)
333329
body = _create_training_body(input, **params)
334330

0 commit comments

Comments
 (0)