55
66import httpx
77
8+ from cogstack_model_gateway_client .exceptions import retry_if_network_error
9+
810
911class GatewayClient :
1012 def __init__ (
@@ -41,6 +43,33 @@ async def wrapper(self, *args, **kwargs):
4143
4244 return wrapper
4345
46+ @require_client
47+ @retry_if_network_error
48+ async def _request (
49+ self ,
50+ method : str ,
51+ url : str ,
52+ * ,
53+ data = None ,
54+ json = None ,
55+ files = None ,
56+ params = None ,
57+ headers = None ,
58+ ** kwargs ,
59+ ) -> httpx .Response :
60+ """Make HTTP requests with retry logic."""
61+ resp = await self ._client .request (
62+ method = method ,
63+ url = url ,
64+ data = data ,
65+ json = json ,
66+ files = files ,
67+ params = params ,
68+ headers = headers ,
69+ ** kwargs ,
70+ )
71+ return resp .raise_for_status ()
72+
4473 @require_client
4574 async def submit_task (
4675 self ,
@@ -59,11 +88,12 @@ async def submit_task(
5988 if not model_name :
6089 raise ValueError ("Please provide a model name or set a default model for the client." )
6190 url = f"{ self .base_url } /models/{ model_name } /tasks/{ task } "
62- resp = await self ._client .post (
63- url , data = data , json = json , files = files , params = params , headers = headers
91+
92+ resp = await self ._request (
93+ "POST" , url , data = data , json = json , files = files , params = params , headers = headers
6494 )
65- resp .raise_for_status ()
6695 task_info = resp .json ()
96+
6797 if wait_for_completion :
6898 task_uuid = task_info ["uuid" ]
6999 task_info = await self .wait_for_task (task_uuid )
@@ -143,8 +173,7 @@ async def _get_task(self, task_uuid: str, detail: bool = True, download: bool =
143173 """Get a Gateway task."""
144174 url = f"{ self .base_url } /tasks/{ task_uuid } "
145175 params = {"detail" : detail , "download" : download }
146- resp = await self ._client .get (url , params = params )
147- return resp .raise_for_status ()
176+ return await self ._request ("GET" , url , params = params )
148177
149178 @require_client
150179 async def get_task (self , task_uuid : str , detail : bool = True ):
@@ -207,8 +236,7 @@ async def wait_for_task(
207236 async def get_models (self , verbose : bool = False ):
208237 """Get the list of available models from the Gateway."""
209238 url = f"{ self .base_url } /models/"
210- resp = await self ._client .get (url , params = {"verbose" : verbose })
211- resp .raise_for_status ()
239+ resp = await self ._request ("GET" , url , params = {"verbose" : verbose })
212240 return resp .json ()
213241
214242 @require_client
@@ -218,8 +246,7 @@ async def get_model(self, model_name: str = None):
218246 if not model_name :
219247 raise ValueError ("Please provide a model name or set a default model for the client." )
220248 url = f"{ self .base_url } /models/{ model_name } /info"
221- resp = await self ._client .get (url )
222- resp .raise_for_status ()
249+ resp = await self ._request ("GET" , url )
223250 return resp .json ()
224251
225252 @require_client
@@ -236,8 +263,7 @@ async def deploy_model(
236263 raise ValueError ("Please provide a model name or set a default model for the client." )
237264 url = f"{ self .base_url } /models/{ model_name } "
238265 data = {"tracking_id" : tracking_id , "model_uri" : model_uri , "ttl" : ttl }
239- resp = await self ._client .post (url , json = data )
240- resp .raise_for_status ()
266+ resp = await self ._request ("POST" , url , json = data )
241267 return resp .json ()
242268
243269
0 commit comments