55import copy
66import time
77import types
8- from typing import List , Literal , Optional , Tuple , Union , overload
8+ from typing import Callable , List , Literal , Optional , Tuple , Union , cast , overload
99
1010import httpx
1111
1717 _bedrock_converse_messages_pt ,
1818 _bedrock_tools_pt ,
1919)
20+ from litellm .llms .base_llm .chat .transformation import BaseConfig , BaseLLMException
2021from litellm .types .llms .bedrock import *
2122from litellm .types .llms .openai import (
2223 AllMessageValues ,
4243all_global_regions = global_config .get_all_regions ()
4344
4445
45- class AmazonConverseConfig :
46+ class AmazonConverseConfig ( BaseConfig ) :
4647 """
4748 Reference - https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html
4849 #2 - https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html#conversation-inference-supported-models-features
@@ -193,9 +194,9 @@ def _create_json_tool_call_for_response_format(
193194
194195 def map_openai_params (
195196 self ,
196- model : str ,
197197 non_default_params : dict ,
198198 optional_params : dict ,
199+ model : str ,
199200 drop_params : bool ,
200201 messages : Optional [List [AllMessageValues ]] = None ,
201202 ) -> dict :
@@ -254,25 +255,6 @@ def map_openai_params(
254255 if _tool_choice_value is not None :
255256 optional_params ["tool_choice" ] = _tool_choice_value
256257
257- ## VALIDATE REQUEST
258- """
259- Bedrock doesn't support tool calling without `tools=` param specified.
260- """
261- if (
262- "tools" not in non_default_params
263- and messages is not None
264- and has_tool_call_blocks (messages )
265- ):
266- if litellm .modify_params :
267- optional_params ["tools" ] = add_dummy_tool (
268- custom_llm_provider = "bedrock_converse"
269- )
270- else :
271- raise litellm .UnsupportedParamsError (
272- message = "Bedrock doesn't support tool calling without `tools=` param specified. Pass `tools=` param OR set `litellm.modify_params = True` // `litellm_settings::modify_params: True` to add dummy tool to the request." ,
273- model = "" ,
274- llm_provider = "bedrock" ,
275- )
276258 return optional_params
277259
278260 @overload
@@ -352,8 +334,32 @@ def _transform_inference_params(self, inference_params: dict) -> InferenceConfig
352334 return InferenceConfig (** inference_params )
353335
354336 def _transform_request_helper (
355- self , system_content_blocks : List [SystemContentBlock ], optional_params : dict
337+ self ,
338+ system_content_blocks : List [SystemContentBlock ],
339+ optional_params : dict ,
340+ messages : Optional [List [AllMessageValues ]] = None ,
356341 ) -> CommonRequestObject :
342+
343+ ## VALIDATE REQUEST
344+ """
345+ Bedrock doesn't support tool calling without `tools=` param specified.
346+ """
347+ if (
348+ "tools" not in optional_params
349+ and messages is not None
350+ and has_tool_call_blocks (messages )
351+ ):
352+ if litellm .modify_params :
353+ optional_params ["tools" ] = add_dummy_tool (
354+ custom_llm_provider = "bedrock_converse"
355+ )
356+ else :
357+ raise litellm .UnsupportedParamsError (
358+ message = "Bedrock doesn't support tool calling without `tools=` param specified. Pass `tools=` param OR set `litellm.modify_params = True` // `litellm_settings::modify_params: True` to add dummy tool to the request." ,
359+ model = "" ,
360+ llm_provider = "bedrock" ,
361+ )
362+
357363 inference_params = copy .deepcopy (optional_params )
358364 additional_request_keys = []
359365 additional_request_params = {}
@@ -429,14 +435,12 @@ async def _async_transform_request(
429435 ) -> RequestObject :
430436 messages , system_content_blocks = self ._transform_system_message (messages )
431437 ## TRANSFORMATION ##
432- # bedrock_messages: List[MessageBlock] = await asyncify(
433- # _bedrock_converse_messages_pt
434- # )(
435- # messages=messages,
436- # model=model,
437- # llm_provider="bedrock_converse",
438- # user_continue_message=litellm_params.pop("user_continue_message", None),
439- # )
438+
439+ _data : CommonRequestObject = self ._transform_request_helper (
440+ system_content_blocks = system_content_blocks ,
441+ optional_params = optional_params ,
442+ messages = messages ,
443+ )
440444
441445 bedrock_messages = (
442446 await BedrockConverseMessagesProcessor ._bedrock_converse_messages_pt_async (
@@ -447,15 +451,28 @@ async def _async_transform_request(
447451 )
448452 )
449453
450- _data : CommonRequestObject = self ._transform_request_helper (
451- system_content_blocks = system_content_blocks ,
452- optional_params = optional_params ,
453- )
454-
455454 data : RequestObject = {"messages" : bedrock_messages , ** _data }
456455
457456 return data
458457
458+ def transform_request (
459+ self ,
460+ model : str ,
461+ messages : List [AllMessageValues ],
462+ optional_params : dict ,
463+ litellm_params : dict ,
464+ headers : dict ,
465+ ) -> dict :
466+ return cast (
467+ dict ,
468+ self ._transform_request (
469+ model = model ,
470+ messages = messages ,
471+ optional_params = optional_params ,
472+ litellm_params = litellm_params ,
473+ ),
474+ )
475+
459476 def _transform_request (
460477 self ,
461478 model : str ,
@@ -464,6 +481,13 @@ def _transform_request(
464481 litellm_params : dict ,
465482 ) -> RequestObject :
466483 messages , system_content_blocks = self ._transform_system_message (messages )
484+
485+ _data : CommonRequestObject = self ._transform_request_helper (
486+ system_content_blocks = system_content_blocks ,
487+ optional_params = optional_params ,
488+ messages = messages ,
489+ )
490+
467491 ## TRANSFORMATION ##
468492 bedrock_messages : List [MessageBlock ] = _bedrock_converse_messages_pt (
469493 messages = messages ,
@@ -472,15 +496,38 @@ def _transform_request(
472496 user_continue_message = litellm_params .pop ("user_continue_message" , None ),
473497 )
474498
475- _data : CommonRequestObject = self ._transform_request_helper (
476- system_content_blocks = system_content_blocks ,
477- optional_params = optional_params ,
478- )
479-
480499 data : RequestObject = {"messages" : bedrock_messages , ** _data }
481500
482501 return data
483502
503+ def transform_response (
504+ self ,
505+ model : str ,
506+ raw_response : httpx .Response ,
507+ model_response : ModelResponse ,
508+ logging_obj : Logging ,
509+ request_data : dict ,
510+ messages : List [AllMessageValues ],
511+ optional_params : dict ,
512+ litellm_params : dict ,
513+ encoding : Any ,
514+ api_key : Optional [str ] = None ,
515+ json_mode : Optional [bool ] = None ,
516+ ) -> ModelResponse :
517+ return self ._transform_response (
518+ model = model ,
519+ response = raw_response ,
520+ model_response = model_response ,
521+ stream = optional_params .get ("stream" , False ),
522+ logging_obj = logging_obj ,
523+ optional_params = optional_params ,
524+ api_key = api_key ,
525+ data = request_data ,
526+ messages = messages ,
527+ print_verbose = None ,
528+ encoding = encoding ,
529+ )
530+
484531 def _transform_response (
485532 self ,
486533 model : str ,
@@ -489,12 +536,12 @@ def _transform_response(
489536 stream : bool ,
490537 logging_obj : Optional [Logging ],
491538 optional_params : dict ,
492- api_key : str ,
539+ api_key : Optional [ str ] ,
493540 data : Union [dict , str ],
494541 messages : List ,
495- print_verbose ,
542+ print_verbose : Optional [ Callable ] ,
496543 encoding ,
497- ) -> Union [ ModelResponse , CustomStreamWrapper ] :
544+ ) -> ModelResponse :
498545 ## LOGGING
499546 if logging_obj is not None :
500547 logging_obj .post_call (
@@ -503,7 +550,7 @@ def _transform_response(
503550 original_response = response .text ,
504551 additional_args = {"complete_input_dict" : data },
505552 )
506- print_verbose ( f"raw model_response: { response . text } " )
553+
507554 json_mode : Optional [bool ] = optional_params .pop ("json_mode" , None )
508555 ## RESPONSE OBJECT
509556 try :
@@ -652,3 +699,25 @@ def _get_base_model(self, model: str) -> str:
652699 return model .split ("/" , 1 )[1 ]
653700
654701 return model
702+
703+ def get_error_class (
704+ self , error_message : str , status_code : int , headers : Union [dict , httpx .Headers ]
705+ ) -> BaseLLMException :
706+ return BedrockError (
707+ message = error_message ,
708+ status_code = status_code ,
709+ headers = headers ,
710+ )
711+
712+ def validate_environment (
713+ self ,
714+ headers : dict ,
715+ model : str ,
716+ messages : List [AllMessageValues ],
717+ optional_params : dict ,
718+ api_key : Optional [str ] = None ,
719+ api_base : Optional [str ] = None ,
720+ ) -> dict :
721+ if api_key :
722+ headers ["Authorization" ] = f"Bearer { api_key } "
723+ return headers
0 commit comments