Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 29 additions & 13 deletions replicate/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from .model import ModelCollection
from .prediction import PredictionCollection
from .training import TrainingCollection
from .version import Version


class Client:
Expand Down Expand Up @@ -100,26 +101,41 @@ def run(self, model_version: str, **kwargs) -> Union[Any, Iterator[Any]]:
The output of the model
"""
# Split model_version into owner, name, version in format owner/name:version
m = re.match(r"^(?P<model>[^/]+/[^:]+):(?P<version>.+)$", model_version)
if not m:
match = re.match(
r"^(?P<owner>[^/]+)/(?P<name>[^:]+):(?P<version>.+)$", model_version
)
if not match:
raise ReplicateError(
f"Invalid model_version: {model_version}. Expected format: owner/name:version"
)
model = self.models.get(m.group("model"))
version = model.versions.get(m.group("version"))
prediction = self.predictions.create(version=version, **kwargs)
# Return an iterator of the output
schema = version.get_transformed_schema()
output = schema["components"]["schemas"]["Output"]
if (
output.get("type") == "array"
and output.get("x-cog-array-type") == "iterator"
):
return prediction.output_iterator()

owner = match.group("owner")
name = match.group("name")
version_id = match.group("version")

prediction = self.predictions.create(version=version_id, **kwargs)

if owner and name:
# FIXME: There should be a method for fetching a version without first fetching its model
resp = self._request(
"GET", f"/v1/models/{owner}/{name}/versions/{version_id}"
)
version = Version(**resp.json())

# Return an iterator of the output
schema = version.get_transformed_schema()
output = schema["components"]["schemas"]["Output"]
if (
output.get("type") == "array"
and output.get("x-cog-array-type") == "iterator"
):
return prediction.output_iterator()

prediction.wait()

if prediction.status == "failed":
raise ModelError(prediction.error)

return prediction.output


Expand Down
113 changes: 99 additions & 14 deletions replicate/model.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,93 @@
from typing import Dict, List, Union
from typing import Dict, List, Optional, Union

from typing_extensions import deprecated

from replicate.base_model import BaseModel
from replicate.collection import Collection
from replicate.exceptions import ReplicateException
from replicate.version import VersionCollection
from replicate.prediction import Prediction
from replicate.version import Version, VersionCollection


class Model(BaseModel):
"""
A machine learning model hosted on Replicate.
"""

username: str
url: str
"""
The URL of the model.
"""

owner: str
"""
The name of the user or organization that owns the model.
The owner of the model.
"""

name: str
"""
The name of the model.
"""

description: Optional[str]
"""
The description of the model.
"""

visibility: str
"""
The visibility of the model. Can be 'public' or 'private'.
"""

github_url: Optional[str]
"""
The GitHub URL of the model.
"""

paper_url: Optional[str]
"""
The URL of the paper related to the model.
"""

license_url: Optional[str]
"""
The URL of the license for the model.
"""

run_count: int
"""
The number of runs of the model.
"""

cover_image_url: Optional[str]
"""
The URL of the cover image for the model.
"""

default_example: Optional[Prediction]
"""
The default example of the model.
"""

latest_version: Optional[Version]
"""
The latest version of the model.
"""

@property
@deprecated("Use `model.owner` instead.")
def username(self) -> str:
"""
The name of the user or organization that owns the model.
This attribute is deprecated and will be removed in future versions.
"""
return self.owner

@username.setter
@deprecated("Use `model.owner` instead.")
def username(self, value: str) -> None:
self.owner = value

def predict(self, *args, **kwargs) -> None:
"""
DEPRECATED: Use `replicate.run()` instead.
Expand All @@ -43,29 +110,47 @@ class ModelCollection(Collection):
model = Model

def list(self) -> List[Model]:
raise NotImplementedError()
"""
List all public models.

def get(self, name: str) -> Model:
Returns:
A list of models.
"""

resp = self._client._request("GET", "/v1/models")
# TODO: paginate
models = resp.json()["results"]
return [self.prepare_model(obj) for obj in models]

def get(self, key: str) -> Model:
"""
Get a model by name.

Args:
name: The name of the model, in the format `owner/model-name`.
key: The qualified name of the model, in the format `owner/model-name`.
Returns:
The model.
"""

# TODO: fetch model from server
# TODO: support permanent IDs
username, name = name.split("/")
return self.prepare_model({"username": username, "name": name})
resp = self._client._request("GET", f"/v1/models/{key}")
return self.prepare_model(resp.json())

def create(self, **kwargs) -> Model:
raise NotImplementedError()

def prepare_model(self, attrs: Union[Model, Dict]) -> Model:
if isinstance(attrs, BaseModel):
attrs.id = f"{attrs.username}/{attrs.name}"
attrs.id = f"{attrs.owner}/{attrs.name}"
elif isinstance(attrs, dict):
attrs["id"] = f"{attrs['username']}/{attrs['name']}"
return super().prepare_model(attrs)
attrs["id"] = f"{attrs['owner']}/{attrs['name']}"
attrs.get("default_example", {}).pop("version", None)

model = super().prepare_model(attrs)

if model.default_example is not None:
model.default_example._client = self._client

if model.latest_version is not None:
model.latest_version._client = self._client

return model
12 changes: 8 additions & 4 deletions replicate/prediction.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import re
import time
from dataclasses import dataclass
from typing import Any, Dict, Iterator, List, Optional
from typing import Any, Dict, Iterator, List, Optional, Union

from replicate.base_model import BaseModel
from replicate.collection import Collection
Expand Down Expand Up @@ -169,7 +169,7 @@ def get(self, id: str) -> Prediction:

def create( # type: ignore
self,
version: Version,
version: Union[Version, str],
input: Dict[str, Any],
webhook: Optional[str] = None,
webhook_completed: Optional[str] = None,
Expand All @@ -195,7 +195,7 @@ def create( # type: ignore

input = encode_json(input, upload_file=upload_file)
body = {
"version": version.id,
"version": version if isinstance(version, str) else version.id,
"input": input,
}
if webhook is not None:
Expand All @@ -213,5 +213,9 @@ def create( # type: ignore
json=body,
)
obj = resp.json()
obj["version"] = version
if isinstance(version, Version):
obj["version"] = version
else:
del obj["version"]

return self.prepare_model(obj)
4 changes: 2 additions & 2 deletions replicate/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def get(self, id: str) -> Version:
The model version.
"""
resp = self._client._request(
"GET", f"/v1/models/{self._model.username}/{self._model.name}/versions/{id}"
"GET", f"/v1/models/{self._model.owner}/{self._model.name}/versions/{id}"
)
return self.prepare_model(resp.json())

Expand All @@ -102,6 +102,6 @@ def list(self) -> List[Version]:
List[Version]: A list of version objects.
"""
resp = self._client._request(
"GET", f"/v1/models/{self._model.username}/{self._model.name}/versions"
"GET", f"/v1/models/{self._model.owner}/{self._model.name}/versions"
)
return [self.prepare_model(obj) for obj in resp.json()["results"]]
Loading