Skip to content

Commit 84ad65f

Browse files
committed
Add support for files API endpoints
Signed-off-by: Mattt Zmuda <[email protected]>
1 parent b0bbc04 commit 84ad65f

File tree

7 files changed

+272
-43
lines changed

7 files changed

+272
-43
lines changed

replicate/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,9 @@
1414
async_paginate = _async_paginate
1515

1616
collections = default_client.collections
17-
hardware = default_client.hardware
1817
deployments = default_client.deployments
18+
files = default_client.files
19+
hardware = default_client.hardware
1920
models = default_client.models
2021
predictions = default_client.predictions
2122
trainings = default_client.trainings

replicate/client.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from replicate.collection import Collections
2525
from replicate.deployment import Deployments
2626
from replicate.exceptions import ReplicateError
27+
from replicate.file import Files
2728
from replicate.hardware import HardwareNamespace as Hardware
2829
from replicate.model import Models
2930
from replicate.prediction import Predictions
@@ -117,6 +118,13 @@ def deployments(self) -> Deployments:
117118
"""
118119
return Deployments(client=self)
119120

121+
@property
122+
def files(self) -> Files:
123+
"""
124+
Namespace for operations related to files.
125+
"""
126+
return Files(client=self)
127+
120128
@property
121129
def hardware(self) -> Hardware:
122130
"""

replicate/file.py

Lines changed: 202 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,202 @@
1+
import base64
2+
import io
3+
import json
4+
import mimetypes
5+
import os
6+
import pathlib
7+
from typing import Any, BinaryIO, Dict, List, Optional, Tuple, TypedDict, Union
8+
9+
import httpx
10+
from typing_extensions import NotRequired, Unpack
11+
12+
from replicate.resource import Namespace, Resource
13+
14+
15+
class File(Resource):
16+
"""
17+
A file uploaded to Replicate that can be used as an input to a model.
18+
"""
19+
20+
id: str
21+
"""The ID of the file."""
22+
23+
name: str
24+
"""The name of the file."""
25+
26+
content_type: str
27+
"""The content type of the file."""
28+
29+
size: int
30+
"""The size of the file in bytes."""
31+
32+
etag: str
33+
"""The ETag of the file."""
34+
35+
checksums: Dict[str, str]
36+
"""The checksums of the file."""
37+
38+
metadata: Dict[str, Any]
39+
"""The metadata of the file."""
40+
41+
created_at: str
42+
"""The time the file was created."""
43+
44+
expires_at: Optional[str]
45+
"""The time the file will expire."""
46+
47+
urls: Dict[str, str]
48+
"""The URLs of the file."""
49+
50+
51+
class Files(Namespace):
52+
class CreateFileParams(TypedDict):
53+
"""Parameters for creating a file."""
54+
55+
filename: NotRequired[str]
56+
"""The name of the file."""
57+
58+
content_type: NotRequired[str]
59+
"""The content type of the file."""
60+
61+
metadata: NotRequired[Dict[str, Any]]
62+
"""The file metadata."""
63+
64+
def create(
65+
self,
66+
file: Union[str, pathlib.Path, BinaryIO, io.IOBase],
67+
**params: Unpack["Files.CreateFileParams"],
68+
) -> File:
69+
"""
70+
Upload a file that can be passed as an input when running a model.
71+
"""
72+
73+
if isinstance(file, (str, pathlib.Path)):
74+
with open(file, "rb") as f:
75+
return self.create(f, **params)
76+
elif not isinstance(file, (io.IOBase, BinaryIO)):
77+
raise ValueError(
78+
"Unsupported file type. Must be a file path or file-like object."
79+
)
80+
81+
resp = self._client._request(
82+
"POST", "/v1/files", timeout=None, **_create_file_params(file, **params)
83+
)
84+
85+
return _json_to_file(resp.json())
86+
87+
async def async_create(
88+
self,
89+
file: Union[str, pathlib.Path, BinaryIO, io.IOBase],
90+
**params: Unpack["Files.CreateFileParams"],
91+
) -> File:
92+
"""Upload a file asynchronously that can be passed as an input when running a model."""
93+
94+
if isinstance(file, (str, pathlib.Path)):
95+
with open(file, "rb") as f:
96+
return self.create(f, **params)
97+
elif not isinstance(file, (io.IOBase, BinaryIO)):
98+
raise ValueError(
99+
"Unsupported file type. Must be a file path or file-like object."
100+
)
101+
102+
resp = await self._client._async_request(
103+
"POST", "/v1/files", timeout=None, **_create_file_params(file, **params)
104+
)
105+
106+
return _json_to_file(resp.json())
107+
108+
def get(self, file_id: str) -> File:
109+
"""Get an uploaded file by its ID."""
110+
111+
resp = self._client._request("GET", f"/v1/files/{file_id}")
112+
return _json_to_file(resp.json())
113+
114+
async def async_get(self, file_id: str) -> File:
115+
"""Get an uploaded file by its ID asynchronously."""
116+
117+
resp = await self._client._async_request("GET", f"/v1/files/{file_id}")
118+
return _json_to_file(resp.json())
119+
120+
def list(self) -> List[File]:
121+
"""List all uploaded files."""
122+
123+
resp = self._client._request("GET", "/v1/files")
124+
return [_json_to_file(obj) for obj in resp.json().get("results", [])]
125+
126+
async def async_list(self) -> List[File]:
127+
"""List all uploaded files asynchronously."""
128+
129+
resp = await self._client._async_request("GET", "/v1/files")
130+
return [_json_to_file(obj) for obj in resp.json().get("results", [])]
131+
132+
def delete(self, file_id: str) -> None:
133+
"""Delete an uploaded file by its ID."""
134+
135+
_ =self._client._request("DELETE", f"/v1/files/{file_id}")
136+
137+
async def async_delete(self, file_id: str) -> None:
138+
"""Delete an uploaded file by its ID asynchronously."""
139+
140+
_ = await self._client._async_request("DELETE", f"/v1/files/{file_id}")
141+
142+
def _create_file_params(
143+
file: Union[BinaryIO, io.IOBase],
144+
**params: Unpack["Files.CreateFileParams"],
145+
) -> Dict[str, Any]:
146+
file.seek(0)
147+
148+
if params is None:
149+
params = {}
150+
151+
filename = params.get("filename", os.path.basename(getattr(file, "name", "file")))
152+
content_type = (
153+
params.get("content_type")
154+
or mimetypes.guess_type(filename)[0]
155+
or "application/octet-stream"
156+
)
157+
metadata = params.get("metadata")
158+
159+
data = {}
160+
if metadata:
161+
data["metadata"] = json.dumps(metadata)
162+
163+
return {
164+
"files": {"content": (filename, file, content_type)},
165+
"data": data,
166+
}
167+
168+
def _json_to_file(json: Dict[str, Any]) -> File: # pylint: disable=redefined-outer-name
169+
return File(**json)
170+
171+
172+
def upload_file(file: io.IOBase, output_file_prefix: Optional[str] = None) -> str:
173+
"""
174+
Upload a file to the server.
175+
176+
Args:
177+
file: A file handle to upload.
178+
output_file_prefix: A string to prepend to the output file name.
179+
Returns:
180+
str: A URL to the uploaded file.
181+
"""
182+
# Lifted straight from cog.files
183+
184+
file.seek(0)
185+
186+
if output_file_prefix is not None:
187+
name = getattr(file, "name", "output")
188+
url = output_file_prefix + os.path.basename(name)
189+
resp = httpx.put(url, files={"file": file}, timeout=None) # type: ignore
190+
resp.raise_for_status()
191+
192+
return url
193+
194+
body = file.read()
195+
# Ensure the file handle is in bytes
196+
body = body.encode("utf-8") if isinstance(body, str) else body
197+
encoded_body = base64.b64encode(body).decode("utf-8")
198+
# Use getattr to avoid mypy complaints about io.IOBase having no attribute name
199+
mime_type = (
200+
mimetypes.guess_type(getattr(file, "name", ""))[0] or "application/octet-stream"
201+
)
202+
return f"data:{mime_type};base64,{encoded_body}"

replicate/files.py

Lines changed: 0 additions & 40 deletions
This file was deleted.

replicate/prediction.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from typing_extensions import NotRequired, TypedDict, Unpack
2020

2121
from replicate.exceptions import ModelError, ReplicateError
22-
from replicate.files import upload_file
22+
from replicate.file import upload_file
2323
from replicate.json import encode_json
2424
from replicate.pagination import Page
2525
from replicate.resource import Namespace, Resource

replicate/training.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
from typing_extensions import NotRequired, Unpack
1515

16-
from replicate.files import upload_file
16+
from replicate.file import upload_file
1717
from replicate.identifier import ModelVersionIdentifier
1818
from replicate.json import encode_json
1919
from replicate.model import Model

tests/test_file.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
import tempfile
2+
3+
import pytest
4+
5+
import replicate
6+
7+
8+
# @pytest.mark.vcr("files-operations.yaml")
9+
@pytest.mark.asyncio
10+
@pytest.mark.parametrize("async_flag", [True, False])
11+
async def test_file_operations(async_flag):
12+
# Create a sample file
13+
with tempfile.NamedTemporaryFile(
14+
mode="wb", delete=False, prefix="test_file", suffix=".txt"
15+
) as temp_file:
16+
temp_file.write(b"Hello, Replicate!")
17+
18+
# Test create
19+
if async_flag:
20+
created_file = await replicate.files.async_create(temp_file.name)
21+
else:
22+
created_file = replicate.files.create(temp_file.name)
23+
24+
assert created_file.name.startswith("test_file")
25+
assert created_file.name.endswith(".txt")
26+
file_id = created_file.id
27+
28+
# Test get
29+
if async_flag:
30+
retrieved_file = await replicate.files.async_get(file_id)
31+
else:
32+
retrieved_file = replicate.files.get(file_id)
33+
34+
assert retrieved_file.id == file_id
35+
36+
# Test list
37+
if async_flag:
38+
file_list = await replicate.files.async_list()
39+
else:
40+
file_list = replicate.files.list()
41+
42+
assert file_list is not None
43+
assert len(file_list) > 0
44+
assert any(f.id == file_id for f in file_list)
45+
46+
# Test delete
47+
if async_flag:
48+
await replicate.files.async_delete(file_id)
49+
else:
50+
replicate.files.delete(file_id)
51+
52+
# Verify file is deleted
53+
if async_flag:
54+
file_list = await replicate.files.async_list()
55+
else:
56+
file_list = replicate.files.list()
57+
58+
assert all(f.id != file_id for f in file_list)

0 commit comments

Comments
 (0)