Skip to content

Commit c58d7d7

Browse files
jiqing-fengSunMarc
authored andcommitted
fix pipeline dtype (#40638)
Signed-off-by: jiqing-feng <[email protected]> Co-authored-by: Marc Sun <[email protected]>
1 parent ad6b898 commit c58d7d7

File tree

1 file changed

+22
-22
lines changed

1 file changed

+22
-22
lines changed

src/transformers/pipelines/__init__.py

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -995,29 +995,29 @@ def pipeline(
995995
)
996996
model_kwargs["device_map"] = device_map
997997

998-
# BC for the `torch_dtype` argument
999-
if (torch_dtype := kwargs.get("torch_dtype")) is not None:
998+
# BC for the `torch_dtype` argument
999+
if (torch_dtype := kwargs.get("torch_dtype")) is not None:
1000+
logger.warning_once("`torch_dtype` is deprecated! Use `dtype` instead!")
1001+
# If both are provided, keep `dtype`
1002+
dtype = torch_dtype if dtype == "auto" else dtype
1003+
if "torch_dtype" in model_kwargs or "dtype" in model_kwargs:
1004+
if "torch_dtype" in model_kwargs:
10001005
logger.warning_once("`torch_dtype` is deprecated! Use `dtype` instead!")
1001-
# If both are provided, keep `dtype`
1002-
dtype = torch_dtype if dtype == "auto" else dtype
1003-
if "torch_dtype" in model_kwargs or "dtype" in model_kwargs:
1004-
if "torch_dtype" in model_kwargs:
1005-
logger.warning_once("`torch_dtype` is deprecated! Use `dtype` instead!")
1006-
# If the user did not explicitly provide `dtype` (i.e. the function default "auto" is still
1007-
# present) but a value is supplied inside `model_kwargs`, we silently defer to the latter instead of
1008-
# raising. This prevents false positives like providing `dtype` only via `model_kwargs` while the
1009-
# top-level argument keeps its default value "auto".
1010-
if dtype == "auto":
1011-
dtype = None
1012-
else:
1013-
raise ValueError(
1014-
'You cannot use both `pipeline(... dtype=..., model_kwargs={"dtype":...})` as those'
1015-
" arguments might conflict, use only one.)"
1016-
)
1017-
if dtype is not None:
1018-
if isinstance(dtype, str) and hasattr(torch, dtype):
1019-
dtype = getattr(torch, dtype)
1020-
model_kwargs["dtype"] = dtype
1006+
# If the user did not explicitly provide `dtype` (i.e. the function default "auto" is still
1007+
# present) but a value is supplied inside `model_kwargs`, we silently defer to the latter instead of
1008+
# raising. This prevents false positives like providing `dtype` only via `model_kwargs` while the
1009+
# top-level argument keeps its default value "auto".
1010+
if dtype == "auto":
1011+
dtype = None
1012+
else:
1013+
raise ValueError(
1014+
'You cannot use both `pipeline(... dtype=..., model_kwargs={"dtype":...})` as those'
1015+
" arguments might conflict, use only one.)"
1016+
)
1017+
if dtype is not None:
1018+
if isinstance(dtype, str) and hasattr(torch, dtype):
1019+
dtype = getattr(torch, dtype)
1020+
model_kwargs["dtype"] = dtype
10211021

10221022
model_name = model if isinstance(model, str) else None
10231023

0 commit comments

Comments
 (0)