Skip to content

Commit 548f71b

Browse files
committed
Set filename parameter when delegating to create or async_create
Signed-off-by: Mattt Zmuda <[email protected]>
1 parent 33ac93a commit 548f71b

File tree

2 files changed

+97
-1
lines changed

2 files changed

+97
-1
lines changed

replicate/file.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,8 @@ def create(
7171
"""
7272

7373
if isinstance(file, (str, pathlib.Path)):
74+
file_path = pathlib.Path(file)
75+
params["filename"] = file_path.name
7476
with open(file, "rb") as f:
7577
return self.create(f, **params)
7678
elif not isinstance(file, (io.IOBase, BinaryIO)):
@@ -92,7 +94,9 @@ async def async_create(
9294
"""Upload a file asynchronously that can be passed as an input when running a model."""
9395

9496
if isinstance(file, (str, pathlib.Path)):
95-
with open(file, "rb") as f:
97+
file_path = pathlib.Path(file)
98+
params["filename"] = file_path.name
99+
with open(file_path, "rb") as f:
96100
return await self.async_create(f, **params)
97101
elif not isinstance(file, (io.IOBase, BinaryIO)):
98102
raise ValueError(

tests/test_file.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,103 @@
33

44
import httpx
55
import pytest
6+
import respx
67

78
import replicate
9+
from replicate.client import Client
810

911
from .conftest import skip_if_no_token
1012

13+
router = respx.Router(base_url="https://api.replicate.com/v1")
14+
15+
router.route(
16+
method="POST",
17+
path="/files",
18+
name="files.create",
19+
).mock(
20+
return_value=httpx.Response(
21+
201,
22+
json={
23+
"id": "0ZjcyLWFhZjkNGZiNmY2YzQtMThhZi0tODg4NTY0NWNlMDEy",
24+
"name": "hello.txt",
25+
"size": 14,
26+
"content_type": "text/plain",
27+
"etag": "746308829575e17c3331bbcb00c0898b",
28+
"checksums": {
29+
"md5": "746308829575e17c3331bbcb00c0898b",
30+
"sha256": "d9014c4624844aa5bac314773d6b689ad467fa4e1d1a50a1b8a99d5a95f72ff5"
31+
},
32+
"metadata": {
33+
"foo": "bar",
34+
},
35+
"urls": {
36+
"get": "https://api.replicate.com/v1/files/0ZjcyLWFhZjkNGZiNmY2YzQtMThhZi0tODg4NTY0NWNlMDEy",
37+
},
38+
"created_at": "2024-08-22T12:26:51.079Z",
39+
"expires_at": "2024-08-22T13:26:51.079Z",
40+
},
41+
)
42+
)
43+
44+
@pytest.mark.asyncio
45+
@pytest.mark.parametrize("async_flag", [True, False])
46+
@pytest.mark.parametrize("use_path", [True, False])
47+
async def test_file_create(async_flag, use_path):
48+
client = Client(
49+
api_token="test-token", transport=httpx.MockTransport(router.handler)
50+
)
51+
52+
temp_dir = tempfile.mkdtemp()
53+
temp_file_path = os.path.join(temp_dir, "hello.txt")
54+
55+
try:
56+
with open(temp_file_path, "w", encoding="utf-8") as temp_file:
57+
temp_file.write("Hello, world!")
58+
59+
metadata = {"foo": "bar"}
60+
61+
if use_path:
62+
file_arg = temp_file_path
63+
if async_flag:
64+
created_file = await client.files.async_create(file_arg, metadata=metadata)
65+
else:
66+
created_file = client.files.create(file_arg, metadata=metadata)
67+
else:
68+
with open(temp_file_path, "rb") as file_arg:
69+
if async_flag:
70+
created_file = await client.files.async_create(file_arg, metadata=metadata)
71+
else:
72+
created_file = client.files.create(file_arg, metadata=metadata)
73+
74+
assert router["files.create"].called
75+
request = router["files.create"].calls[0].request
76+
77+
# Check that the request is multipart/form-data
78+
assert request.headers["Content-Type"].startswith("multipart/form-data")
79+
80+
# Check that the filename is included and matches the fixed file name
81+
assert b'filename="hello.txt"' in request.content
82+
assert b"Hello, world!" in request.content
83+
84+
# Check the response
85+
assert created_file.id == "0ZjcyLWFhZjkNGZiNmY2YzQtMThhZi0tODg4NTY0NWNlMDEy"
86+
assert created_file.name == "hello.txt"
87+
assert created_file.size == 14
88+
assert created_file.content_type == "text/plain"
89+
assert created_file.etag == "746308829575e17c3331bbcb00c0898b"
90+
assert created_file.checksums == {
91+
"md5": "746308829575e17c3331bbcb00c0898b",
92+
"sha256": "d9014c4624844aa5bac314773d6b689ad467fa4e1d1a50a1b8a99d5a95f72ff5"
93+
}
94+
assert created_file.metadata == metadata
95+
assert created_file.urls == {
96+
"get": "https://api.replicate.com/v1/files/0ZjcyLWFhZjkNGZiNmY2YzQtMThhZi0tODg4NTY0NWNlMDEy",
97+
}
98+
99+
finally:
100+
os.unlink(temp_file_path)
101+
os.rmdir(temp_dir)
102+
11103

12104
@skip_if_no_token
13105
@pytest.mark.skipif(os.environ.get("CI") is not None, reason="Do not run on CI")

0 commit comments

Comments
 (0)