diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 00000000..88f170ab --- /dev/null +++ b/.gitattributes @@ -0,0 +1 @@ +tests/cassettes/** binary diff --git a/pyproject.toml b/pyproject.toml index 2fb22177..b2f39273 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,7 +15,8 @@ optional-dependencies = { dev = [ "black", "mypy", "pytest", - "responses", + "pytest-asyncio", + "pytest-recording", "ruff", ] } diff --git a/requirements-dev.txt b/requirements-dev.txt index 3f170d83..75c24e8b 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -15,9 +15,13 @@ charset-normalizer==3.2.0 click==8.1.6 # via black idna==3.4 - # via requests + # via + # requests + # yarl iniconfig==2.0.0 # via pytest +multidict==6.0.4 + # via yarl mypy==1.4.1 # via replicate (pyproject.toml) mypy-extensions==1.0.0 @@ -40,25 +44,30 @@ pydantic==2.0.3 pydantic-core==2.3.0 # via pydantic pytest==7.4.0 + # via + # pytest-asyncio + # pytest-recording + # replicate (pyproject.toml) +pytest-asyncio==0.21.1 + # via replicate (pyproject.toml) +pytest-recording==0.13.0 # via replicate (pyproject.toml) pyyaml==6.0.1 - # via responses + # via vcrpy requests==2.31.0 - # via - # replicate (pyproject.toml) - # responses -responses==0.23.1 # via replicate (pyproject.toml) ruff==0.0.278 # via replicate (pyproject.toml) -types-pyyaml==6.0.12.10 - # via responses typing-extensions==4.7.1 # via # mypy # pydantic # pydantic-core urllib3==2.0.3 - # via - # requests - # responses + # via requests +vcrpy==5.1.0 + # via pytest-recording +wrapt==1.15.0 + # via vcrpy +yarl==1.9.2 + # via vcrpy diff --git a/tests/cassettes/predictions-cancel.yaml b/tests/cassettes/predictions-cancel.yaml new file mode 100644 index 00000000..3bb7dce2 --- /dev/null +++ b/tests/cassettes/predictions-cancel.yaml @@ -0,0 +1,232 @@ +interactions: +- request: + body: null + headers: + Accept: + - '*/*' + Accept-Encoding: + - gzip, deflate + Connection: + - keep-alive + User-Agent: + - replicate-python/0.11.0 + method: GET + uri: https://api.replicate.com/v1/models/stability-ai/sdxl/versions/a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5 + response: + body: + string: '{"id":"a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5","created_at":"2023-08-11T17:00:23.916379Z","cog_version":"0.8.5","openapi_schema":{"info":{"title":"Cog","version":"0.1.0"},"paths":{"/":{"get":{"summary":"Root","responses":{"200":{"content":{"application/json":{"schema":{"title":"Response + Root Get"}}},"description":"Successful Response"}},"operationId":"root__get"}},"/shutdown":{"post":{"summary":"Start + Shutdown","responses":{"200":{"content":{"application/json":{"schema":{"title":"Response + Start Shutdown Shutdown Post"}}},"description":"Successful Response"}},"operationId":"start_shutdown_shutdown_post"}},"/predictions":{"post":{"summary":"Predict","responses":{"200":{"content":{"application/json":{"schema":{"$ref":"#/components/schemas/PredictionResponse"}}},"description":"Successful + Response"},"422":{"content":{"application/json":{"schema":{"$ref":"#/components/schemas/HTTPValidationError"}}},"description":"Validation + Error"}},"parameters":[{"in":"header","name":"prefer","schema":{"type":"string","title":"Prefer"},"required":false}],"description":"Run + a single prediction on the model","operationId":"predict_predictions_post","requestBody":{"content":{"application/json":{"schema":{"$ref":"#/components/schemas/PredictionRequest"}}}}}},"/health-check":{"get":{"summary":"Healthcheck","responses":{"200":{"content":{"application/json":{"schema":{"title":"Response + Healthcheck Health Check Get"}}},"description":"Successful Response"}},"operationId":"healthcheck_health_check_get"}},"/predictions/{prediction_id}":{"put":{"summary":"Predict + Idempotent","responses":{"200":{"content":{"application/json":{"schema":{"$ref":"#/components/schemas/PredictionResponse"}}},"description":"Successful + Response"},"422":{"content":{"application/json":{"schema":{"$ref":"#/components/schemas/HTTPValidationError"}}},"description":"Validation + Error"}},"parameters":[{"in":"path","name":"prediction_id","schema":{"type":"string","title":"Prediction + ID"},"required":true},{"in":"header","name":"prefer","schema":{"type":"string","title":"Prefer"},"required":false}],"description":"Run + a single prediction on the model (idempotent creation).","operationId":"predict_idempotent_predictions__prediction_id__put","requestBody":{"content":{"application/json":{"schema":{"allOf":[{"$ref":"#/components/schemas/PredictionRequest"}],"title":"Prediction + Request"}}},"required":true}}},"/predictions/{prediction_id}/cancel":{"post":{"summary":"Cancel","responses":{"200":{"content":{"application/json":{"schema":{"title":"Response + Cancel Predictions Prediction Id Cancel Post"}}},"description":"Successful + Response"},"422":{"content":{"application/json":{"schema":{"$ref":"#/components/schemas/HTTPValidationError"}}},"description":"Validation + Error"}},"parameters":[{"in":"path","name":"prediction_id","schema":{"type":"string","title":"Prediction + ID"},"required":true}],"description":"Cancel a running prediction","operationId":"cancel_predictions__prediction_id__cancel_post"}}},"openapi":"3.0.2","components":{"schemas":{"Input":{"type":"object","title":"Input","properties":{"mask":{"type":"string","title":"Mask","format":"uri","x-order":3,"description":"Input + mask for inpaint mode. Black areas will be preserved, white areas will be + inpainted."},"seed":{"type":"integer","title":"Seed","x-order":11,"description":"Random + seed. Leave blank to randomize the seed"},"image":{"type":"string","title":"Image","format":"uri","x-order":2,"description":"Input + image for img2img or inpaint mode"},"width":{"type":"integer","title":"Width","default":1024,"x-order":4,"description":"Width + of output image"},"height":{"type":"integer","title":"Height","default":1024,"x-order":5,"description":"Height + of output image"},"prompt":{"type":"string","title":"Prompt","default":"An + astronaut riding a rainbow unicorn","x-order":0,"description":"Input prompt"},"refine":{"allOf":[{"$ref":"#/components/schemas/refine"}],"default":"no_refiner","x-order":12,"description":"Which + refine style to use"},"scheduler":{"allOf":[{"$ref":"#/components/schemas/scheduler"}],"default":"K_EULER","x-order":7,"description":"scheduler"},"lora_scale":{"type":"number","title":"Lora + Scale","default":0.6,"maximum":1,"minimum":0,"x-order":16,"description":"LoRA + additive scale. Only applicable on trained models."},"num_outputs":{"type":"integer","title":"Num + Outputs","default":1,"maximum":4,"minimum":1,"x-order":6,"description":"Number + of images to output."},"refine_steps":{"type":"integer","title":"Refine Steps","x-order":14,"description":"For + base_image_refiner, the number of steps to refine, defaults to num_inference_steps"},"guidance_scale":{"type":"number","title":"Guidance + Scale","default":7.5,"maximum":50,"minimum":1,"x-order":9,"description":"Scale + for classifier-free guidance"},"apply_watermark":{"type":"boolean","title":"Apply + Watermark","default":true,"x-order":15,"description":"Applies a watermark + to enable determining if an image is generated in downstream applications. + If you have other provisions for generating or deploying images safely, you + can use this to disable watermarking."},"high_noise_frac":{"type":"number","title":"High + Noise Frac","default":0.8,"maximum":1,"minimum":0,"x-order":13,"description":"For + expert_ensemble_refiner, the fraction of noise to use"},"negative_prompt":{"type":"string","title":"Negative + Prompt","default":"","x-order":1,"description":"Input Negative Prompt"},"prompt_strength":{"type":"number","title":"Prompt + Strength","default":0.8,"maximum":1,"minimum":0,"x-order":10,"description":"Prompt + strength when using img2img / inpaint. 1.0 corresponds to full destruction + of information in image"},"num_inference_steps":{"type":"integer","title":"Num + Inference Steps","default":50,"maximum":500,"minimum":1,"x-order":8,"description":"Number + of denoising steps"}}},"Output":{"type":"array","items":{"type":"string","format":"uri"},"title":"Output"},"Status":{"enum":["starting","processing","succeeded","canceled","failed"],"type":"string","title":"Status","description":"An + enumeration."},"refine":{"enum":["no_refiner","expert_ensemble_refiner","base_image_refiner"],"type":"string","title":"refine","description":"An + enumeration."},"scheduler":{"enum":["DDIM","DPMSolverMultistep","HeunDiscrete","KarrasDPM","K_EULER_ANCESTRAL","K_EULER","PNDM"],"type":"string","title":"scheduler","description":"An + enumeration."},"WebhookEvent":{"enum":["start","output","logs","completed"],"type":"string","title":"WebhookEvent","description":"An + enumeration."},"ValidationError":{"type":"object","title":"ValidationError","required":["loc","msg","type"],"properties":{"loc":{"type":"array","items":{"anyOf":[{"type":"string"},{"type":"integer"}]},"title":"Location"},"msg":{"type":"string","title":"Message"},"type":{"type":"string","title":"Error + Type"}}},"PredictionRequest":{"type":"object","title":"PredictionRequest","properties":{"id":{"type":"string","title":"Id"},"input":{"$ref":"#/components/schemas/Input"},"webhook":{"type":"string","title":"Webhook","format":"uri","maxLength":65536,"minLength":1},"created_at":{"type":"string","title":"Created + At","format":"date-time"},"output_file_prefix":{"type":"string","title":"Output + File Prefix"},"webhook_events_filter":{"type":"array","items":{"$ref":"#/components/schemas/WebhookEvent"},"default":["start","completed","logs","output"],"uniqueItems":true}}},"PredictionResponse":{"type":"object","title":"PredictionResponse","properties":{"id":{"type":"string","title":"Id"},"logs":{"type":"string","title":"Logs","default":""},"error":{"type":"string","title":"Error"},"input":{"$ref":"#/components/schemas/Input"},"output":{"$ref":"#/components/schemas/Output"},"status":{"$ref":"#/components/schemas/Status"},"metrics":{"type":"object","title":"Metrics"},"version":{"type":"string","title":"Version"},"created_at":{"type":"string","title":"Created + At","format":"date-time"},"started_at":{"type":"string","title":"Started At","format":"date-time"},"completed_at":{"type":"string","title":"Completed + At","format":"date-time"}}},"HTTPValidationError":{"type":"object","title":"HTTPValidationError","properties":{"detail":{"type":"array","items":{"$ref":"#/components/schemas/ValidationError"},"title":"Detail"}}}}}}}' + headers: + CF-Cache-Status: + - DYNAMIC + CF-RAY: + - 7fdbf06afe2ec38e-SEA + Connection: + - keep-alive + Content-Encoding: + - gzip + Content-Type: + - application/json + Date: + - Mon, 28 Aug 2023 10:40:58 GMT + NEL: + - '{"success_fraction":0,"report_to":"cf-nel","max_age":604800}' + Report-To: + - '{"endpoints":[{"url":"https:\/\/a.nel.cloudflare.com\/report\/v3?s=lqn6vZtiIEuD6ah43djQbwFcdhqmKZXq1T3A%2Bep1EXRg60x7ti4mgtoXcXd8wvjbi65JeGhwqeW%2B2tMRKx%2B0i0tJRtYcZylABbnBdtYNkp4lTeuq5BY9r7TOIYehrRHx192O"}],"group":"cf-nel","max_age":604800}' + Server: + - cloudflare + Strict-Transport-Security: + - max-age=15552000 + Transfer-Encoding: + - chunked + allow: + - GET, DELETE, HEAD, OPTIONS + content-security-policy-report-only: + - 'media-src ''report-sample'' ''self'' https://replicate.delivery https://*.replicate.delivery + https://*.mux.com https://*.gstatic.com https://*.sentry.io; connect-src ''report-sample'' + ''self'' https://replicate.delivery https://*.replicate.delivery https://*.rudderlabs.com + https://*.rudderstack.com https://*.mux.com https://*.sentry.io; font-src + ''report-sample'' ''self'' data: https://fonts.replicate.ai https://fonts.gstatic.com; + script-src ''report-sample'' ''self'' https://cdn.rudderlabs.com/v1.1/rudder-analytics.min.js; + worker-src ''none''; img-src ''report-sample'' ''self'' data: https://replicate.delivery + https://*.replicate.delivery https://*.githubusercontent.com https://github.com; + style-src ''report-sample'' ''self'' ''unsafe-inline'' https://fonts.googleapis.com; + default-src ''self''; report-uri' + cross-origin-opener-policy: + - same-origin + ratelimit-remaining: + - '2999' + ratelimit-reset: + - '1' + referrer-policy: + - same-origin + vary: + - Cookie, origin + via: + - 1.1 vegur, 1.1 google + x-content-type-options: + - nosniff + x-frame-options: + - DENY + status: + code: 200 + message: OK +- request: + body: '{"version": "a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5", + "input": {"prompt": "a studio photo of a rainbow colored corgi", "width": 512, + "height": 512, "seed": 42069}}' + headers: + Accept: + - '*/*' + Accept-Encoding: + - gzip, deflate + Connection: + - keep-alive + Content-Length: + - '189' + Content-Type: + - application/json + User-Agent: + - replicate-python/0.11.0 + method: POST + uri: https://api.replicate.com/v1/predictions + response: + body: + string: '{"id":"dj2xhz3b6iihe2ewh3d3fdtram","version":"a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5","input":{"height":512,"prompt":"a + studio photo of a rainbow colored corgi","seed":42069,"width":512},"logs":"","error":null,"status":"starting","created_at":"2023-08-28T10:40:58.313900179Z","urls":{"cancel":"https://api.replicate.com/v1/predictions/dj2xhz3b6iihe2ewh3d3fdtram/cancel","get":"https://api.replicate.com/v1/predictions/dj2xhz3b6iihe2ewh3d3fdtram"}} + + ' + headers: + CF-Cache-Status: + - DYNAMIC + CF-RAY: + - 7fdbf06bef6aeb53-SEA + Connection: + - keep-alive + Content-Length: + - '474' + Content-Type: + - application/json + Date: + - Mon, 28 Aug 2023 10:40:58 GMT + NEL: + - '{"success_fraction":0,"report_to":"cf-nel","max_age":604800}' + Report-To: + - '{"endpoints":[{"url":"https:\/\/a.nel.cloudflare.com\/report\/v3?s=%2BvAgJlOECQToRU65Zaf3t3e27iqz%2Fyt3cIYchCouUpOy%2FXvjIiIe%2BDAC4DYVSBKV6OznfDoLuOTF%2BzC9XodVLtQ6NQcTC1jB5sTHLgUf4n5w%2FknHpcXlYPI23jyU9eD72cUt"}],"group":"cf-nel","max_age":604800}' + Server: + - cloudflare + Strict-Transport-Security: + - max-age=15552000 + ratelimit-remaining: + - '599' + ratelimit-reset: + - '1' + via: + - 1.1 google + status: + code: 201 + message: Created +- request: + body: null + headers: + Accept: + - '*/*' + Accept-Encoding: + - gzip, deflate + Connection: + - keep-alive + Content-Length: + - '0' + User-Agent: + - replicate-python/0.11.0 + method: POST + uri: https://api.replicate.com/v1/predictions/dj2xhz3b6iihe2ewh3d3fdtram/cancel + response: + body: + string: '{"completed_at":"2023-08-28T10:40:58.480860Z","created_at":"2023-08-28T10:40:58.385013Z","error":null,"id":"dj2xhz3b6iihe2ewh3d3fdtram","input":{"seed":42069,"width":512,"height":512,"prompt":"a + studio photo of a rainbow colored corgi"},"logs":null,"metrics":{"predict_time":2e-06},"output":null,"started_at":"2023-08-28T10:40:58.480858Z","status":"canceled","urls":{"get":"https://api.replicate.com/v1/predictions/dj2xhz3b6iihe2ewh3d3fdtram","cancel":"https://api.replicate.com/v1/predictions/dj2xhz3b6iihe2ewh3d3fdtram/cancel"},"version":"a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5","webhook_completed":null}' + headers: + CF-Cache-Status: + - DYNAMIC + CF-RAY: + - 7fdbf06d0deac766-SEA + Connection: + - keep-alive + Content-Encoding: + - gzip + Content-Type: + - application/json + Date: + - Mon, 28 Aug 2023 10:40:58 GMT + NEL: + - '{"success_fraction":0,"report_to":"cf-nel","max_age":604800}' + Report-To: + - '{"endpoints":[{"url":"https:\/\/a.nel.cloudflare.com\/report\/v3?s=jeEQzHBiX9tZ4IdkiEgJjenWYtZvy3CafnOWdp%2BQZDH7ox2wiTLTfB57ya0p%2BwCYVTB0ITeL5DAfZBsnpGv9D1zctc%2FkSID7QL9f4QikvyILNB%2BPSvRf8jLYjWN7UfJ2l9g7"}],"group":"cf-nel","max_age":604800}' + Server: + - cloudflare + Strict-Transport-Security: + - max-age=15552000 + Transfer-Encoding: + - chunked + ratelimit-remaining: + - '2999' + ratelimit-reset: + - '1' + via: + - 1.1 google + status: + code: 200 + message: OK +version: 1 diff --git a/tests/cassettes/predictions-create.yaml b/tests/cassettes/predictions-create.yaml new file mode 100644 index 00000000..d1ec591c --- /dev/null +++ b/tests/cassettes/predictions-create.yaml @@ -0,0 +1,180 @@ +interactions: +- request: + body: null + headers: + Accept: + - '*/*' + Accept-Encoding: + - gzip, deflate + Connection: + - keep-alive + User-Agent: + - replicate-python/0.11.0 + method: GET + uri: https://api.replicate.com/v1/models/stability-ai/sdxl/versions/a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5 + response: + body: + string: '{"id":"a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5","created_at":"2023-08-11T17:00:23.916379Z","cog_version":"0.8.5","openapi_schema":{"info":{"title":"Cog","version":"0.1.0"},"paths":{"/":{"get":{"summary":"Root","responses":{"200":{"content":{"application/json":{"schema":{"title":"Response + Root Get"}}},"description":"Successful Response"}},"operationId":"root__get"}},"/shutdown":{"post":{"summary":"Start + Shutdown","responses":{"200":{"content":{"application/json":{"schema":{"title":"Response + Start Shutdown Shutdown Post"}}},"description":"Successful Response"}},"operationId":"start_shutdown_shutdown_post"}},"/predictions":{"post":{"summary":"Predict","responses":{"200":{"content":{"application/json":{"schema":{"$ref":"#/components/schemas/PredictionResponse"}}},"description":"Successful + Response"},"422":{"content":{"application/json":{"schema":{"$ref":"#/components/schemas/HTTPValidationError"}}},"description":"Validation + Error"}},"parameters":[{"in":"header","name":"prefer","schema":{"type":"string","title":"Prefer"},"required":false}],"description":"Run + a single prediction on the model","operationId":"predict_predictions_post","requestBody":{"content":{"application/json":{"schema":{"$ref":"#/components/schemas/PredictionRequest"}}}}}},"/health-check":{"get":{"summary":"Healthcheck","responses":{"200":{"content":{"application/json":{"schema":{"title":"Response + Healthcheck Health Check Get"}}},"description":"Successful Response"}},"operationId":"healthcheck_health_check_get"}},"/predictions/{prediction_id}":{"put":{"summary":"Predict + Idempotent","responses":{"200":{"content":{"application/json":{"schema":{"$ref":"#/components/schemas/PredictionResponse"}}},"description":"Successful + Response"},"422":{"content":{"application/json":{"schema":{"$ref":"#/components/schemas/HTTPValidationError"}}},"description":"Validation + Error"}},"parameters":[{"in":"path","name":"prediction_id","schema":{"type":"string","title":"Prediction + ID"},"required":true},{"in":"header","name":"prefer","schema":{"type":"string","title":"Prefer"},"required":false}],"description":"Run + a single prediction on the model (idempotent creation).","operationId":"predict_idempotent_predictions__prediction_id__put","requestBody":{"content":{"application/json":{"schema":{"allOf":[{"$ref":"#/components/schemas/PredictionRequest"}],"title":"Prediction + Request"}}},"required":true}}},"/predictions/{prediction_id}/cancel":{"post":{"summary":"Cancel","responses":{"200":{"content":{"application/json":{"schema":{"title":"Response + Cancel Predictions Prediction Id Cancel Post"}}},"description":"Successful + Response"},"422":{"content":{"application/json":{"schema":{"$ref":"#/components/schemas/HTTPValidationError"}}},"description":"Validation + Error"}},"parameters":[{"in":"path","name":"prediction_id","schema":{"type":"string","title":"Prediction + ID"},"required":true}],"description":"Cancel a running prediction","operationId":"cancel_predictions__prediction_id__cancel_post"}}},"openapi":"3.0.2","components":{"schemas":{"Input":{"type":"object","title":"Input","properties":{"mask":{"type":"string","title":"Mask","format":"uri","x-order":3,"description":"Input + mask for inpaint mode. Black areas will be preserved, white areas will be + inpainted."},"seed":{"type":"integer","title":"Seed","x-order":11,"description":"Random + seed. Leave blank to randomize the seed"},"image":{"type":"string","title":"Image","format":"uri","x-order":2,"description":"Input + image for img2img or inpaint mode"},"width":{"type":"integer","title":"Width","default":1024,"x-order":4,"description":"Width + of output image"},"height":{"type":"integer","title":"Height","default":1024,"x-order":5,"description":"Height + of output image"},"prompt":{"type":"string","title":"Prompt","default":"An + astronaut riding a rainbow unicorn","x-order":0,"description":"Input prompt"},"refine":{"allOf":[{"$ref":"#/components/schemas/refine"}],"default":"no_refiner","x-order":12,"description":"Which + refine style to use"},"scheduler":{"allOf":[{"$ref":"#/components/schemas/scheduler"}],"default":"K_EULER","x-order":7,"description":"scheduler"},"lora_scale":{"type":"number","title":"Lora + Scale","default":0.6,"maximum":1,"minimum":0,"x-order":16,"description":"LoRA + additive scale. Only applicable on trained models."},"num_outputs":{"type":"integer","title":"Num + Outputs","default":1,"maximum":4,"minimum":1,"x-order":6,"description":"Number + of images to output."},"refine_steps":{"type":"integer","title":"Refine Steps","x-order":14,"description":"For + base_image_refiner, the number of steps to refine, defaults to num_inference_steps"},"guidance_scale":{"type":"number","title":"Guidance + Scale","default":7.5,"maximum":50,"minimum":1,"x-order":9,"description":"Scale + for classifier-free guidance"},"apply_watermark":{"type":"boolean","title":"Apply + Watermark","default":true,"x-order":15,"description":"Applies a watermark + to enable determining if an image is generated in downstream applications. + If you have other provisions for generating or deploying images safely, you + can use this to disable watermarking."},"high_noise_frac":{"type":"number","title":"High + Noise Frac","default":0.8,"maximum":1,"minimum":0,"x-order":13,"description":"For + expert_ensemble_refiner, the fraction of noise to use"},"negative_prompt":{"type":"string","title":"Negative + Prompt","default":"","x-order":1,"description":"Input Negative Prompt"},"prompt_strength":{"type":"number","title":"Prompt + Strength","default":0.8,"maximum":1,"minimum":0,"x-order":10,"description":"Prompt + strength when using img2img / inpaint. 1.0 corresponds to full destruction + of information in image"},"num_inference_steps":{"type":"integer","title":"Num + Inference Steps","default":50,"maximum":500,"minimum":1,"x-order":8,"description":"Number + of denoising steps"}}},"Output":{"type":"array","items":{"type":"string","format":"uri"},"title":"Output"},"Status":{"enum":["starting","processing","succeeded","canceled","failed"],"type":"string","title":"Status","description":"An + enumeration."},"refine":{"enum":["no_refiner","expert_ensemble_refiner","base_image_refiner"],"type":"string","title":"refine","description":"An + enumeration."},"scheduler":{"enum":["DDIM","DPMSolverMultistep","HeunDiscrete","KarrasDPM","K_EULER_ANCESTRAL","K_EULER","PNDM"],"type":"string","title":"scheduler","description":"An + enumeration."},"WebhookEvent":{"enum":["start","output","logs","completed"],"type":"string","title":"WebhookEvent","description":"An + enumeration."},"ValidationError":{"type":"object","title":"ValidationError","required":["loc","msg","type"],"properties":{"loc":{"type":"array","items":{"anyOf":[{"type":"string"},{"type":"integer"}]},"title":"Location"},"msg":{"type":"string","title":"Message"},"type":{"type":"string","title":"Error + Type"}}},"PredictionRequest":{"type":"object","title":"PredictionRequest","properties":{"id":{"type":"string","title":"Id"},"input":{"$ref":"#/components/schemas/Input"},"webhook":{"type":"string","title":"Webhook","format":"uri","maxLength":65536,"minLength":1},"created_at":{"type":"string","title":"Created + At","format":"date-time"},"output_file_prefix":{"type":"string","title":"Output + File Prefix"},"webhook_events_filter":{"type":"array","items":{"$ref":"#/components/schemas/WebhookEvent"},"default":["start","completed","logs","output"],"uniqueItems":true}}},"PredictionResponse":{"type":"object","title":"PredictionResponse","properties":{"id":{"type":"string","title":"Id"},"logs":{"type":"string","title":"Logs","default":""},"error":{"type":"string","title":"Error"},"input":{"$ref":"#/components/schemas/Input"},"output":{"$ref":"#/components/schemas/Output"},"status":{"$ref":"#/components/schemas/Status"},"metrics":{"type":"object","title":"Metrics"},"version":{"type":"string","title":"Version"},"created_at":{"type":"string","title":"Created + At","format":"date-time"},"started_at":{"type":"string","title":"Started At","format":"date-time"},"completed_at":{"type":"string","title":"Completed + At","format":"date-time"}}},"HTTPValidationError":{"type":"object","title":"HTTPValidationError","properties":{"detail":{"type":"array","items":{"$ref":"#/components/schemas/ValidationError"},"title":"Detail"}}}}}}}' + headers: + CF-Cache-Status: + - DYNAMIC + CF-RAY: + - 7fdbf067ef14ec13-SEA + Connection: + - keep-alive + Content-Encoding: + - gzip + Content-Type: + - application/json + Date: + - Mon, 28 Aug 2023 10:40:57 GMT + NEL: + - '{"success_fraction":0,"report_to":"cf-nel","max_age":604800}' + Report-To: + - '{"endpoints":[{"url":"https:\/\/a.nel.cloudflare.com\/report\/v3?s=Uh%2FoF9Q9e0lNAD3h0gUDCIbpo5P2FAU1iwdHyD2fpH3umx0DL7Kkfi%2Bmj2okKR%2FMeY%2BCxLbVobvcCAW8EmeC0IFzMLEAGUalePlNCjeaDcQnNdMCKOUrhf7RXwNMThtPLQIy"}],"group":"cf-nel","max_age":604800}' + Server: + - cloudflare + Strict-Transport-Security: + - max-age=15552000 + Transfer-Encoding: + - chunked + allow: + - GET, DELETE, HEAD, OPTIONS + content-security-policy-report-only: + - 'media-src ''report-sample'' ''self'' https://replicate.delivery https://*.replicate.delivery + https://*.mux.com https://*.gstatic.com https://*.sentry.io; img-src ''report-sample'' + ''self'' data: https://replicate.delivery https://*.replicate.delivery https://*.githubusercontent.com + https://github.com; style-src ''report-sample'' ''self'' ''unsafe-inline'' + https://fonts.googleapis.com; script-src ''report-sample'' ''self'' https://cdn.rudderlabs.com/v1.1/rudder-analytics.min.js; + font-src ''report-sample'' ''self'' data: https://fonts.replicate.ai https://fonts.gstatic.com; + worker-src ''none''; default-src ''self''; connect-src ''report-sample'' ''self'' + https://replicate.delivery https://*.replicate.delivery https://*.rudderlabs.com + https://*.rudderstack.com https://*.mux.com https://*.sentry.io; report-uri' + cross-origin-opener-policy: + - same-origin + ratelimit-remaining: + - '2999' + ratelimit-reset: + - '1' + referrer-policy: + - same-origin + vary: + - Cookie, origin + via: + - 1.1 vegur, 1.1 google + x-content-type-options: + - nosniff + x-frame-options: + - DENY + status: + code: 200 + message: OK +- request: + body: '{"version": "a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5", + "input": {"prompt": "a studio photo of a rainbow colored corgi", "width": 512, + "height": 512, "seed": 42069}}' + headers: + Accept: + - '*/*' + Accept-Encoding: + - gzip, deflate + Connection: + - keep-alive + Content-Length: + - '189' + Content-Type: + - application/json + User-Agent: + - replicate-python/0.11.0 + method: POST + uri: https://api.replicate.com/v1/predictions + response: + body: + string: '{"id":"3d47fqtb4bnhj466vasrt5au2i","version":"a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5","input":{"height":512,"prompt":"a + studio photo of a rainbow colored corgi","seed":42069,"width":512},"logs":"","error":null,"status":"starting","created_at":"2023-08-28T10:40:57.808792401Z","urls":{"cancel":"https://api.replicate.com/v1/predictions/3d47fqtb4bnhj466vasrt5au2i/cancel","get":"https://api.replicate.com/v1/predictions/3d47fqtb4bnhj466vasrt5au2i"}} + + ' + headers: + CF-Cache-Status: + - DYNAMIC + CF-RAY: + - 7fdbf068cb322841-SEA + Connection: + - keep-alive + Content-Length: + - '474' + Content-Type: + - application/json + Date: + - Mon, 28 Aug 2023 10:40:57 GMT + NEL: + - '{"success_fraction":0,"report_to":"cf-nel","max_age":604800}' + Report-To: + - '{"endpoints":[{"url":"https:\/\/a.nel.cloudflare.com\/report\/v3?s=oNoeqo0NBP8H%2BEwqoufE67dj3aeN6GDgA45VtolRjAv8j1ktTZdMZ7GH7R6tAbhaWAX83ARMOlCoYlY%2BfTWq6WeoHW7%2FfMNFOQES%2B9RBaiZgSodPoOvrpqa7DBmkOoGYWegF"}],"group":"cf-nel","max_age":604800}' + Server: + - cloudflare + Strict-Transport-Security: + - max-age=15552000 + ratelimit-remaining: + - '599' + ratelimit-reset: + - '1' + via: + - 1.1 google + status: + code: 201 + message: Created +version: 1 diff --git a/tests/cassettes/predictions-get.yaml b/tests/cassettes/predictions-get.yaml new file mode 100644 index 00000000..e0e69085 --- /dev/null +++ b/tests/cassettes/predictions-get.yaml @@ -0,0 +1,79 @@ +interactions: +- request: + body: null + headers: + Accept: + - '*/*' + Accept-Encoding: + - gzip, deflate + Connection: + - keep-alive + User-Agent: + - replicate-python/0.11.0 + method: GET + uri: https://api.replicate.com/v1/predictions/vgcm4plb7tgzlyznry5d5jkgvu + response: + body: + string: "{\"completed_at\":\"2023-08-16T18:57:12.170420Z\",\"created_at\":\"2023-08-16T18:57:08.394251Z\",\"error\":null,\"id\":\"vgcm4plb7tgzlyznry5d5jkgvu\",\"input\":{\"seed\":42069,\"width\":512,\"height\":512,\"prompt\":\"a + studio photo of a rainbow colored corgi\"},\"logs\":\"Using seed: 42069\\nPrompt: + a studio photo of a rainbow colored corgi\\ntxt2img mode\\n 0%| | + 0/50 [00:00"}' + headers: + Accept: + - '*/*' + Accept-Encoding: + - gzip, deflate + Connection: + - keep-alive + Content-Length: + - '148' + Content-Type: + - application/json + User-Agent: + - replicate-python/0.11.0 + method: POST + uri: https://api.replicate.com/v1/models/stability-ai/sdxl/versions/a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5/trainings + response: + body: + string: '{"detail":"The specified training destination does not exist","status":404} + + ' + headers: + CF-Cache-Status: + - DYNAMIC + CF-RAY: + - 7fdbeebbf9d830ec-SEA + Connection: + - keep-alive + Content-Length: + - '76' + Content-Type: + - application/problem+json + Date: + - Mon, 28 Aug 2023 10:39:49 GMT + NEL: + - '{"success_fraction":0,"report_to":"cf-nel","max_age":604800}' + Report-To: + - '{"endpoints":[{"url":"https:\/\/a.nel.cloudflare.com\/report\/v3?s=taj3Swm7Drwon8liNhr2%2Bm4jj4m92erkPAH%2FZyxdYXPHJVKtA554hzQFf6Bn7YIfZcEjJhV6nF8A7Nk0EongoR4faHt%2FPQTGVB69L2mgoYimRG0EM2Iid3Iu1%2BRP1lpBtZ4f"}],"group":"cf-nel","max_age":604800}' + Server: + - cloudflare + Strict-Transport-Security: + - max-age=15552000 + ratelimit-remaining: + - '2999' + ratelimit-reset: + - '1' + via: + - 1.1 google + status: + code: 404 + message: Not Found +version: 1 diff --git a/tests/cassettes/trainings-get.yaml b/tests/cassettes/trainings-get.yaml new file mode 100644 index 00000000..1f908055 --- /dev/null +++ b/tests/cassettes/trainings-get.yaml @@ -0,0 +1,575 @@ +interactions: +- request: + body: null + headers: + Accept: + - '*/*' + Accept-Encoding: + - gzip, deflate + Connection: + - keep-alive + User-Agent: + - replicate-python/0.11.0 + method: GET + uri: https://api.replicate.com/v1/trainings/67wuwytbabl5hxbvdfcsovdn3m + response: + body: + string: "{\"completed_at\":\"2023-08-28T10:33:24.486168Z\",\"created_at\":\"2023-08-28T10:32:23.542164Z\",\"error\":null,\"id\":\"67wuwytbabl5hxbvdfcsovdn3m\",\"input\":{\"input_images\":\"https://replicate.delivery/pbxt/JMV5OrEWpBAC5gO8rre0tPOyJIOkaXvG0TWfVJ9b4zhLeEUY/data.zip\",\"use_face_detection_instead\":true},\"logs\":\"['./temp_in/rickroll.jpg']\\nGenerating + 1 captions...\\nInput captioning text: a photo of TOK,\\nTOK\\na photo of + tok, with a camera, is shown in a scene in the video\\nGenerated captions + ['a photo of tok, with a camera, is shown in a scene in the video']\\n 0%| + \ | 0/1 [00:00, with a camera, is shown in a scene in the video\\n# PTI : Loaded + dataset\\n# PTI : Running training\\n# PTI : Num examples = 1\\n# PTI : + \ Num batches each epoch = 1\\n# PTI : Num Epochs = 1000\\n# PTI : Instantaneous + batch size per device = 4\\nTotal train batch size (w. parallel, distributed + & accumulation) = 4\\n# PTI : Gradient Accumulation steps = 1\\n# PTI : Total + optimization steps = 1000\\n 0%| | 0/1000 [00:00, this is the image of a man, with a microphone and a suit\\n# PTI + : Loaded dataset\\n# PTI : Running training\\n# PTI : Num examples = 1\\n# + PTI : Num batches each epoch = 1\\n# PTI : Num Epochs = 1000\\n# PTI : Instantaneous + batch size per device = 4\\nTotal train batch size (w. parallel, distributed + & accumulation) = 4\\n# PTI : Gradient Accumulation steps = 1\\n# PTI : Total + optimization steps = 1000\\n 0%| | 0/1000 [00:00 0 + assert output[0].startswith("https://") + + +@pytest.mark.vcr +def test_run_with_invalid_identifier(mock_replicate_api_token): + with pytest.raises(ReplicateError): + replicate.run("invalid") diff --git a/tests/test_training.py b/tests/test_training.py index b74938db..4b7bb3b9 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -1,175 +1,73 @@ -import responses -from responses import matchers - -from .factories import create_client, create_version - - -@responses.activate -def test_create_works_with_webhooks(): - client = create_client() - version = create_version(client) - - rsp = responses.post( - "https://api.replicate.com/v1/models/owner/model/versions/v1/trainings", - match=[ - matchers.json_params_matcher( - { - "input": {"data": "..."}, - "destination": "new_owner/new_model", - "webhook": "https://example.com/webhook", - "webhook_events_filter": ["completed"], - } - ), - ], - json={ - "id": "t1", - "version": "v1", - "urls": { - "get": "https://api.replicate.com/v1/trainings/t1", - "cancel": "https://api.replicate.com/v1/trainings/t1/cancel", - }, - "created_at": "2022-04-26T20:00:40.658234Z", - "completed_at": "2022-04-26T20:02:27.648305Z", - "source": "api", - "status": "processing", - "input": {"data": "..."}, - "output": None, - "error": None, - "logs": "", - }, - ) +import pytest - client.trainings.create( - version=f"owner/model:{version.id}", - input={"data": "..."}, - destination="new_owner/new_model", - webhook="https://example.com/webhook", - webhook_events_filter=["completed"], - ) +import replicate +from replicate.exceptions import ReplicateException - assert rsp.call_count == 1 - - -@responses.activate -def test_cancel(): - client = create_client() - version = create_version(client) - - responses.post( - "https://api.replicate.com/v1/models/owner/model/versions/v1/trainings", - match=[ - matchers.json_params_matcher( - { - "input": {"data": "..."}, - "destination": "new_owner/new_model", - "webhook": "https://example.com/webhook", - "webhook_events_filter": ["completed"], - } - ), - ], - json={ - "id": "t1", - "version": "v1", - "urls": { - "get": "https://api.replicate.com/v1/trainings/t1", - "cancel": "https://api.replicate.com/v1/trainings/t1/cancel", - }, - "created_at": "2022-04-26T20:00:40.658234Z", - "completed_at": "2022-04-26T20:02:27.648305Z", - "source": "api", - "status": "processing", - "input": {"data": "..."}, - "output": None, - "error": None, - "logs": "", - }, - ) +input_images_url = "https://replicate.delivery/pbxt/JMV5OrEWpBAC5gO8rre0tPOyJIOkaXvG0TWfVJ9b4zhLeEUY/data.zip" - training = client.trainings.create( - version=f"owner/model:{version.id}", - input={"data": "..."}, - destination="new_owner/new_model", - webhook="https://example.com/webhook", - webhook_events_filter=["completed"], - ) - rsp = responses.post("https://api.replicate.com/v1/trainings/t1/cancel", json={}) - training.cancel() - assert rsp.call_count == 1 - - -@responses.activate -def test_async_timings(): - client = create_client() - version = create_version(client) - - responses.post( - "https://api.replicate.com/v1/models/owner/model/versions/v1/trainings", - match=[ - matchers.json_params_matcher( - { - "input": {"data": "..."}, - "destination": "new_owner/new_model", - "webhook": "https://example.com/webhook", - "webhook_events_filter": ["completed"], - } - ), - ], - json={ - "id": "t1", - "version": "v1", - "urls": { - "get": "https://api.replicate.com/v1/trainings/t1", - "cancel": "https://api.replicate.com/v1/trainings/t1/cancel", - }, - "created_at": "2022-04-26T20:00:40.658234Z", - "source": "api", - "status": "processing", - "input": {"data": "..."}, - "output": None, - "error": None, - "logs": "", +@pytest.mark.vcr("trainings-create.yaml") +@pytest.mark.asyncio +async def test_trainings_create(mock_replicate_api_token): + training = replicate.trainings.create( + "stability-ai/sdxl:a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5", + input={ + "input_images": input_images_url, + "use_face_detection_instead": True, }, + destination="replicate/dreambooth-sdxl", ) - responses.get( - "https://api.replicate.com/v1/trainings/t1", - json={ - "id": "t1", - "version": "v1", - "urls": { - "get": "https://api.replicate.com/v1/trainings/t1", - "cancel": "https://api.replicate.com/v1/trainings/t1/cancel", - }, - "created_at": "2022-04-26T20:00:40.658234Z", - "completed_at": "2022-04-26T20:02:27.648305Z", - "source": "api", - "status": "succeeded", - "input": {"data": "..."}, - "output": { - "weights": "https://delivery.replicate.com/weights.tgz", - "version": "v2", + assert training.id is not None + assert training.status == "starting" + + +@pytest.mark.vcr("trainings-create__invalid-destination.yaml") +@pytest.mark.asyncio +async def test_trainings_create_with_invalid_destination(mock_replicate_api_token): + with pytest.raises(ReplicateException): + replicate.trainings.create( + "stability-ai/sdxl:a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5", + input={ + "input_images": input_images_url, }, - "error": None, - "logs": "", - }, - ) + destination="", + ) + + +@pytest.mark.vcr("trainings-get.yaml") +@pytest.mark.asyncio +async def test_trainings_get(mock_replicate_api_token): + id = "ckcbvmtbvg6di3b3uhvccytnfm" + + training = replicate.trainings.get(id) - training = client.trainings.create( - version=f"owner/model:{version.id}", - input={"data": "..."}, - destination="new_owner/new_model", - webhook="https://example.com/webhook", - webhook_events_filter=["completed"], + assert training.id == id + assert training.status == "processing" + + +@pytest.mark.vcr("trainings-cancel.yaml") +@pytest.mark.asyncio +async def test_trainings_cancel(mock_replicate_api_token): + input = { + "input_images": input_images_url, + "use_face_detection_instead": True, + } + + destination = "replicate/dreambooth-sdxl" + + training = replicate.trainings.create( + "stability-ai/sdxl:a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5", + destination=destination, + input=input, ) - assert training.created_at == "2022-04-26T20:00:40.658234Z" - assert training.completed_at is None - assert training.output is None + id = training.id + assert training.status == "starting" - # trainings don't have a wait method, so simulate it by calling reload + # training = replicate.trainings.cancel(training) + training.cancel() training.reload() - assert training.created_at == "2022-04-26T20:00:40.658234Z" - assert training.completed_at == "2022-04-26T20:02:27.648305Z" - assert training.output["weights"] == "https://delivery.replicate.com/weights.tgz" - assert training.output["version"] == "v2" + + assert training.id == id + assert training.status == "canceled" diff --git a/tests/test_version.py b/tests/test_version.py deleted file mode 100644 index fb08ec50..00000000 --- a/tests/test_version.py +++ /dev/null @@ -1,265 +0,0 @@ -from collections.abc import Iterable - -import pytest -import responses -from responses import matchers - -from replicate.exceptions import ModelError - -from .factories import ( - create_version, - create_version_with_iterator_output, - create_version_with_iterator_output_backwards_compatibility_0_3_8, - create_version_with_list_output, -) - - -@responses.activate -def test_predict(): - version = create_version() - - responses.post( - "https://api.replicate.com/v1/predictions", - match=[ - matchers.json_params_matcher({"version": "v1", "input": {"text": "world"}}) - ], - json={ - "id": "p1", - "version": "v1", - "urls": { - "get": "https://api.replicate.com/v1/predictions/p1", - "cancel": "https://api.replicate.com/v1/predictions/p1/cancel", - }, - "created_at": "2022-04-26T20:00:40.658234Z", - "completed_at": "2022-04-26T20:02:27.648305Z", - "source": "api", - "status": "processing", - "input": {"text": "world"}, - "output": None, - "error": None, - "logs": "", - }, - ) - responses.get( - "https://api.replicate.com/v1/predictions/p1", - json={ - "id": "p1", - "version": "v1", - "urls": { - "get": "https://api.replicate.com/v1/predictions/p1", - "cancel": "https://api.replicate.com/v1/predictions/p1/cancel", - }, - "created_at": "2022-04-26T20:00:40.658234Z", - "completed_at": "2022-04-26T20:02:27.648305Z", - "source": "api", - "status": "succeeded", - "input": {"text": "world"}, - "output": "hello world", - "error": None, - "logs": "", - }, - ) - - assert version.predict(text="world") == "hello world" - - -@responses.activate -def test_predict_with_iterator(): - version = create_version_with_iterator_output() - responses.post( - "https://api.replicate.com/v1/predictions", - match=[ - matchers.json_params_matcher({"version": "v1", "input": {"text": "world"}}) - ], - json={ - "id": "p1", - "version": "v1", - "urls": { - "get": "https://api.replicate.com/v1/predictions/p1", - "cancel": "https://api.replicate.com/v1/predictions/p1/cancel", - }, - "created_at": "2022-04-26T20:00:40.658234Z", - "completed_at": "2022-04-26T20:02:27.648305Z", - "source": "api", - "status": "processing", - "input": {"text": "world"}, - "output": None, - "error": None, - "logs": "", - }, - ) - responses.get( - "https://api.replicate.com/v1/predictions/p1", - json={ - "id": "p1", - "version": "v1", - "urls": { - "get": "https://api.replicate.com/v1/predictions/p1", - "cancel": "https://api.replicate.com/v1/predictions/p1/cancel", - }, - "created_at": "2022-04-26T20:00:40.658234Z", - "completed_at": "2022-04-26T20:02:27.648305Z", - "source": "api", - "status": "succeeded", - "input": {"text": "world"}, - "output": ["hello world"], - "error": None, - "logs": "", - }, - ) - - output = version.predict(text="world") - assert isinstance(output, Iterable) - assert list(output) == ["hello world"] - - -@responses.activate -def test_predict_with_list(): - version = create_version_with_list_output() - responses.post( - "https://api.replicate.com/v1/predictions", - match=[ - matchers.json_params_matcher({"version": "v1", "input": {"text": "world"}}) - ], - json={ - "id": "p1", - "version": "v1", - "urls": { - "get": "https://api.replicate.com/v1/predictions/p1", - "cancel": "https://api.replicate.com/v1/predictions/p1/cancel", - }, - "created_at": "2022-04-26T20:00:40.658234Z", - "completed_at": "2022-04-26T20:02:27.648305Z", - "source": "api", - "status": "processing", - "input": {"text": "world"}, - "output": None, - "error": None, - "logs": "", - }, - ) - responses.get( - "https://api.replicate.com/v1/predictions/p1", - json={ - "id": "p1", - "version": "v1", - "urls": { - "get": "https://api.replicate.com/v1/predictions/p1", - "cancel": "https://api.replicate.com/v1/predictions/p1/cancel", - }, - "created_at": "2022-04-26T20:00:40.658234Z", - "completed_at": "2022-04-26T20:02:27.648305Z", - "source": "api", - "status": "succeeded", - "input": {"text": "world"}, - "output": ["hello world"], - "error": None, - "logs": "", - }, - ) - - output = version.predict(text="world") - assert isinstance(output, list) - assert output == ["hello world"] - - -@responses.activate -def test_predict_with_iterator_backwards_compatibility_cog_0_3_8(): - version = create_version_with_iterator_output_backwards_compatibility_0_3_8() - responses.post( - "https://api.replicate.com/v1/predictions", - match=[ - matchers.json_params_matcher({"version": "v1", "input": {"text": "world"}}) - ], - json={ - "id": "p1", - "version": "v1", - "urls": { - "get": "https://api.replicate.com/v1/predictions/p1", - "cancel": "https://api.replicate.com/v1/predictions/p1/cancel", - }, - "created_at": "2022-04-26T20:00:40.658234Z", - "completed_at": "2022-04-26T20:02:27.648305Z", - "source": "api", - "status": "processing", - "input": {"text": "world"}, - "output": None, - "error": None, - "logs": "", - }, - ) - responses.get( - "https://api.replicate.com/v1/predictions/p1", - json={ - "id": "p1", - "version": "v1", - "urls": { - "get": "https://api.replicate.com/v1/predictions/p1", - "cancel": "https://api.replicate.com/v1/predictions/p1/cancel", - }, - "created_at": "2022-04-26T20:00:40.658234Z", - "completed_at": "2022-04-26T20:02:27.648305Z", - "source": "api", - "status": "succeeded", - "input": {"text": "world"}, - "output": ["hello world"], - "error": None, - "logs": "", - }, - ) - - output = version.predict(text="world") - assert isinstance(output, Iterable) - assert list(output) == ["hello world"] - - -@responses.activate -def test_predict_with_iterator_with_failed_prediction(): - version = create_version_with_iterator_output() - responses.post( - "https://api.replicate.com/v1/predictions", - match=[ - matchers.json_params_matcher({"version": "v1", "input": {"text": "world"}}) - ], - json={ - "id": "p1", - "version": "v1", - "urls": { - "get": "https://api.replicate.com/v1/predictions/p1", - "cancel": "https://api.replicate.com/v1/predictions/p1/cancel", - }, - "created_at": "2022-04-26T20:00:40.658234Z", - "completed_at": "2022-04-26T20:02:27.648305Z", - "source": "api", - "status": "processing", - "input": {"text": "world"}, - "output": None, - "error": None, - "logs": "", - }, - ) - responses.get( - "https://api.replicate.com/v1/predictions/p1", - json={ - "id": "p1", - "version": "v1", - "urls": { - "get": "https://api.replicate.com/v1/predictions/p1", - "cancel": "https://api.replicate.com/v1/predictions/p1/cancel", - }, - "created_at": "2022-04-26T20:00:40.658234Z", - "completed_at": "2022-04-26T20:02:27.648305Z", - "source": "api", - "status": "failed", - "input": {"text": "world"}, - "output": None, - "error": "it broke", - "logs": "", - }, - ) - - output = version.predict(text="world") - assert isinstance(output, Iterable) - with pytest.raises(ModelError) as excinfo: - list(output) - assert "it broke" in str(excinfo.value)