@@ -426,3 +426,84 @@ def prediction_with_status(
426426
427427 assert output1 .read () == b"Hello,"
428428 assert output2 .read () == b" world!"
429+
430+
431+ @pytest .mark .asyncio
432+ async def test_run_with_file_output_data_uri (mock_replicate_api_token ):
433+ def prediction_with_status (
434+ status : str , output : str | list [str ] | None = None
435+ ) -> dict :
436+ return {
437+ "id" : "p1" ,
438+ "model" : "test/example" ,
439+ "version" : "v1" ,
440+ "urls" : {
441+ "get" : "https://api.replicate.com/v1/predictions/p1" ,
442+ "cancel" : "https://api.replicate.com/v1/predictions/p1/cancel" ,
443+ },
444+ "created_at" : "2023-10-05T12:00:00.000000Z" ,
445+ "source" : "api" ,
446+ "status" : status ,
447+ "input" : {"text" : "world" },
448+ "output" : output ,
449+ "error" : "OOM" if status == "failed" else None ,
450+ "logs" : "" ,
451+ }
452+
453+ router = respx .Router (base_url = "https://api.replicate.com/v1" )
454+ router .route (method = "POST" , path = "/predictions" ).mock (
455+ return_value = httpx .Response (
456+ 201 ,
457+ json = prediction_with_status ("processing" ),
458+ )
459+ )
460+ router .route (method = "GET" , path = "/predictions/p1" ).mock (
461+ return_value = httpx .Response (
462+ 200 ,
463+ json = prediction_with_status (
464+ "succeeded" ,
465+ "data:text/plain;base64,SGVsbG8sIHdvcmxkIQ==" ,
466+ ),
467+ )
468+ )
469+ router .route (
470+ method = "GET" ,
471+ path = "/models/test/example/versions/v1" ,
472+ ).mock (
473+ return_value = httpx .Response (
474+ 201 ,
475+ json = {
476+ "id" : "f2d6b24e6002f25f77ae89c2b0a5987daa6d0bf751b858b94b8416e8542434d1" ,
477+ "created_at" : "2024-07-18T00:35:56.210272Z" ,
478+ "cog_version" : "0.9.10" ,
479+ "openapi_schema" : {
480+ "openapi" : "3.0.2" ,
481+ },
482+ },
483+ )
484+ )
485+
486+ client = Client (
487+ api_token = "test-token" , transport = httpx .MockTransport (router .handler )
488+ )
489+ client .poll_interval = 0.001
490+
491+ output = cast (
492+ FileOutput ,
493+ client .run (
494+ "test/example:v1" ,
495+ input = {
496+ "text" : "Hello, world!" ,
497+ },
498+ use_file_output = True ,
499+ ),
500+ )
501+
502+ assert output .url == "data:text/plain;base64,SGVsbG8sIHdvcmxkIQ=="
503+ assert output .read () == b"Hello, world!"
504+ for chunk in output :
505+ assert chunk == b"Hello, world!"
506+
507+ assert await output .aread () == b"Hello, world!"
508+ async for chunk in output :
509+ assert chunk == b"Hello, world!"
0 commit comments