Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions replicate/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@
models = default_client.models
predictions = default_client.predictions
trainings = default_client.trainings
deployments = default_client.deployments
5 changes: 5 additions & 0 deletions replicate/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from requests.cookies import RequestsCookieJar

from replicate.__about__ import __version__
from replicate.deployment import DeploymentCollection
from replicate.exceptions import ModelError, ReplicateError
from replicate.model import ModelCollection
from replicate.prediction import PredictionCollection
Expand Down Expand Up @@ -113,6 +114,10 @@ def predictions(self) -> PredictionCollection:
def trainings(self) -> TrainingCollection:
return TrainingCollection(client=self)

@property
def deployments(self) -> DeploymentCollection:
return DeploymentCollection(client=self)

def run(self, model_version: str, **kwargs) -> Union[Any, Iterator[Any]]:
"""
Run a model and wait for its output.
Expand Down
140 changes: 140 additions & 0 deletions replicate/deployment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union

from replicate.base_model import BaseModel
from replicate.collection import Collection
from replicate.files import upload_file
from replicate.json import encode_json
from replicate.prediction import Prediction

if TYPE_CHECKING:
from replicate.client import Client


class Deployment(BaseModel):
"""
A deployment of a model hosted on Replicate.
"""

username: str
"""
The name of the user or organization that owns the deployment.
"""

name: str
"""
The name of the deployment.
"""

@property
def predictions(self) -> "DeploymentPredictionCollection":
"""
Get the predictions for this deployment.
"""

return DeploymentPredictionCollection(client=self._client, deployment=self)


class DeploymentCollection(Collection):
model = Deployment

def list(self) -> List[Deployment]:
raise NotImplementedError()

def get(self, name: str) -> Deployment:
"""
Get a deployment by name.

Args:
name: The name of the deployment, in the format `owner/model-name`.
Returns:
The model.
"""

# TODO: fetch model from server
# TODO: support permanent IDs
username, name = name.split("/")
return self.prepare_model({"username": username, "name": name})

def create(self, **kwargs) -> Deployment:
raise NotImplementedError()

def prepare_model(self, attrs: Union[Deployment, Dict]) -> Deployment:
if isinstance(attrs, BaseModel):
attrs.id = f"{attrs.username}/{attrs.name}"
elif isinstance(attrs, dict):
attrs["id"] = f"{attrs['username']}/{attrs['name']}"
return super().prepare_model(attrs)


class DeploymentPredictionCollection(Collection):
model = Prediction

def __init__(self, client: "Client", deployment: Deployment) -> None:
super().__init__(client=client)
self._deployment = deployment

def list(self) -> List[Prediction]:
raise NotImplementedError()

def get(self, id: str) -> Prediction:
"""
Get a prediction by ID.

Args:
id: The ID of the prediction.
Returns:
Prediction: The prediction object.
"""

resp = self._client._request("GET", f"/v1/predictions/{id}")
obj = resp.json()
# HACK: resolve this? make it lazy somehow?
del obj["version"]
return self.prepare_model(obj)

def create( # type: ignore
self,
input: Dict[str, Any],
webhook: Optional[str] = None,
webhook_completed: Optional[str] = None,
webhook_events_filter: Optional[List[str]] = None,
*,
stream: Optional[bool] = None,
**kwargs,
) -> Prediction:
"""
Create a new prediction with the deployment.

Args:
input: The input data for the prediction.
webhook: The URL to receive a POST request with prediction updates.
webhook_completed: The URL to receive a POST request when the prediction is completed.
webhook_events_filter: List of events to trigger webhooks.
stream: Set to True to enable streaming of prediction output.

Returns:
Prediction: The created prediction object.
"""

input = encode_json(input, upload_file=upload_file)
body: Dict[str, Any] = {
"input": input,
}
if webhook is not None:
body["webhook"] = webhook
if webhook_completed is not None:
body["webhook_completed"] = webhook_completed
if webhook_events_filter is not None:
body["webhook_events_filter"] = webhook_events_filter
if stream is True:
body["stream"] = "true"

resp = self._client._request(
"POST",
f"/v1/deployments/{self._deployment.username}/{self._deployment.name}/predictions",
json=body,
)
obj = resp.json()
obj["deployment"] = self._deployment
del obj["version"]
return self.prepare_model(obj)
47 changes: 47 additions & 0 deletions tests/test_deployment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import responses
from responses import matchers

from replicate.client import Client


@responses.activate
def test_deployment_predictions_create():
client = Client(api_token="abc123")

deployment = client.deployments.get("test/model")

rsp = responses.post(
"https://api.replicate.com/v1/deployments/test/model/predictions",
match=[
matchers.json_params_matcher(
{
"input": {"text": "world"},
"webhook": "https://example.com/webhook",
"webhook_events_filter": ["completed"],
}
),
],
json={
"id": "p1",
"version": "v1",
"urls": {
"get": "https://api.replicate.com/v1/predictions/p1",
"cancel": "https://api.replicate.com/v1/predictions/p1/cancel",
},
"created_at": "2022-04-26T20:00:40.658234Z",
"source": "api",
"status": "processing",
"input": {"text": "hello"},
"output": None,
"error": None,
"logs": "",
},
)

deployment.predictions.create(
input={"text": "world"},
webhook="https://example.com/webhook",
webhook_events_filter=["completed"],
)

assert rsp.call_count == 1