File tree Expand file tree Collapse file tree 2 files changed +64
-0
lines changed
Expand file tree Collapse file tree 2 files changed +64
-0
lines changed Original file line number Diff line number Diff line change @@ -207,6 +207,46 @@ async def async_list(
207207
208208 return Page [Model ](** obj )
209209
210+ def search (self , query : str ) -> Page [Model ]:
211+ """
212+ Search for public models.
213+
214+ Parameters:
215+ query: The search query.
216+ Returns:
217+ Page[Model]: A page of models matching the search query.
218+ """
219+ resp = self ._client ._request (
220+ "QUERY" , "/v1/models" , content = query , headers = {"Content-Type" : "text/plain" }
221+ )
222+
223+ obj = resp .json ()
224+ obj ["results" ] = [
225+ _json_to_model (self ._client , result ) for result in obj ["results" ]
226+ ]
227+
228+ return Page [Model ](** obj )
229+
230+ async def async_search (self , query : str ) -> Page [Model ]:
231+ """
232+ Asynchronously search for public models.
233+
234+ Parameters:
235+ query: The search query.
236+ Returns:
237+ Page[Model]: A page of models matching the search query.
238+ """
239+ resp = await self ._client ._async_request (
240+ "QUERY" , "/v1/models" , content = query , headers = {"Content-Type" : "text/plain" }
241+ )
242+
243+ obj = resp .json ()
244+ obj ["results" ] = [
245+ _json_to_model (self ._client , result ) for result in obj ["results" ]
246+ ]
247+
248+ return Page [Model ](** obj )
249+
210250 @overload
211251 def get (self , key : str ) -> Model : ...
212252
Original file line number Diff line number Diff line change 11import pytest
22
33import replicate
4+ from replicate .model import Model , Page
45
56
67@pytest .mark .vcr ("models-get.yaml" )
@@ -130,3 +131,26 @@ async def test_models_predictions_create(async_flag):
130131 # assert prediction.model == "meta/llama-2-70b-chat"
131132 assert prediction .model == "replicate/lifeboat-70b" # FIXME: this is temporary
132133 assert prediction .status == "starting"
134+
135+
136+ @pytest .mark .vcr ("models-search.yaml" )
137+ @pytest .mark .asyncio
138+ @pytest .mark .parametrize ("async_flag" , [True , False ])
139+ async def test_models_search (async_flag ):
140+ query = "llama"
141+
142+ if async_flag :
143+ page = await replicate .models .async_search (query )
144+ else :
145+ page = replicate .models .search (query )
146+
147+ assert isinstance (page , Page )
148+ assert len (page .results ) > 0
149+
150+ for model in page .results :
151+ assert isinstance (model , Model )
152+ assert model .id is not None
153+ assert model .owner is not None
154+ assert model .name is not None
155+
156+ assert any ("meta" in model .name .lower () for model in page .results )
You can’t perform that action at this time.
0 commit comments