@@ -995,29 +995,29 @@ def pipeline(
995995 )
996996 model_kwargs ["device_map" ] = device_map
997997
998- # BC for the `torch_dtype` argument
999- if (torch_dtype := kwargs .get ("torch_dtype" )) is not None :
998+ # BC for the `torch_dtype` argument
999+ if (torch_dtype := kwargs .get ("torch_dtype" )) is not None :
1000+ logger .warning_once ("`torch_dtype` is deprecated! Use `dtype` instead!" )
1001+ # If both are provided, keep `dtype`
1002+ dtype = torch_dtype if dtype == "auto" else dtype
1003+ if "torch_dtype" in model_kwargs or "dtype" in model_kwargs :
1004+ if "torch_dtype" in model_kwargs :
10001005 logger .warning_once ("`torch_dtype` is deprecated! Use `dtype` instead!" )
1001- # If both are provided, keep `dtype`
1002- dtype = torch_dtype if dtype == "auto" else dtype
1003- if "torch_dtype" in model_kwargs or "dtype" in model_kwargs :
1004- if "torch_dtype" in model_kwargs :
1005- logger .warning_once ("`torch_dtype` is deprecated! Use `dtype` instead!" )
1006- # If the user did not explicitly provide `dtype` (i.e. the function default "auto" is still
1007- # present) but a value is supplied inside `model_kwargs`, we silently defer to the latter instead of
1008- # raising. This prevents false positives like providing `dtype` only via `model_kwargs` while the
1009- # top-level argument keeps its default value "auto".
1010- if dtype == "auto" :
1011- dtype = None
1012- else :
1013- raise ValueError (
1014- 'You cannot use both `pipeline(... dtype=..., model_kwargs={"dtype":...})` as those'
1015- " arguments might conflict, use only one.)"
1016- )
1017- if dtype is not None :
1018- if isinstance (dtype , str ) and hasattr (torch , dtype ):
1019- dtype = getattr (torch , dtype )
1020- model_kwargs ["dtype" ] = dtype
1006+ # If the user did not explicitly provide `dtype` (i.e. the function default "auto" is still
1007+ # present) but a value is supplied inside `model_kwargs`, we silently defer to the latter instead of
1008+ # raising. This prevents false positives like providing `dtype` only via `model_kwargs` while the
1009+ # top-level argument keeps its default value "auto".
1010+ if dtype == "auto" :
1011+ dtype = None
1012+ else :
1013+ raise ValueError (
1014+ 'You cannot use both `pipeline(... dtype=..., model_kwargs={"dtype":...})` as those'
1015+ " arguments might conflict, use only one.)"
1016+ )
1017+ if dtype is not None :
1018+ if isinstance (dtype , str ) and hasattr (torch , dtype ):
1019+ dtype = getattr (torch , dtype )
1020+ model_kwargs ["dtype" ] = dtype
10211021
10221022 model_name = model if isinstance (model , str ) else None
10231023
0 commit comments