diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index f463934d..a7366feb 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -28,15 +28,13 @@ jobs: with: python-version: ${{ matrix.python-version }} cache: "pip" - - name: Install dependencies - run: | - python -m pip install -r requirements.txt -r requirements-dev.txt . - yes | python -m mypy --install-types replicate || true - - name: Lint - run: | - python -m mypy replicate - python -m ruff . - python -m ruff format --check . + - name: Setup + run: ./script/setup + - name: Test - run: python -m pytest + run: ./script/test + + - name: Lint + run: ./script/lint + \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 7341cb0b..b6415b88 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,6 +13,7 @@ requires-python = ">=3.8" dependencies = ["packaging", "pydantic>1", "httpx>=0.21.0,<1"] optional-dependencies = { dev = [ "mypy", + "pylint", "pytest", "pytest-asyncio", "pytest-recording", @@ -27,6 +28,10 @@ repository = "https://github.com/replicate/replicate-python" [tool.pytest.ini_options] testpaths = "tests/" +[tool.setuptools] +# See https://github.com/pypa/setuptools/issues/3197#issuecomment-1078770109 +py-modules = [] + [tool.setuptools.package-data] "replicate" = ["py.typed"] @@ -34,6 +39,16 @@ testpaths = "tests/" plugins = "pydantic.mypy" exclude = ["tests/"] +[tool.pylint.main] +disable = [ + "C0301", # Line too long + "C0413", # Import should be placed at the top of the module + "C0114", # Missing module docstring + "R0801", # Similar lines in N files + "W0212", # Access to a protected member + "W0622", # Redefining built-in +] + [tool.ruff] select = [ "E", # pycodestyle error diff --git a/replicate/base_model.py b/replicate/base_model.py index c1dc1498..3a954b2c 100644 --- a/replicate/base_model.py +++ b/replicate/base_model.py @@ -24,6 +24,7 @@ def reload(self) -> None: """ Load this object from the server again. """ - new_model = self._collection.get(self.id) - for k, v in new_model.dict().items(): + + new_model = self._collection.get(self.id) # pylint: disable=no-member + for k, v in new_model.dict().items(): # pylint: disable=invalid-name setattr(self, k, v) diff --git a/replicate/client.py b/replicate/client.py index cef9f46a..8f023e47 100644 --- a/replicate/client.py +++ b/replicate/client.py @@ -84,18 +84,30 @@ def _request(self, method: str, path: str, **kwargs) -> httpx.Response: @property def models(self) -> ModelCollection: + """ + Namespace for operations related to models. + """ return ModelCollection(client=self) @property def predictions(self) -> PredictionCollection: + """ + Namespace for operations related to predictions. + """ return PredictionCollection(client=self) @property def trainings(self) -> TrainingCollection: + """ + Namespace for operations related to trainings. + """ return TrainingCollection(client=self) @property def deployments(self) -> DeploymentCollection: + """ + Namespace for operations related to deployments. + """ return DeploymentCollection(client=self) def run(self, model_version: str, **kwargs) -> Union[Any, Iterator[Any]]: # noqa: ANN401 diff --git a/replicate/collection.py b/replicate/collection.py index 32596f89..799b7b63 100644 --- a/replicate/collection.py +++ b/replicate/collection.py @@ -5,6 +5,7 @@ from replicate.client import Client from replicate.base_model import BaseModel +from replicate.exceptions import ReplicateException Model = TypeVar("Model", bound=BaseModel) @@ -17,20 +18,21 @@ class Collection(abc.ABC, Generic[Model]): def __init__(self, client: "Client") -> None: self._client = client - @abc.abstractproperty - def model(self) -> Model: + @property + @abc.abstractmethod + def model(self) -> Model: # pylint: disable=missing-function-docstring pass @abc.abstractmethod - def list(self) -> List[Model]: + def list(self) -> List[Model]: # pylint: disable=missing-function-docstring pass @abc.abstractmethod - def get(self, key: str) -> Model: + def get(self, key: str) -> Model: # pylint: disable=missing-function-docstring pass @abc.abstractmethod - def create(self, **kwargs) -> Model: + def create(self, **kwargs) -> Model: # pylint: disable=missing-function-docstring pass def prepare_model(self, attrs: Union[Model, Dict]) -> Model: @@ -41,13 +43,12 @@ def prepare_model(self, attrs: Union[Model, Dict]) -> Model: attrs._client = self._client attrs._collection = self return cast(Model, attrs) - elif ( - isinstance(attrs, dict) and self.model is not None and callable(self.model) - ): + + if isinstance(attrs, dict) and self.model is not None and callable(self.model): model = self.model(**attrs) model._client = self._client model._collection = self return model - else: - name = self.model.__name__ if hasattr(self.model, "__name__") else "model" - raise Exception(f"Can't create {name} from {attrs}") + + name = self.model.__name__ if hasattr(self.model, "__name__") else "model" + raise ReplicateException(f"Can't create {name} from {attrs}") diff --git a/replicate/deployment.py b/replicate/deployment.py index 191511fb..1a0766c7 100644 --- a/replicate/deployment.py +++ b/replicate/deployment.py @@ -35,9 +35,19 @@ def predictions(self) -> "DeploymentPredictionCollection": class DeploymentCollection(Collection): + """ + Namespace for operations related to deployments. + """ + model = Deployment def list(self) -> List[Deployment]: + """ + List deployments. + + Raises: + NotImplementedError: This method is not implemented. + """ raise NotImplementedError() def get(self, name: str) -> Deployment: @@ -56,6 +66,12 @@ def get(self, name: str) -> Deployment: return self.prepare_model({"username": username, "name": name}) def create(self, **kwargs) -> Deployment: + """ + Create a deployment. + + Raises: + NotImplementedError: This method is not implemented. + """ raise NotImplementedError() def prepare_model(self, attrs: Union[Deployment, Dict]) -> Deployment: @@ -74,6 +90,12 @@ def __init__(self, client: "Client", deployment: Deployment) -> None: self._deployment = deployment def list(self) -> List[Prediction]: + """ + List predictions in a deployment. + + Raises: + NotImplementedError: This method is not implemented. + """ raise NotImplementedError() def get(self, id: str) -> Prediction: diff --git a/replicate/files.py b/replicate/files.py index 394d589c..e761ed3f 100644 --- a/replicate/files.py +++ b/replicate/files.py @@ -7,36 +7,34 @@ import httpx -def upload_file(fh: io.IOBase, output_file_prefix: Optional[str] = None) -> str: +def upload_file(file: io.IOBase, output_file_prefix: Optional[str] = None) -> str: """ Upload a file to the server. Args: - fh: A file handle to upload. + file: A file handle to upload. output_file_prefix: A string to prepend to the output file name. Returns: str: A URL to the uploaded file. """ # Lifted straight from cog.files - fh.seek(0) + file.seek(0) if output_file_prefix is not None: - name = getattr(fh, "name", "output") + name = getattr(file, "name", "output") url = output_file_prefix + os.path.basename(name) - resp = httpx.put(url, files={"file": fh}, timeout=None) # type: ignore + resp = httpx.put(url, files={"file": file}, timeout=None) # type: ignore resp.raise_for_status() + return url - b = fh.read() - # The file handle is strings, not bytes - if isinstance(b, str): - b = b.encode("utf-8") - encoded_body = base64.b64encode(b) - if getattr(fh, "name", None): - # despite doing a getattr check here, mypy complains that io.IOBase has no attribute name - mime_type = mimetypes.guess_type(fh.name)[0] # type: ignore - else: - mime_type = "application/octet-stream" - s = encoded_body.decode("utf-8") - return f"data:{mime_type};base64,{s}" + body = file.read() + # Ensure the file handle is in bytes + body = body.encode("utf-8") if isinstance(body, str) else body + encoded_body = base64.b64encode(body).decode("utf-8") + # Use getattr to avoid mypy complaints about io.IOBase having no attribute name + mime_type = ( + mimetypes.guess_type(getattr(file, "name", ""))[0] or "application/octet-stream" + ) + return f"data:{mime_type};base64,{encoded_body}" diff --git a/replicate/json.py b/replicate/json.py index 5964b7fc..ae5f32c8 100644 --- a/replicate/json.py +++ b/replicate/json.py @@ -6,11 +6,12 @@ try: import numpy as np # type: ignore - has_numpy = True + HAS_NUMPY = True except ImportError: - has_numpy = False + HAS_NUMPY = False +# pylint: disable=too-many-return-statements def encode_json( obj: Any, # noqa: ANN401 upload_file: Callable[[io.IOBase], str], @@ -25,11 +26,11 @@ def encode_json( if isinstance(obj, (list, set, frozenset, GeneratorType, tuple)): return [encode_json(value, upload_file) for value in obj] if isinstance(obj, Path): - with obj.open("rb") as f: - return upload_file(f) + with obj.open("rb") as file: + return upload_file(file) if isinstance(obj, io.IOBase): return upload_file(obj) - if has_numpy: + if HAS_NUMPY: if isinstance(obj, np.integer): return int(obj) if isinstance(obj, np.floating): diff --git a/replicate/model.py b/replicate/model.py index 9fa2924a..3dcc2427 100644 --- a/replicate/model.py +++ b/replicate/model.py @@ -107,6 +107,10 @@ def versions(self) -> VersionCollection: class ModelCollection(Collection): + """ + Namespace for operations related to models. + """ + model = Model def list(self) -> List[Model]: @@ -136,6 +140,12 @@ def get(self, key: str) -> Model: return self.prepare_model(resp.json()) def create(self, **kwargs) -> Model: + """ + Create a model. + + Raises: + NotImplementedError: This method is not implemented. + """ raise NotImplementedError() def prepare_model(self, attrs: Union[Model, Dict]) -> Model: diff --git a/replicate/prediction.py b/replicate/prediction.py index e0944e60..f8afa5e8 100644 --- a/replicate/prediction.py +++ b/replicate/prediction.py @@ -103,7 +103,7 @@ def wait(self) -> None: Wait for prediction to finish. """ while self.status not in ["succeeded", "failed", "canceled"]: - time.sleep(self._client.poll_interval) + time.sleep(self._client.poll_interval) # pylint: disable=no-member self.reload() def output_iterator(self) -> Iterator[Any]: @@ -114,7 +114,7 @@ def output_iterator(self) -> Iterator[Any]: new_output = output[len(previous_output) :] yield from new_output previous_output = output - time.sleep(self._client.poll_interval) + time.sleep(self._client.poll_interval) # pylint: disable=no-member self.reload() if self.status == "failed": @@ -129,10 +129,14 @@ def cancel(self) -> None: """ Cancels a running prediction. """ - self._client._request("POST", f"/v1/predictions/{self.id}/cancel") + self._client._request("POST", f"/v1/predictions/{self.id}/cancel") # pylint: disable=no-member class PredictionCollection(Collection): + """ + Namespace for operations related to predictions. + """ + model = Prediction def list(self) -> List[Prediction]: diff --git a/replicate/schema.py b/replicate/schema.py index 0fb784a3..06f9f058 100644 --- a/replicate/schema.py +++ b/replicate/schema.py @@ -15,12 +15,12 @@ def version_has_no_array_type(cog_version: str) -> Optional[bool]: def make_schema_backwards_compatible( schema: dict, - version: str, + cog_version: str, ) -> dict: """A place to add backwards compatibility logic for our openapi schema""" # If the top-level output is an array, assume it is an iterator in old versions which didn't have an array type - if version_has_no_array_type(version): + if version_has_no_array_type(cog_version): output = schema["components"]["schemas"]["Output"] if output.get("type") == "array": output["x-cog-array-type"] = "iterator" diff --git a/replicate/training.py b/replicate/training.py index d93b56ab..4499a79e 100644 --- a/replicate/training.py +++ b/replicate/training.py @@ -58,10 +58,14 @@ class Training(BaseModel): def cancel(self) -> None: """Cancel a running training""" - self._client._request("POST", f"/v1/trainings/{self.id}/cancel") + self._client._request("POST", f"/v1/trainings/{self.id}/cancel") # pylint: disable=no-member class TrainingCollection(Collection): + """ + Namespace for operations related to trainings. + """ + model = Training def list(self) -> List[Training]: diff --git a/replicate/version.py b/replicate/version.py index dbba4423..c3be8b2e 100644 --- a/replicate/version.py +++ b/replicate/version.py @@ -48,7 +48,7 @@ def predict(self, **kwargs) -> Union[Any, Iterator[Any]]: # noqa: ANN401 stacklevel=1, ) - prediction = self._client.predictions.create(version=self, input=kwargs) + prediction = self._client.predictions.create(version=self, input=kwargs) # pylint: disable=no-member # Return an iterator of the output schema = self.get_transformed_schema() output = schema["components"]["schemas"]["Output"] @@ -70,13 +70,16 @@ def get_transformed_schema(self) -> dict: class VersionCollection(Collection): + """ + Namespace for operations related to model versions. + """ + model = Version def __init__(self, client: "Client", model: "Model") -> None: super().__init__(client=client) self._model = model - # doesn't exist yet def get(self, id: str) -> Version: """ Get a specific model version. @@ -92,6 +95,12 @@ def get(self, id: str) -> Version: return self.prepare_model(resp.json()) def create(self, **kwargs) -> Version: + """ + Create a model version. + + Raises: + NotImplementedError: This method is not implemented. + """ raise NotImplementedError() def list(self) -> List[Version]: diff --git a/requirements-dev.txt b/requirements-dev.txt index 8a4a1539..0759b688 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -8,10 +8,14 @@ annotated-types==0.5.0 # via pydantic anyio==3.7.1 # via httpcore +astroid==3.0.1 + # via pylint certifi==2023.7.22 # via # httpcore # httpx +dill==0.3.7 + # via pylint h11==0.14.0 # via httpcore httpcore==0.17.3 @@ -27,6 +31,10 @@ idna==3.4 # yarl iniconfig==2.0.0 # via pytest +isort==5.12.0 + # via pylint +mccabe==0.7.0 + # via pylint multidict==6.0.4 # via yarl mypy==1.4.1 @@ -37,12 +45,16 @@ packaging==23.1 # via # pytest # replicate (pyproject.toml) +platformdirs==3.11.0 + # via pylint pluggy==1.2.0 # via pytest pydantic==2.0.3 # via replicate (pyproject.toml) pydantic-core==2.3.0 # via pydantic +pylint==3.0.2 + # via replicate (pyproject.toml) pytest==7.4.0 # via # pytest-asyncio @@ -63,6 +75,8 @@ sniffio==1.3.0 # anyio # httpcore # httpx +tomlkit==0.12.1 + # via pylint typing-extensions==4.7.1 # via # mypy diff --git a/script/format b/script/format new file mode 100755 index 00000000..2cc38a42 --- /dev/null +++ b/script/format @@ -0,0 +1,5 @@ +#!/bin/bash + +set -e + +python -m ruff format . diff --git a/script/lint b/script/lint new file mode 100755 index 00000000..7fe2d232 --- /dev/null +++ b/script/lint @@ -0,0 +1,23 @@ +#!/bin/bash + +set -e + +STATUS=0 + +echo "Running mypy" +python -m mypy replicate || STATUS=$? +echo "" + +echo "Running pylint" +python -m pylint --exit-zero replicate || STATUS=$? +echo "" + +echo "Running ruff check" +python -m ruff . || STATUS=$? +echo "" + +echo "Running ruff format check" +python -m ruff format --check . || STATUS=$? +echo "" + +exit $STATUS diff --git a/script/setup b/script/setup new file mode 100755 index 00000000..7e4313e5 --- /dev/null +++ b/script/setup @@ -0,0 +1,7 @@ +#!/bin/bash + +set -e + +python -m pip install -r requirements.txt -r requirements-dev.txt . + +yes | python -m mypy --install-types replicate || true diff --git a/script/test b/script/test new file mode 100755 index 00000000..df1e5d9d --- /dev/null +++ b/script/test @@ -0,0 +1,5 @@ +#!/bin/bash + +set -e + +python -m pytest -v