diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 00000000..88f170ab --- /dev/null +++ b/.gitattributes @@ -0,0 +1 @@ +tests/cassettes/** binary diff --git a/pyproject.toml b/pyproject.toml index 5de320dc..b2799750 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,12 +10,14 @@ readme = "README.md" license = { file = "LICENSE" } authors = [{ name = "Replicate, Inc." }] requires-python = ">=3.8" -dependencies = ["packaging", "pydantic>1", "requests>2"] +dependencies = ["packaging", "pydantic>1", "httpx>=0.21.0,<1"] optional-dependencies = { dev = [ "black", "mypy", "pytest", - "responses", + "pytest-asyncio", + "pytest-recording", + "respx", "ruff", ] } diff --git a/replicate/client.py b/replicate/client.py index e78296a4..6768aefd 100644 --- a/replicate/client.py +++ b/replicate/client.py @@ -1,106 +1,77 @@ import os +import random import re -from json import JSONDecodeError -from typing import Any, Dict, Iterator, Optional, Union - -import requests -from requests.adapters import HTTPAdapter, Retry -from requests.cookies import RequestsCookieJar - -from replicate.__about__ import __version__ -from replicate.deployment import DeploymentCollection -from replicate.exceptions import ModelError, ReplicateError -from replicate.model import ModelCollection -from replicate.prediction import PredictionCollection -from replicate.training import TrainingCollection +import time +from datetime import datetime +from typing import ( + Any, + Iterable, + Iterator, + Mapping, + Optional, + Union, +) + +import httpx + +from .__about__ import __version__ +from .deployment import DeploymentCollection +from .exceptions import ModelError, ReplicateError +from .model import ModelCollection +from .prediction import PredictionCollection +from .training import TrainingCollection class Client: - def __init__(self, api_token: Optional[str] = None) -> None: + """A Replicate API client library""" + + def __init__( + self, + api_token: Optional[str] = None, + *, + base_url: Optional[str] = None, + timeout: Optional[httpx.Timeout] = None, + **kwargs, + ) -> None: super().__init__() - # Client is instantiated at import time, so do as little as possible. - # This includes resolving environment variables -- they might be set programmatically. - self.api_token = api_token - self.base_url = os.environ.get( + + api_token = api_token or os.environ.get("REPLICATE_API_TOKEN") + + base_url = base_url or os.environ.get( "REPLICATE_API_BASE_URL", "https://api.replicate.com" ) - self.poll_interval = float(os.environ.get("REPLICATE_POLL_INTERVAL", "0.5")) - # TODO: make thread safe - self.read_session = _create_session() - read_retries = Retry( - total=5, - backoff_factor=2, - # Only retry 500s on GET so we don't unintionally mutute data - allowed_methods=["GET"], - # https://support.cloudflare.com/hc/en-us/articles/115003011431-Troubleshooting-Cloudflare-5XX-errors - status_forcelist=[ - 429, - 500, - 502, - 503, - 504, - 520, - 521, - 522, - 523, - 524, - 526, - 527, - ], - ) - self.read_session.mount("http://", HTTPAdapter(max_retries=read_retries)) - self.read_session.mount("https://", HTTPAdapter(max_retries=read_retries)) - - self.write_session = _create_session() - write_retries = Retry( - total=5, - backoff_factor=2, - allowed_methods=["POST", "PUT"], - # Only retry POST/PUT requests on rate limits, so we don't unintionally mutute data - status_forcelist=[429], + timeout = timeout or httpx.Timeout( + 5.0, read=30.0, write=30.0, connect=5.0, pool=10.0 ) - self.write_session.mount("http://", HTTPAdapter(max_retries=write_retries)) - self.write_session.mount("https://", HTTPAdapter(max_retries=write_retries)) - - def _request(self, method: str, path: str, **kwargs) -> requests.Response: - # from requests.Session - if method in ["GET", "OPTIONS"]: - kwargs.setdefault("allow_redirects", True) - if method in ["HEAD"]: - kwargs.setdefault("allow_redirects", False) - kwargs.setdefault("headers", {}) - kwargs["headers"].update(self._headers()) - session = self.read_session - if method in ["POST", "PUT", "DELETE", "PATCH"]: - session = self.write_session - resp = session.request(method, self.base_url + path, **kwargs) - if 400 <= resp.status_code < 600: - try: - raise ReplicateError(resp.json()["detail"]) - except (JSONDecodeError, KeyError): - pass - raise ReplicateError(f"HTTP error: {resp.status_code, resp.reason}") - return resp - def _headers(self) -> Dict[str, str]: - return { - "Authorization": f"Token {self._api_token()}", + self.poll_interval = float(os.environ.get("REPLICATE_POLL_INTERVAL", "0.5")) + + headers = { + "Authorization": f"Token {api_token}", "User-Agent": f"replicate-python/{__version__}", } - def _api_token(self) -> str: - token = self.api_token - # Evaluate lazily in case environment variable is set with dotenv, or something - if token is None: - token = os.environ.get("REPLICATE_API_TOKEN") - if not token: - raise ReplicateError( - """No API token provided. You need to set the REPLICATE_API_TOKEN environment variable or create a client with `replicate.Client(api_token=...)`. + transport = kwargs.pop("transport", httpx.HTTPTransport()) -You can find your API key on https://replicate.com""" - ) - return token + self._client = self._build_client( + **kwargs, + base_url=base_url, + headers=headers, + timeout=timeout, + transport=RetryTransport(wrapped_transport=transport), + ) + + def _build_client(self, **kwargs) -> httpx.Client: + return httpx.Client(**kwargs) + + def _request(self, method: str, path: str, **kwargs) -> httpx.Response: + resp = self._client.request(method, path, **kwargs) + + if 400 <= resp.status_code < 600: + raise ReplicateError(resp.json()["detail"]) + + return resp @property def models(self) -> ModelCollection: @@ -152,19 +123,129 @@ def run(self, model_version: str, **kwargs) -> Union[Any, Iterator[Any]]: return prediction.output -class _NonpersistentCookieJar(RequestsCookieJar): - """ - A cookie jar that doesn't persist cookies between requests. +# Adapted from https://github.com/encode/httpx/issues/108#issuecomment-1132753155 +class RetryTransport(httpx.AsyncBaseTransport, httpx.BaseTransport): + """A custom HTTP transport that automatically retries requests using an exponential backoff strategy + for specific HTTP status codes and request methods. """ - def set(self, name, value, **kwargs) -> None: - return + RETRYABLE_METHODS = frozenset(["HEAD", "GET", "PUT", "DELETE", "OPTIONS", "TRACE"]) + RETRYABLE_STATUS_CODES = frozenset( + [ + 429, # Too Many Requests + 503, # Service Unavailable + 504, # Gateway Timeout + ] + ) + MAX_BACKOFF_WAIT = 60 + + def __init__( + self, + wrapped_transport: Union[httpx.BaseTransport, httpx.AsyncBaseTransport], + max_attempts: int = 10, + max_backoff_wait: float = MAX_BACKOFF_WAIT, + backoff_factor: float = 0.1, + jitter_ratio: float = 0.1, + retryable_methods: Optional[Iterable[str]] = None, + retry_status_codes: Optional[Iterable[int]] = None, + ) -> None: + self._wrapped_transport = wrapped_transport + + if jitter_ratio < 0 or jitter_ratio > 0.5: + raise ValueError( + f"jitter ratio should be between 0 and 0.5, actual {jitter_ratio}" + ) + + self.max_attempts = max_attempts + self.backoff_factor = backoff_factor + self.retryable_methods = ( + frozenset(retryable_methods) + if retryable_methods + else self.RETRYABLE_METHODS + ) + self.retry_status_codes = ( + frozenset(retry_status_codes) + if retry_status_codes + else self.RETRYABLE_STATUS_CODES + ) + self.jitter_ratio = jitter_ratio + self.max_backoff_wait = max_backoff_wait + + def _calculate_sleep( + self, attempts_made: int, headers: Union[httpx.Headers, Mapping[str, str]] + ) -> float: + retry_after_header = (headers.get("Retry-After") or "").strip() + if retry_after_header: + if retry_after_header.isdigit(): + return float(retry_after_header) + + try: + parsed_date = datetime.fromisoformat(retry_after_header).astimezone() + diff = (parsed_date - datetime.now().astimezone()).total_seconds() + if diff > 0: + return min(diff, self.max_backoff_wait) + except ValueError: + pass + + backoff = self.backoff_factor * (2 ** (attempts_made - 1)) + jitter = (backoff * self.jitter_ratio) * random.choice([1, -1]) # noqa: S311 + total_backoff = backoff + jitter + return min(total_backoff, self.max_backoff_wait) + + def handle_request(self, request: httpx.Request) -> httpx.Response: + response = self._wrapped_transport.handle_request(request) # type: ignore + + if request.method not in self.retryable_methods: + return response + + remaining_attempts = self.max_attempts - 1 + attempts_made = 1 + + while True: + if ( + remaining_attempts < 1 + or response.status_code not in self.retry_status_codes + ): + return response + + response.close() + + sleep_for = self._calculate_sleep(attempts_made, response.headers) + time.sleep(sleep_for) + + response = self._wrapped_transport.handle_request(request) # type: ignore + + attempts_made += 1 + remaining_attempts -= 1 + + async def handle_async_request(self, request: httpx.Request) -> httpx.Response: + response = await self._wrapped_transport.handle_async_request(request) # type: ignore + + if request.method not in self.retryable_methods: + return response + + remaining_attempts = self.max_attempts - 1 + attempts_made = 1 + + while True: + if ( + remaining_attempts < 1 + or response.status_code not in self.retry_status_codes + ): + return response + + response.close() + + sleep_for = self._calculate_sleep(attempts_made, response.headers) + time.sleep(sleep_for) + + response = await self._wrapped_transport.handle_async_request(request) # type: ignore - def set_cookie(self, cookie, *args, **kwargs) -> None: - return + attempts_made += 1 + remaining_attempts -= 1 + async def aclose(self) -> None: + await self._wrapped_transport.aclose() # type: ignore -def _create_session() -> requests.Session: - s = requests.Session() - s.cookies = _NonpersistentCookieJar() - return s + def close(self) -> None: + self._wrapped_transport.close() # type: ignore diff --git a/replicate/files.py b/replicate/files.py index 27dbb6db..394d589c 100644 --- a/replicate/files.py +++ b/replicate/files.py @@ -4,7 +4,7 @@ import os from typing import Optional -import requests +import httpx def upload_file(fh: io.IOBase, output_file_prefix: Optional[str] = None) -> str: @@ -24,7 +24,7 @@ def upload_file(fh: io.IOBase, output_file_prefix: Optional[str] = None) -> str: if output_file_prefix is not None: name = getattr(fh, "name", "output") url = output_file_prefix + os.path.basename(name) - resp = requests.put(url, files={"file": fh}, timeout=None) + resp = httpx.put(url, files={"file": fh}, timeout=None) # type: ignore resp.raise_for_status() return url diff --git a/requirements-dev.txt b/requirements-dev.txt index 3f170d83..5906f987 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -6,18 +6,33 @@ # annotated-types==0.5.0 # via pydantic +anyio==3.7.1 + # via httpcore black==23.7.0 # via replicate (pyproject.toml) certifi==2023.7.22 - # via requests -charset-normalizer==3.2.0 - # via requests + # via + # httpcore + # httpx click==8.1.6 # via black +h11==0.14.0 + # via httpcore +httpcore==0.17.3 + # via httpx +httpx==0.24.1 + # via + # replicate (pyproject.toml) + # respx idna==3.4 - # via requests + # via + # anyio + # httpx + # yarl iniconfig==2.0.0 # via pytest +multidict==6.0.4 + # via yarl mypy==1.4.1 # via replicate (pyproject.toml) mypy-extensions==1.0.0 @@ -40,25 +55,33 @@ pydantic==2.0.3 pydantic-core==2.3.0 # via pydantic pytest==7.4.0 - # via replicate (pyproject.toml) -pyyaml==6.0.1 - # via responses -requests==2.31.0 # via + # pytest-asyncio + # pytest-recording # replicate (pyproject.toml) - # responses -responses==0.23.1 +pytest-asyncio==0.21.1 + # via replicate (pyproject.toml) +pytest-recording==0.13.0 + # via replicate (pyproject.toml) +pyyaml==6.0.1 + # via vcrpy +respx==0.20.2 # via replicate (pyproject.toml) ruff==0.0.278 # via replicate (pyproject.toml) -types-pyyaml==6.0.12.10 - # via responses +sniffio==1.3.0 + # via + # anyio + # httpcore + # httpx typing-extensions==4.7.1 # via # mypy # pydantic # pydantic-core -urllib3==2.0.3 - # via - # requests - # responses +vcrpy==5.1.0 + # via pytest-recording +wrapt==1.15.0 + # via vcrpy +yarl==1.9.2 + # via vcrpy diff --git a/requirements.txt b/requirements.txt index 1a221196..6fac3e28 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,23 +6,34 @@ # annotated-types==0.5.0 # via pydantic +anyio==3.7.1 + # via httpcore certifi==2023.7.22 - # via requests -charset-normalizer==3.2.0 - # via requests + # via + # httpcore + # httpx +h11==0.14.0 + # via httpcore +httpcore==0.17.3 + # via httpx +httpx==0.24.1 + # via replicate (pyproject.toml) idna==3.4 - # via requests + # via + # anyio + # httpx packaging==23.1 # via replicate (pyproject.toml) pydantic==2.0.3 # via replicate (pyproject.toml) pydantic-core==2.3.0 # via pydantic -requests==2.31.0 - # via replicate (pyproject.toml) +sniffio==1.3.0 + # via + # anyio + # httpcore + # httpx typing-extensions==4.7.1 # via # pydantic # pydantic-core -urllib3==2.0.3 - # via requests diff --git a/tests/cassettes/predictions-cancel.yaml b/tests/cassettes/predictions-cancel.yaml new file mode 100644 index 00000000..decc5a92 --- /dev/null +++ b/tests/cassettes/predictions-cancel.yaml @@ -0,0 +1,512 @@ +interactions: +- request: + body: '' + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + connection: + - keep-alive + host: + - api.replicate.com + user-agent: + - replicate-python/0.11.0 + method: GET + uri: https://api.replicate.com/v1/models/stability-ai/sdxl + response: + content: "{\"url\":\"https://replicate.com/stability-ai/sdxl\",\"owner\":\"stability-ai\",\"name\":\"sdxl\",\"description\":\"A + text-to-image generative AI model that creates beautiful 1024x1024 images\",\"visibility\":\"public\",\"github_url\":\"https://github.com/Stability-AI/generative-models\",\"paper_url\":\"https://arxiv.org/abs/2307.01952\",\"license_url\":\"https://github.com/Stability-AI/generative-models/blob/main/model_licenses/LICENSE-SDXL1.0\",\"run_count\":918101,\"cover_image_url\":\"https://tjzk.replicate.delivery/models_models_cover_image/61004930-fb88-4e09-9bd4-74fd8b4aa677/sdxl_cover.png\",\"default_example\":{\"completed_at\":\"2023-07-26T21:04:37.933562Z\",\"created_at\":\"2023-07-26T21:04:23.762683Z\",\"error\":null,\"id\":\"vu42q7dbkm6iicbpal4v6uvbqm\",\"input\":{\"width\":1024,\"height\":1024,\"prompt\":\"An + astronaut riding a rainbow unicorn, cinematic, dramatic\",\"refine\":\"expert_ensemble_refiner\",\"scheduler\":\"DDIM\",\"num_outputs\":1,\"guidance_scale\":7.5,\"high_noise_frac\":0.8,\"prompt_strength\":0.8,\"num_inference_steps\":50},\"logs\":\"Using + seed: 12103\\ntxt2img mode\\n 0%| | 0/40 [00:00"}' + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '148' + content-type: + - application/json + host: + - api.replicate.com + user-agent: + - replicate-python/0.11.0 + method: POST + uri: https://api.replicate.com/v1/models/stability-ai/sdxl/versions/a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5/trainings + response: + content: '{"detail":"The specified training destination does not exist","status":404} + + ' + headers: + CF-Cache-Status: + - DYNAMIC + CF-RAY: + - 7f7c2190ed8c281a-SEA + Connection: + - keep-alive + Content-Length: + - '76' + Content-Type: + - application/problem+json + Date: + - Wed, 16 Aug 2023 19:37:18 GMT + NEL: + - '{"success_fraction":0,"report_to":"cf-nel","max_age":604800}' + Report-To: + - '{"endpoints":[{"url":"https:\/\/a.nel.cloudflare.com\/report\/v3?s=0vMWFlGDyffyF0A%2FL4%2FH830OVHnZd0gZDww4oocSSHq7eMAt327ut6v%2B2qAda7fThmH4WcElLTM%2B3PFyrsa1w1SHgfEdWyJSv8TYYi2nWXMqeP5EJc1SDjV958HGKSKDnjH5"}],"group":"cf-nel","max_age":604800}' + Server: + - cloudflare + Strict-Transport-Security: + - max-age=15552000 + ratelimit-remaining: + - '2999' + ratelimit-reset: + - '1' + via: + - 1.1 google + http_version: HTTP/1.1 + status_code: 404 +version: 1 diff --git a/tests/cassettes/trainings-get.yaml b/tests/cassettes/trainings-get.yaml new file mode 100644 index 00000000..77a0e8d5 --- /dev/null +++ b/tests/cassettes/trainings-get.yaml @@ -0,0 +1,73 @@ +interactions: +- request: + body: '' + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + connection: + - keep-alive + host: + - api.replicate.com + user-agent: + - replicate-python/0.11.0 + method: GET + uri: https://api.replicate.com/v1/trainings/medrnz3bm5dd6ultvad2tejrte + response: + content: '{"completed_at":null,"created_at":"2023-08-16T19:33:26.906823Z","error":null,"id":"medrnz3bm5dd6ultvad2tejrte","input":{"input_images":"https://replicate.delivery/pbxt/JMV5OrEWpBAC5gO8rre0tPOyJIOkaXvG0TWfVJ9b4zhLeEUY/data.zip","use_face_detection_instead":true},"logs":null,"metrics":{},"output":null,"started_at":"2023-08-16T19:33:42.114513Z","status":"processing","urls":{"get":"https://api.replicate.com/v1/trainings/medrnz3bm5dd6ultvad2tejrte","cancel":"https://api.replicate.com/v1/trainings/medrnz3bm5dd6ultvad2tejrte/cancel"},"version":"a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5","webhook_completed":null}' + headers: + CF-Cache-Status: + - DYNAMIC + CF-RAY: + - 7f7c1beaedff279c-SEA + Connection: + - keep-alive + Content-Encoding: + - gzip + Content-Type: + - application/json + Date: + - Wed, 16 Aug 2023 19:33:26 GMT + NEL: + - '{"success_fraction":0,"report_to":"cf-nel","max_age":604800}' + Report-To: + - '{"endpoints":[{"url":"https:\/\/a.nel.cloudflare.com\/report\/v3?s=SntiwLHCR4wiv49Qmn%2BR1ZblcX%2FgoVlIgsek4yZliZiWts2SqPjqTjrSkB%2Bwch8oHqR%2BBNVs1cSbihlHd8MWPXsbwC2uShz0c6tD4nclaecblb3FnEp4Mccy9hlZ39izF9Tm"}],"group":"cf-nel","max_age":604800}' + Server: + - cloudflare + Strict-Transport-Security: + - max-age=15552000 + Transfer-Encoding: + - chunked + allow: + - OPTIONS, GET + content-security-policy-report-only: + - 'style-src ''report-sample'' ''self'' ''unsafe-inline'' https://fonts.googleapis.com; + img-src ''report-sample'' ''self'' data: https://replicate.delivery https://*.replicate.delivery + https://*.githubusercontent.com https://github.com; worker-src ''none''; media-src + ''report-sample'' ''self'' https://replicate.delivery https://*.replicate.delivery + https://*.mux.com https://*.gstatic.com https://*.sentry.io; connect-src ''report-sample'' + ''self'' https://replicate.delivery https://*.replicate.delivery https://*.rudderlabs.com + https://*.rudderstack.com https://*.mux.com https://*.sentry.io; script-src + ''report-sample'' ''self'' https://cdn.rudderlabs.com/v1.1/rudder-analytics.min.js; + font-src ''report-sample'' ''self'' data: https://fonts.replicate.ai https://fonts.gstatic.com; + default-src ''self''; report-uri' + cross-origin-opener-policy: + - same-origin + ratelimit-remaining: + - '2999' + ratelimit-reset: + - '1' + referrer-policy: + - same-origin + vary: + - Cookie, origin + via: + - 1.1 vegur, 1.1 google + x-content-type-options: + - nosniff + x-frame-options: + - DENY + http_version: HTTP/1.1 + status_code: 200 +version: 1 diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..a29ed640 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,27 @@ +import os +from unittest import mock + +import pytest + + +@pytest.fixture(scope="session") +def mock_replicate_api_token(scope="class"): + if os.environ.get("REPLICATE_API_TOKEN", "") != "": + yield + else: + with mock.patch.dict( + os.environ, + {"REPLICATE_API_TOKEN": "test-token", "REPLICATE_POLL_INTERVAL": "0.0"}, + ): + yield + + +@pytest.fixture(scope="module") +def vcr_config(): + return {"allowed_hosts": ["api.replicate.com"], "filter_headers": ["authorization"]} + + +@pytest.fixture(scope="module") +def vcr_cassette_dir(request): + module = request.node.fspath + return os.path.join(module.dirname, "cassettes") diff --git a/tests/factories.py b/tests/factories.py deleted file mode 100644 index 349fb7f4..00000000 --- a/tests/factories.py +++ /dev/null @@ -1,236 +0,0 @@ -import datetime - -import responses -from responses import matchers - -from replicate.client import Client -from replicate.version import Version - - -def create_client(): - client = Client(api_token="abc123") - return client - - -def get_mock_schema(): - return { - "info": {"title": "Cog", "version": "0.1.0"}, - "paths": { - "/": { - "get": { - "summary": "Root", - "responses": { - "200": { - "content": {"application/json": {"schema": {}}}, - "description": "Successful Response", - } - }, - "operationId": "root__get", - } - }, - "/predictions": { - "post": { - "summary": "Predict", - "responses": { - "200": { - "content": { - "application/json": { - "schema": {"$ref": "#/components/schemas/Response"} - } - }, - "description": "Successful Response", - }, - "422": { - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/HTTPValidationError" - } - } - }, - "description": "Validation Error", - }, - }, - "description": "Run a single prediction on the model", - "operationId": "predict_predictions_post", - "requestBody": { - "content": { - "application/json": { - "schema": {"$ref": "#/components/schemas/Request"} - } - } - }, - } - }, - }, - "openapi": "3.0.2", - "components": { - "schemas": { - "Input": { - "type": "object", - "title": "Input", - "required": ["text"], - "properties": { - "text": { - "type": "string", - "title": "Text", - "x-order": 0, - "description": "Text to prefix with 'hello '", - } - }, - }, - "Output": {"type": "string", "title": "Output"}, - "Status": { - "enum": ["processing", "succeeded", "failed"], - "type": "string", - "title": "Status", - "description": "An enumeration.", - }, - "Request": { - "type": "object", - "title": "Request", - "properties": { - "input": {"$ref": "#/components/schemas/Input"}, - "output_file_prefix": { - "type": "string", - "title": "Output File Prefix", - }, - }, - "description": "The request body for a prediction", - }, - "Response": { - "type": "object", - "title": "Response", - "required": ["status"], - "properties": { - "error": {"type": "string", "title": "Error"}, - "output": {"$ref": "#/components/schemas/Output"}, - "status": {"$ref": "#/components/schemas/Status"}, - }, - "description": "The response body for a prediction", - }, - "ValidationError": { - "type": "object", - "title": "ValidationError", - "required": ["loc", "msg", "type"], - "properties": { - "loc": { - "type": "array", - "items": { - "anyOf": [ - {"type": "string"}, - {"type": "integer"}, - ] - }, - "title": "Location", - }, - "msg": {"type": "string", "title": "Message"}, - "type": {"type": "string", "title": "Error Type"}, - }, - }, - "HTTPValidationError": { - "type": "object", - "title": "HTTPValidationError", - "properties": { - "detail": { - "type": "array", - "items": {"$ref": "#/components/schemas/ValidationError"}, - "title": "Detail", - } - }, - }, - } - }, - } - - -def mock_version_get( - owner="test", model="model", version="v1", openapi_schema=None, cog_version="0.3.9" -): - responses.get( - f"https://api.replicate.com/v1/models/{owner}/{model}/versions/{version}", - match=[ - matchers.header_matcher({"Authorization": "Token abc123"}), - ], - json={ - "id": version, - "created_at": "2022-04-26T19:29:04.418669Z", - "cog_version": "0.3.9", - "openapi_schema": openapi_schema or get_mock_schema(), - }, - ) - - -def mock_version_get_with_iterator_output(**kwargs): - schema = get_mock_schema() - schema["components"]["schemas"]["Output"] = { - "type": "array", - "items": {"type": "string"}, - "title": "Output", - "x-cog-array-type": "iterator", - } - mock_version_get(openapi_schema=schema, cog_version="0.3.9", **kwargs) - - -def mock_version_get_with_list_output(**kwargs): - schema = get_mock_schema() - schema["components"]["schemas"]["Output"] = { - "type": "array", - "items": {"type": "string"}, - "title": "Output", - } - mock_version_get(openapi_schema=schema, cog_version="0.3.9", **kwargs) - - -def mock_version_get_with_iterator_output_backwards_compatibility_0_3_8(**kwargs): - schema = get_mock_schema() - schema["components"]["schemas"]["Output"] = { - "type": "array", - "items": {"type": "string"}, - "title": "Output", - } - mock_version_get(openapi_schema=schema, cog_version="0.3.8", **kwargs) - - -def create_version(client=None, openapi_schema=None, cog_version="0.3.0"): - if client is None: - client = create_client() - version = Version( - id="v1", - created_at=datetime.datetime.now(), - cog_version=cog_version, - openapi_schema=openapi_schema or get_mock_schema(), - ) - version._client = client - return version - - -def create_version_with_iterator_output(): - version = create_version(cog_version="0.3.9") - version.openapi_schema["components"]["schemas"]["Output"] = { - "type": "array", - "items": {"type": "string"}, - "title": "Output", - "x-cog-array-type": "iterator", - } - return version - - -def create_version_with_list_output(): - version = create_version(cog_version="0.3.9") - version.openapi_schema["components"]["schemas"]["Output"] = { - "type": "array", - "items": {"type": "string"}, - "title": "Output", - } - return version - - -def create_version_with_iterator_output_backwards_compatibility_0_3_8(): - version = create_version(cog_version="0.3.8") - version.openapi_schema["components"]["schemas"]["Output"] = { - "type": "array", - "items": {"type": "string"}, - "title": "Output", - } - return version diff --git a/tests/test_client.py b/tests/test_client.py deleted file mode 100644 index 05fffe08..00000000 --- a/tests/test_client.py +++ /dev/null @@ -1,290 +0,0 @@ -from collections.abc import Iterable - -import pytest -import responses -from responses import matchers - -from replicate.__about__ import __version__ -from replicate.client import Client -from replicate.exceptions import ModelError - -from .factories import ( - mock_version_get, - mock_version_get_with_iterator_output, - mock_version_get_with_iterator_output_backwards_compatibility_0_3_8, - mock_version_get_with_list_output, -) - - -@responses.activate -def test_client_sets_authorization_token_and_user_agent_headers(): - client = Client(api_token="abc123") - model = client.models.get("test/model") - - responses.get( - "https://api.replicate.com/v1/models/test/model/versions", - match=[ - matchers.header_matcher({"Authorization": "Token abc123"}), - matchers.header_matcher({"User-Agent": f"replicate-python/{__version__}"}), - ], - json={"results": []}, - ) - - model.versions.list() - - -@responses.activate -def test_run(): - mock_version_get(owner="test", model="model", version="v1") - responses.post( - "https://api.replicate.com/v1/predictions", - match=[ - matchers.json_params_matcher({"version": "v1", "input": {"text": "world"}}) - ], - json={ - "id": "p1", - "version": "v1", - "urls": { - "get": "https://api.replicate.com/v1/predictions/p1", - "cancel": "https://api.replicate.com/v1/predictions/p1/cancel", - }, - "created_at": "2022-04-26T20:00:40.658234Z", - "completed_at": "2022-04-26T20:02:27.648305Z", - "source": "api", - "status": "processing", - "input": {"text": "world"}, - "output": None, - "error": None, - "logs": "", - }, - ) - responses.get( - "https://api.replicate.com/v1/predictions/p1", - json={ - "id": "p1", - "version": "v1", - "urls": { - "get": "https://api.replicate.com/v1/predictions/p1", - "cancel": "https://api.replicate.com/v1/predictions/p1/cancel", - }, - "created_at": "2022-04-26T20:00:40.658234Z", - "completed_at": "2022-04-26T20:02:27.648305Z", - "source": "api", - "status": "succeeded", - "input": {"text": "world"}, - "output": "hello world", - "error": None, - "logs": "", - }, - ) - - client = Client(api_token="abc123") - assert client.run("test/model:v1", input={"text": "world"}) == "hello world" - - -@responses.activate -def test_run_with_iterator(): - mock_version_get_with_iterator_output(owner="test", model="model", version="v1") - responses.post( - "https://api.replicate.com/v1/predictions", - match=[ - matchers.json_params_matcher({"version": "v1", "input": {"text": "world"}}) - ], - json={ - "id": "p1", - "version": "v1", - "urls": { - "get": "https://api.replicate.com/v1/predictions/p1", - "cancel": "https://api.replicate.com/v1/predictions/p1/cancel", - }, - "created_at": "2022-04-26T20:00:40.658234Z", - "completed_at": "2022-04-26T20:02:27.648305Z", - "source": "api", - "status": "processing", - "input": {"text": "world"}, - "output": None, - "error": None, - "logs": "", - }, - ) - responses.get( - "https://api.replicate.com/v1/predictions/p1", - json={ - "id": "p1", - "version": "v1", - "urls": { - "get": "https://api.replicate.com/v1/predictions/p1", - "cancel": "https://api.replicate.com/v1/predictions/p1/cancel", - }, - "created_at": "2022-04-26T20:00:40.658234Z", - "completed_at": "2022-04-26T20:02:27.648305Z", - "source": "api", - "status": "succeeded", - "input": {"text": "world"}, - "output": ["hello world"], - "error": None, - "logs": "", - }, - ) - - client = Client(api_token="abc123") - output = client.run("test/model:v1", input={"text": "world"}) - assert isinstance(output, Iterable) - assert list(output) == ["hello world"] - - -@responses.activate -def test_run_with_list(): - mock_version_get_with_list_output(owner="test", model="model", version="v1") - responses.post( - "https://api.replicate.com/v1/predictions", - match=[ - matchers.json_params_matcher({"version": "v1", "input": {"text": "world"}}) - ], - json={ - "id": "p1", - "version": "v1", - "urls": { - "get": "https://api.replicate.com/v1/predictions/p1", - "cancel": "https://api.replicate.com/v1/predictions/p1/cancel", - }, - "created_at": "2022-04-26T20:00:40.658234Z", - "completed_at": "2022-04-26T20:02:27.648305Z", - "source": "api", - "status": "processing", - "input": {"text": "world"}, - "output": None, - "error": None, - "logs": "", - }, - ) - responses.get( - "https://api.replicate.com/v1/predictions/p1", - json={ - "id": "p1", - "version": "v1", - "urls": { - "get": "https://api.replicate.com/v1/predictions/p1", - "cancel": "https://api.replicate.com/v1/predictions/p1/cancel", - }, - "created_at": "2022-04-26T20:00:40.658234Z", - "completed_at": "2022-04-26T20:02:27.648305Z", - "source": "api", - "status": "succeeded", - "input": {"text": "world"}, - "output": ["hello world"], - "error": None, - "logs": "", - }, - ) - - client = Client(api_token="abc123") - output = client.run("test/model:v1", input={"text": "world"}) - assert isinstance(output, list) - assert output == ["hello world"] - - -@responses.activate -def test_run_with_iterator_backwards_compatibility_cog_0_3_8(): - mock_version_get_with_iterator_output_backwards_compatibility_0_3_8( - owner="test", model="model", version="v1" - ) - responses.post( - "https://api.replicate.com/v1/predictions", - match=[ - matchers.json_params_matcher({"version": "v1", "input": {"text": "world"}}) - ], - json={ - "id": "p1", - "version": "v1", - "urls": { - "get": "https://api.replicate.com/v1/predictions/p1", - "cancel": "https://api.replicate.com/v1/predictions/p1/cancel", - }, - "created_at": "2022-04-26T20:00:40.658234Z", - "completed_at": "2022-04-26T20:02:27.648305Z", - "source": "api", - "status": "processing", - "input": {"text": "world"}, - "output": None, - "error": None, - "logs": "", - }, - ) - responses.get( - "https://api.replicate.com/v1/predictions/p1", - json={ - "id": "p1", - "version": "v1", - "urls": { - "get": "https://api.replicate.com/v1/predictions/p1", - "cancel": "https://api.replicate.com/v1/predictions/p1/cancel", - }, - "created_at": "2022-04-26T20:00:40.658234Z", - "completed_at": "2022-04-26T20:02:27.648305Z", - "source": "api", - "status": "succeeded", - "input": {"text": "world"}, - "output": ["hello world"], - "error": None, - "logs": "", - }, - ) - - client = Client(api_token="abc123") - output = client.run("test/model:v1", input={"text": "world"}) - assert isinstance(output, Iterable) - assert list(output) == ["hello world"] - - -@responses.activate -def test_predict_with_iterator_with_failed_prediction(): - mock_version_get_with_iterator_output(owner="test", model="model", version="v1") - responses.post( - "https://api.replicate.com/v1/predictions", - match=[ - matchers.json_params_matcher({"version": "v1", "input": {"text": "world"}}) - ], - json={ - "id": "p1", - "version": "v1", - "urls": { - "get": "https://api.replicate.com/v1/predictions/p1", - "cancel": "https://api.replicate.com/v1/predictions/p1/cancel", - }, - "created_at": "2022-04-26T20:00:40.658234Z", - "completed_at": "2022-04-26T20:02:27.648305Z", - "source": "api", - "status": "processing", - "input": {"text": "world"}, - "output": None, - "error": None, - "logs": "", - }, - ) - responses.get( - "https://api.replicate.com/v1/predictions/p1", - json={ - "id": "p1", - "version": "v1", - "urls": { - "get": "https://api.replicate.com/v1/predictions/p1", - "cancel": "https://api.replicate.com/v1/predictions/p1/cancel", - }, - "created_at": "2022-04-26T20:00:40.658234Z", - "completed_at": "2022-04-26T20:02:27.648305Z", - "source": "api", - "status": "failed", - "input": {"text": "world"}, - "output": None, - "error": "it broke", - "logs": "", - }, - ) - - client = Client(api_token="abc123") - output = client.run("test/model:v1", input={"text": "world"}) - assert isinstance(output, Iterable) - with pytest.raises(ModelError) as excinfo: - list(output) - assert "it broke" in str(excinfo.value) diff --git a/tests/test_deployment.py b/tests/test_deployment.py index 8ec63a77..1b37444a 100644 --- a/tests/test_deployment.py +++ b/tests/test_deployment.py @@ -1,26 +1,16 @@ -import responses -from responses import matchers +import httpx +import respx from replicate.client import Client - -@responses.activate -def test_deployment_predictions_create(): - client = Client(api_token="abc123") - - deployment = client.deployments.get("test/model") - - rsp = responses.post( - "https://api.replicate.com/v1/deployments/test/model/predictions", - match=[ - matchers.json_params_matcher( - { - "input": {"text": "world"}, - "webhook": "https://example.com/webhook", - "webhook_events_filter": ["completed"], - } - ), - ], +router = respx.Router(base_url="https://api.replicate.com/v1") +router.route( + method="POST", + path="/deployments/test/model/predictions", + name="deployments.predictions.create", +).mock( + return_value=httpx.Response( + 201, json={ "id": "p1", "version": "v1", @@ -31,17 +21,28 @@ def test_deployment_predictions_create(): "created_at": "2022-04-26T20:00:40.658234Z", "source": "api", "status": "processing", - "input": {"text": "hello"}, + "input": {"text": "world"}, "output": None, "error": None, "logs": "", }, ) +) +router.route(host="api.replicate.com").pass_through() + + +def test_deployment_predictions_create(): + client = Client( + api_token="test-token", transport=httpx.MockTransport(router.handler) + ) + deployment = client.deployments.get("test/model") - deployment.predictions.create( + prediction = deployment.predictions.create( input={"text": "world"}, webhook="https://example.com/webhook", webhook_events_filter=["completed"], ) - assert rsp.call_count == 1 + assert router["deployments.predictions.create"].called + assert prediction.id == "p1" + assert prediction.input == {"text": "world"} diff --git a/tests/test_model.py b/tests/test_model.py index 37da743f..e9d6f313 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -1,35 +1,13 @@ -import responses +import pytest -from replicate.client import Client +import replicate -@responses.activate -def test_versions(): - client = Client(api_token="abc123") +@pytest.mark.vcr("models-get.yaml") +@pytest.mark.asyncio +async def test_models_get(mock_replicate_api_token): + model = replicate.models.get("stability-ai/sdxl") - model = client.models.get("test/model") - - responses.get( - "https://api.replicate.com/v1/models/test/model/versions", - json={ - "results": [ - { - "id": "v1", - "created_at": "2022-04-26T19:29:04.418669Z", - "cog_version": "0.3.0", - "openapi_schema": {}, - }, - { - "id": "v2", - "created_at": "2022-03-21T13:01:04.418669Z", - "cog_version": "0.3.0", - "openapi_schema": {}, - }, - ] - }, - ) - - versions = model.versions.list() - assert len(versions) == 2 - assert versions[0].id == "v1" - assert versions[1].id == "v2" + assert model is not None + assert model.username == "stability-ai" + assert model.name == "sdxl" diff --git a/tests/test_prediction.py b/tests/test_prediction.py index 4b330015..dd315701 100644 --- a/tests/test_prediction.py +++ b/tests/test_prediction.py @@ -1,278 +1,245 @@ -import responses -from responses import matchers - -from replicate.prediction import Prediction - -from .factories import create_client, create_version - - -@responses.activate -def test_create_works_with_webhooks(): - client = create_client() - version = create_version(client) - - rsp = responses.post( - "https://api.replicate.com/v1/predictions", - match=[ - matchers.json_params_matcher( - { - "version": "v1", - "input": {"text": "world"}, - "webhook": "https://example.com/webhook", - "webhook_events_filter": ["completed"], - } - ), - ], - json={ - "id": "p1", - "version": "v1", - "urls": { - "get": "https://api.replicate.com/v1/predictions/p1", - "cancel": "https://api.replicate.com/v1/predictions/p1/cancel", - }, - "created_at": "2022-04-26T20:00:40.658234Z", - "completed_at": "2022-04-26T20:02:27.648305Z", - "source": "api", - "status": "processing", - "input": {"text": "world"}, - "output": None, - "error": None, - "logs": "", - }, - ) +import pytest - client.predictions.create( - version=version, - input={"text": "world"}, - webhook="https://example.com/webhook", - webhook_events_filter=["completed"], - ) +import replicate - assert rsp.call_count == 1 - - -@responses.activate -def test_cancel(): - client = create_client() - version = create_version(client) - - responses.post( - "https://api.replicate.com/v1/predictions", - match=[ - matchers.json_params_matcher( - { - "version": "v1", - "input": {"text": "world"}, - "webhook_completed": "https://example.com/webhook", - } - ), - ], - json={ - "id": "p1", - "version": "v1", - "urls": { - "get": "https://api.replicate.com/v1/predictions/p1", - "cancel": "https://api.replicate.com/v1/predictions/p1/cancel", - }, - "created_at": "2022-04-26T20:00:40.658234Z", - "completed_at": "2022-04-26T20:02:27.648305Z", - "source": "api", - "status": "processing", - "input": {"text": "world"}, - "output": None, - "error": None, - "logs": "", - }, - ) - prediction = client.predictions.create( - version=version, - input={"text": "world"}, - webhook_completed="https://example.com/webhook", - ) +@pytest.mark.vcr("predictions-create.yaml") +@pytest.mark.asyncio +async def test_predictions_create(mock_replicate_api_token): + input = { + "prompt": "a studio photo of a rainbow colored corgi", + "width": 512, + "height": 512, + "seed": 42069, + } - rsp = responses.post("https://api.replicate.com/v1/predictions/p1/cancel", json={}) - prediction.cancel() - assert rsp.call_count == 1 - - -@responses.activate -def test_stream(): - client = create_client() - version = create_version(client) - - rsp = responses.post( - "https://api.replicate.com/v1/predictions", - match=[ - matchers.json_params_matcher( - { - "version": "v1", - "input": {"text": "world"}, - "stream": "true", - } - ), - ], - json={ - "id": "p1", - "version": "v1", - "urls": { - "get": "https://api.replicate.com/v1/predictions/p1", - "cancel": "https://api.replicate.com/v1/predictions/p1/cancel", - "stream": "https://streaming.api.replicate.com/v1/predictions/p1", - }, - "created_at": "2022-04-26T20:00:40.658234Z", - "completed_at": "2022-04-26T20:02:27.648305Z", - "source": "api", - "status": "processing", - "input": {"text": "world"}, - "output": None, - "error": None, - "logs": "", - }, + model = replicate.models.get("stability-ai/sdxl") + version = model.versions.get( + "a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5" ) - - prediction = client.predictions.create( + prediction = replicate.predictions.create( version=version, - input={"text": "world"}, - stream=True, + input=input, ) - assert rsp.call_count == 1 + assert prediction.id is not None + assert prediction.version == version + assert prediction.status == "starting" - assert ( - prediction.urls["stream"] - == "https://streaming.api.replicate.com/v1/predictions/p1" - ) +@pytest.mark.vcr("predictions-get.yaml") +@pytest.mark.asyncio +async def test_predictions_get(mock_replicate_api_token): + id = "vgcm4plb7tgzlyznry5d5jkgvu" -@responses.activate -def test_async_timings(): - client = create_client() - version = create_version(client) - - responses.post( - "https://api.replicate.com/v1/predictions", - match=[ - matchers.json_params_matcher( - { - "version": "v1", - "input": {"text": "hello"}, - "webhook_completed": "https://example.com/webhook", - } - ), - ], - json={ - "id": "p1", - "version": "v1", - "urls": { - "get": "https://api.replicate.com/v1/predictions/p1", - "cancel": "https://api.replicate.com/v1/predictions/p1/cancel", - }, - "created_at": "2022-04-26T20:00:40.658234Z", - "source": "api", - "status": "processing", - "input": {"text": "hello"}, - "output": None, - "error": None, - "logs": "", - }, - ) + prediction = replicate.predictions.get(id) + + assert prediction.id == id - responses.get( - "https://api.replicate.com/v1/predictions/p1", - json={ - "id": "p1", - "version": "v1", - "urls": { - "get": "https://api.replicate.com/v1/predictions/p1", - "cancel": "https://api.replicate.com/v1/predictions/p1/cancel", - }, - "created_at": "2022-04-26T20:00:40.658234Z", - "completed_at": "2022-04-26T20:02:27.648305Z", - "source": "api", - "status": "succeeded", - "input": {"text": "hello"}, - "output": "hello world", - "error": None, - "logs": "", - "metrics": { - "predict_time": 1.2345, - }, - }, - ) - prediction = client.predictions.create( +@pytest.mark.vcr("predictions-cancel.yaml") +@pytest.mark.asyncio +async def test_predictions_cancel(mock_replicate_api_token): + input = { + "prompt": "a studio photo of a rainbow colored corgi", + "width": 512, + "height": 512, + "seed": 42069, + } + + model = replicate.models.get("stability-ai/sdxl") + version = model.versions.get( + "a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5" + ) + prediction = replicate.predictions.create( version=version, - input={"text": "hello"}, - webhook_completed="https://example.com/webhook", + input=input, ) - assert prediction.created_at == "2022-04-26T20:00:40.658234Z" - assert prediction.completed_at is None - assert prediction.output is None - assert prediction.urls["get"] == "https://api.replicate.com/v1/predictions/p1" - prediction.wait() - assert prediction.created_at == "2022-04-26T20:00:40.658234Z" - assert prediction.completed_at == "2022-04-26T20:02:27.648305Z" - assert prediction.output == "hello world" - assert prediction.metrics["predict_time"] == 1.2345 - - -def test_prediction_progress(): - client = create_client() - version = create_version(client) - prediction = Prediction( - id="ufawqhfynnddngldkgtslldrkq", version=version, status="starting" - ) + # id = prediction.id + assert prediction.status == "starting" + + # prediction = replicate.predictions.cancel(prediction) + prediction.cancel() + - lines = [ - "Using seed: 12345", - "0%| | 0/5 [00:00 0 + assert output[0].startswith("https://") + + +@pytest.mark.vcr +def test_run_with_invalid_identifier(mock_replicate_api_token): + with pytest.raises(ReplicateError): + replicate.run("invalid") diff --git a/tests/test_training.py b/tests/test_training.py index b74938db..da215749 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -1,175 +1,68 @@ -import responses -from responses import matchers - -from .factories import create_client, create_version - - -@responses.activate -def test_create_works_with_webhooks(): - client = create_client() - version = create_version(client) - - rsp = responses.post( - "https://api.replicate.com/v1/models/owner/model/versions/v1/trainings", - match=[ - matchers.json_params_matcher( - { - "input": {"data": "..."}, - "destination": "new_owner/new_model", - "webhook": "https://example.com/webhook", - "webhook_events_filter": ["completed"], - } - ), - ], - json={ - "id": "t1", - "version": "v1", - "urls": { - "get": "https://api.replicate.com/v1/trainings/t1", - "cancel": "https://api.replicate.com/v1/trainings/t1/cancel", - }, - "created_at": "2022-04-26T20:00:40.658234Z", - "completed_at": "2022-04-26T20:02:27.648305Z", - "source": "api", - "status": "processing", - "input": {"data": "..."}, - "output": None, - "error": None, - "logs": "", - }, - ) +import pytest - client.trainings.create( - version=f"owner/model:{version.id}", - input={"data": "..."}, - destination="new_owner/new_model", - webhook="https://example.com/webhook", - webhook_events_filter=["completed"], - ) +import replicate +from replicate.exceptions import ReplicateException - assert rsp.call_count == 1 - - -@responses.activate -def test_cancel(): - client = create_client() - version = create_version(client) - - responses.post( - "https://api.replicate.com/v1/models/owner/model/versions/v1/trainings", - match=[ - matchers.json_params_matcher( - { - "input": {"data": "..."}, - "destination": "new_owner/new_model", - "webhook": "https://example.com/webhook", - "webhook_events_filter": ["completed"], - } - ), - ], - json={ - "id": "t1", - "version": "v1", - "urls": { - "get": "https://api.replicate.com/v1/trainings/t1", - "cancel": "https://api.replicate.com/v1/trainings/t1/cancel", - }, - "created_at": "2022-04-26T20:00:40.658234Z", - "completed_at": "2022-04-26T20:02:27.648305Z", - "source": "api", - "status": "processing", - "input": {"data": "..."}, - "output": None, - "error": None, - "logs": "", - }, - ) +input_images_url = "https://replicate.delivery/pbxt/JMV5OrEWpBAC5gO8rre0tPOyJIOkaXvG0TWfVJ9b4zhLeEUY/data.zip" - training = client.trainings.create( - version=f"owner/model:{version.id}", - input={"data": "..."}, - destination="new_owner/new_model", - webhook="https://example.com/webhook", - webhook_events_filter=["completed"], - ) - rsp = responses.post("https://api.replicate.com/v1/trainings/t1/cancel", json={}) - training.cancel() - assert rsp.call_count == 1 - - -@responses.activate -def test_async_timings(): - client = create_client() - version = create_version(client) - - responses.post( - "https://api.replicate.com/v1/models/owner/model/versions/v1/trainings", - match=[ - matchers.json_params_matcher( - { - "input": {"data": "..."}, - "destination": "new_owner/new_model", - "webhook": "https://example.com/webhook", - "webhook_events_filter": ["completed"], - } - ), - ], - json={ - "id": "t1", - "version": "v1", - "urls": { - "get": "https://api.replicate.com/v1/trainings/t1", - "cancel": "https://api.replicate.com/v1/trainings/t1/cancel", - }, - "created_at": "2022-04-26T20:00:40.658234Z", - "source": "api", - "status": "processing", - "input": {"data": "..."}, - "output": None, - "error": None, - "logs": "", +@pytest.mark.vcr("trainings-create.yaml") +@pytest.mark.asyncio +async def test_trainings_create(mock_replicate_api_token): + training = replicate.trainings.create( + "stability-ai/sdxl:a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5", + input={ + "input_images": input_images_url, + "use_face_detection_instead": True, }, + destination="replicate/dreambooth-sdxl", ) - responses.get( - "https://api.replicate.com/v1/trainings/t1", - json={ - "id": "t1", - "version": "v1", - "urls": { - "get": "https://api.replicate.com/v1/trainings/t1", - "cancel": "https://api.replicate.com/v1/trainings/t1/cancel", - }, - "created_at": "2022-04-26T20:00:40.658234Z", - "completed_at": "2022-04-26T20:02:27.648305Z", - "source": "api", - "status": "succeeded", - "input": {"data": "..."}, - "output": { - "weights": "https://delivery.replicate.com/weights.tgz", - "version": "v2", + assert training.id is not None + assert training.status == "starting" + + +@pytest.mark.vcr("trainings-create__invalid-destination.yaml") +@pytest.mark.asyncio +async def test_trainings_create_with_invalid_destination(mock_replicate_api_token): + with pytest.raises(ReplicateException): + replicate.trainings.create( + "stability-ai/sdxl:a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5", + input={ + "input_images": input_images_url, }, - "error": None, - "logs": "", - }, - ) + destination="", + ) + + +@pytest.mark.vcr("trainings-get.yaml") +@pytest.mark.asyncio +async def test_trainings_get(mock_replicate_api_token): + id = "medrnz3bm5dd6ultvad2tejrte" + + training = replicate.trainings.get(id) - training = client.trainings.create( - version=f"owner/model:{version.id}", - input={"data": "..."}, - destination="new_owner/new_model", - webhook="https://example.com/webhook", - webhook_events_filter=["completed"], + assert training.id == id + assert training.status == "processing" + + +@pytest.mark.vcr("trainings-cancel.yaml") +@pytest.mark.asyncio +async def test_trainings_cancel(mock_replicate_api_token): + input = { + "input_images": input_images_url, + "use_face_detection_instead": True, + } + + destination = "replicate/dreambooth-sdxl" + + training = replicate.trainings.create( + "stability-ai/sdxl:a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5", + destination=destination, + input=input, ) - assert training.created_at == "2022-04-26T20:00:40.658234Z" - assert training.completed_at is None - assert training.output is None + assert training.status == "starting" - # trainings don't have a wait method, so simulate it by calling reload - training.reload() - assert training.created_at == "2022-04-26T20:00:40.658234Z" - assert training.completed_at == "2022-04-26T20:02:27.648305Z" - assert training.output["weights"] == "https://delivery.replicate.com/weights.tgz" - assert training.output["version"] == "v2" + # training = replicate.trainings.cancel(training) + training.cancel() diff --git a/tests/test_version.py b/tests/test_version.py deleted file mode 100644 index fb08ec50..00000000 --- a/tests/test_version.py +++ /dev/null @@ -1,265 +0,0 @@ -from collections.abc import Iterable - -import pytest -import responses -from responses import matchers - -from replicate.exceptions import ModelError - -from .factories import ( - create_version, - create_version_with_iterator_output, - create_version_with_iterator_output_backwards_compatibility_0_3_8, - create_version_with_list_output, -) - - -@responses.activate -def test_predict(): - version = create_version() - - responses.post( - "https://api.replicate.com/v1/predictions", - match=[ - matchers.json_params_matcher({"version": "v1", "input": {"text": "world"}}) - ], - json={ - "id": "p1", - "version": "v1", - "urls": { - "get": "https://api.replicate.com/v1/predictions/p1", - "cancel": "https://api.replicate.com/v1/predictions/p1/cancel", - }, - "created_at": "2022-04-26T20:00:40.658234Z", - "completed_at": "2022-04-26T20:02:27.648305Z", - "source": "api", - "status": "processing", - "input": {"text": "world"}, - "output": None, - "error": None, - "logs": "", - }, - ) - responses.get( - "https://api.replicate.com/v1/predictions/p1", - json={ - "id": "p1", - "version": "v1", - "urls": { - "get": "https://api.replicate.com/v1/predictions/p1", - "cancel": "https://api.replicate.com/v1/predictions/p1/cancel", - }, - "created_at": "2022-04-26T20:00:40.658234Z", - "completed_at": "2022-04-26T20:02:27.648305Z", - "source": "api", - "status": "succeeded", - "input": {"text": "world"}, - "output": "hello world", - "error": None, - "logs": "", - }, - ) - - assert version.predict(text="world") == "hello world" - - -@responses.activate -def test_predict_with_iterator(): - version = create_version_with_iterator_output() - responses.post( - "https://api.replicate.com/v1/predictions", - match=[ - matchers.json_params_matcher({"version": "v1", "input": {"text": "world"}}) - ], - json={ - "id": "p1", - "version": "v1", - "urls": { - "get": "https://api.replicate.com/v1/predictions/p1", - "cancel": "https://api.replicate.com/v1/predictions/p1/cancel", - }, - "created_at": "2022-04-26T20:00:40.658234Z", - "completed_at": "2022-04-26T20:02:27.648305Z", - "source": "api", - "status": "processing", - "input": {"text": "world"}, - "output": None, - "error": None, - "logs": "", - }, - ) - responses.get( - "https://api.replicate.com/v1/predictions/p1", - json={ - "id": "p1", - "version": "v1", - "urls": { - "get": "https://api.replicate.com/v1/predictions/p1", - "cancel": "https://api.replicate.com/v1/predictions/p1/cancel", - }, - "created_at": "2022-04-26T20:00:40.658234Z", - "completed_at": "2022-04-26T20:02:27.648305Z", - "source": "api", - "status": "succeeded", - "input": {"text": "world"}, - "output": ["hello world"], - "error": None, - "logs": "", - }, - ) - - output = version.predict(text="world") - assert isinstance(output, Iterable) - assert list(output) == ["hello world"] - - -@responses.activate -def test_predict_with_list(): - version = create_version_with_list_output() - responses.post( - "https://api.replicate.com/v1/predictions", - match=[ - matchers.json_params_matcher({"version": "v1", "input": {"text": "world"}}) - ], - json={ - "id": "p1", - "version": "v1", - "urls": { - "get": "https://api.replicate.com/v1/predictions/p1", - "cancel": "https://api.replicate.com/v1/predictions/p1/cancel", - }, - "created_at": "2022-04-26T20:00:40.658234Z", - "completed_at": "2022-04-26T20:02:27.648305Z", - "source": "api", - "status": "processing", - "input": {"text": "world"}, - "output": None, - "error": None, - "logs": "", - }, - ) - responses.get( - "https://api.replicate.com/v1/predictions/p1", - json={ - "id": "p1", - "version": "v1", - "urls": { - "get": "https://api.replicate.com/v1/predictions/p1", - "cancel": "https://api.replicate.com/v1/predictions/p1/cancel", - }, - "created_at": "2022-04-26T20:00:40.658234Z", - "completed_at": "2022-04-26T20:02:27.648305Z", - "source": "api", - "status": "succeeded", - "input": {"text": "world"}, - "output": ["hello world"], - "error": None, - "logs": "", - }, - ) - - output = version.predict(text="world") - assert isinstance(output, list) - assert output == ["hello world"] - - -@responses.activate -def test_predict_with_iterator_backwards_compatibility_cog_0_3_8(): - version = create_version_with_iterator_output_backwards_compatibility_0_3_8() - responses.post( - "https://api.replicate.com/v1/predictions", - match=[ - matchers.json_params_matcher({"version": "v1", "input": {"text": "world"}}) - ], - json={ - "id": "p1", - "version": "v1", - "urls": { - "get": "https://api.replicate.com/v1/predictions/p1", - "cancel": "https://api.replicate.com/v1/predictions/p1/cancel", - }, - "created_at": "2022-04-26T20:00:40.658234Z", - "completed_at": "2022-04-26T20:02:27.648305Z", - "source": "api", - "status": "processing", - "input": {"text": "world"}, - "output": None, - "error": None, - "logs": "", - }, - ) - responses.get( - "https://api.replicate.com/v1/predictions/p1", - json={ - "id": "p1", - "version": "v1", - "urls": { - "get": "https://api.replicate.com/v1/predictions/p1", - "cancel": "https://api.replicate.com/v1/predictions/p1/cancel", - }, - "created_at": "2022-04-26T20:00:40.658234Z", - "completed_at": "2022-04-26T20:02:27.648305Z", - "source": "api", - "status": "succeeded", - "input": {"text": "world"}, - "output": ["hello world"], - "error": None, - "logs": "", - }, - ) - - output = version.predict(text="world") - assert isinstance(output, Iterable) - assert list(output) == ["hello world"] - - -@responses.activate -def test_predict_with_iterator_with_failed_prediction(): - version = create_version_with_iterator_output() - responses.post( - "https://api.replicate.com/v1/predictions", - match=[ - matchers.json_params_matcher({"version": "v1", "input": {"text": "world"}}) - ], - json={ - "id": "p1", - "version": "v1", - "urls": { - "get": "https://api.replicate.com/v1/predictions/p1", - "cancel": "https://api.replicate.com/v1/predictions/p1/cancel", - }, - "created_at": "2022-04-26T20:00:40.658234Z", - "completed_at": "2022-04-26T20:02:27.648305Z", - "source": "api", - "status": "processing", - "input": {"text": "world"}, - "output": None, - "error": None, - "logs": "", - }, - ) - responses.get( - "https://api.replicate.com/v1/predictions/p1", - json={ - "id": "p1", - "version": "v1", - "urls": { - "get": "https://api.replicate.com/v1/predictions/p1", - "cancel": "https://api.replicate.com/v1/predictions/p1/cancel", - }, - "created_at": "2022-04-26T20:00:40.658234Z", - "completed_at": "2022-04-26T20:02:27.648305Z", - "source": "api", - "status": "failed", - "input": {"text": "world"}, - "output": None, - "error": "it broke", - "logs": "", - }, - ) - - output = version.predict(text="world") - assert isinstance(output, Iterable) - with pytest.raises(ModelError) as excinfo: - list(output) - assert "it broke" in str(excinfo.value)