Skip to content

Commit 81dc0df

Browse files
committed
Set filename parameter when delegating to create or async_create
Signed-off-by: Mattt Zmuda <[email protected]>
1 parent 33ac93a commit 81dc0df

File tree

2 files changed

+102
-1
lines changed

2 files changed

+102
-1
lines changed

replicate/file.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,8 @@ def create(
7171
"""
7272

7373
if isinstance(file, (str, pathlib.Path)):
74+
file_path = pathlib.Path(file)
75+
params["filename"] = file_path.name
7476
with open(file, "rb") as f:
7577
return self.create(f, **params)
7678
elif not isinstance(file, (io.IOBase, BinaryIO)):
@@ -92,7 +94,9 @@ async def async_create(
9294
"""Upload a file asynchronously that can be passed as an input when running a model."""
9395

9496
if isinstance(file, (str, pathlib.Path)):
95-
with open(file, "rb") as f:
97+
file_path = pathlib.Path(file)
98+
params["filename"] = file_path.name
99+
with open(file_path, "rb") as f:
96100
return await self.async_create(f, **params)
97101
elif not isinstance(file, (io.IOBase, BinaryIO)):
98102
raise ValueError(

tests/test_file.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,108 @@
33

44
import httpx
55
import pytest
6+
import respx
67

78
import replicate
9+
from replicate.client import Client
810

911
from .conftest import skip_if_no_token
1012

13+
router = respx.Router(base_url="https://api.replicate.com/v1")
14+
15+
router.route(
16+
method="POST",
17+
path="/files",
18+
name="files.create",
19+
).mock(
20+
return_value=httpx.Response(
21+
201,
22+
json={
23+
"id": "0ZjcyLWFhZjkNGZiNmY2YzQtMThhZi0tODg4NTY0NWNlMDEy",
24+
"name": "hello.txt",
25+
"size": 14,
26+
"content_type": "text/plain",
27+
"etag": "746308829575e17c3331bbcb00c0898b",
28+
"checksums": {
29+
"md5": "746308829575e17c3331bbcb00c0898b",
30+
"sha256": "d9014c4624844aa5bac314773d6b689ad467fa4e1d1a50a1b8a99d5a95f72ff5",
31+
},
32+
"metadata": {
33+
"foo": "bar",
34+
},
35+
"urls": {
36+
"get": "https://api.replicate.com/v1/files/0ZjcyLWFhZjkNGZiNmY2YzQtMThhZi0tODg4NTY0NWNlMDEy",
37+
},
38+
"created_at": "2024-08-22T12:26:51.079Z",
39+
"expires_at": "2024-08-22T13:26:51.079Z",
40+
},
41+
)
42+
)
43+
44+
45+
@pytest.mark.asyncio
46+
@pytest.mark.parametrize("async_flag", [True, False])
47+
@pytest.mark.parametrize("use_path", [True, False])
48+
async def test_file_create(async_flag, use_path):
49+
client = Client(
50+
api_token="test-token", transport=httpx.MockTransport(router.handler)
51+
)
52+
53+
temp_dir = tempfile.mkdtemp()
54+
temp_file_path = os.path.join(temp_dir, "hello.txt")
55+
56+
try:
57+
with open(temp_file_path, "w", encoding="utf-8") as temp_file:
58+
temp_file.write("Hello, world!")
59+
60+
metadata = {"foo": "bar"}
61+
62+
if use_path:
63+
file_arg = temp_file_path
64+
if async_flag:
65+
created_file = await client.files.async_create(
66+
file_arg, metadata=metadata
67+
)
68+
else:
69+
created_file = client.files.create(file_arg, metadata=metadata)
70+
else:
71+
with open(temp_file_path, "rb") as file_arg:
72+
if async_flag:
73+
created_file = await client.files.async_create(
74+
file_arg, metadata=metadata
75+
)
76+
else:
77+
created_file = client.files.create(file_arg, metadata=metadata)
78+
79+
assert router["files.create"].called
80+
request = router["files.create"].calls[0].request
81+
82+
# Check that the request is multipart/form-data
83+
assert request.headers["Content-Type"].startswith("multipart/form-data")
84+
85+
# Check that the filename is included and matches the fixed file name
86+
assert b'filename="hello.txt"' in request.content
87+
assert b"Hello, world!" in request.content
88+
89+
# Check the response
90+
assert created_file.id == "0ZjcyLWFhZjkNGZiNmY2YzQtMThhZi0tODg4NTY0NWNlMDEy"
91+
assert created_file.name == "hello.txt"
92+
assert created_file.size == 14
93+
assert created_file.content_type == "text/plain"
94+
assert created_file.etag == "746308829575e17c3331bbcb00c0898b"
95+
assert created_file.checksums == {
96+
"md5": "746308829575e17c3331bbcb00c0898b",
97+
"sha256": "d9014c4624844aa5bac314773d6b689ad467fa4e1d1a50a1b8a99d5a95f72ff5",
98+
}
99+
assert created_file.metadata == metadata
100+
assert created_file.urls == {
101+
"get": "https://api.replicate.com/v1/files/0ZjcyLWFhZjkNGZiNmY2YzQtMThhZi0tODg4NTY0NWNlMDEy",
102+
}
103+
104+
finally:
105+
os.unlink(temp_file_path)
106+
os.rmdir(temp_dir)
107+
11108

12109
@skip_if_no_token
13110
@pytest.mark.skipif(os.environ.get("CI") is not None, reason="Do not run on CI")

0 commit comments

Comments
 (0)