@@ -108,7 +108,7 @@ def _jacobian_loss_wrt_inputs(
108108 batch).
109109
110110 Args:
111- loss_fn (torch.nn.Module, Callable, or None ): The loss function. If a library
111+ loss_fn (torch.nn.Module, Callable): The loss function. If a library
112112 defined loss function is provided, it would be expected to be a
113113 torch.nn.Module. If a custom loss is provided, it can be either type,
114114 but must behave as a library loss function would if `reduction='sum'`
@@ -131,24 +131,21 @@ def _jacobian_loss_wrt_inputs(
131131 in the batch represented by `out`. This is a 2D tensor, where the
132132 first dimension is the batch dimension.
133133 """
134- # TODO: allow loss_fn to be Callable
135- if isinstance (loss_fn , Module ) and hasattr (loss_fn , "reduction" ):
136- msg0 = "Please ensure that loss_fn.reduction is set to `sum` or `mean`"
137-
138- assert loss_fn .reduction != "none" , msg0
139- msg1 = (
140- f"loss_fn.reduction ({ loss_fn .reduction } ) does not match"
141- f"reduction type ({ reduction_type } ). Please ensure they are"
142- " matching."
143- )
144- assert loss_fn .reduction == reduction_type , msg1
145-
146134 if reduction_type != "sum" and reduction_type != "mean" :
147135 raise ValueError (
148- f"{ reduction_type } is not a valid value for reduction_type. "
136+ f"` { reduction_type } ` is not a valid value for reduction_type. "
149137 "Must be either 'sum' or 'mean'."
150138 )
151139
140+ # TODO: allow loss_fn to be Callable
141+ if isinstance (loss_fn , Module ) and hasattr (loss_fn , "reduction" ):
142+ msg = (
143+ f"loss_fn.reduction `{ loss_fn .reduction } ` does not match"
144+ f"reduction type `{ reduction_type } `. Please ensure they are"
145+ " matching."
146+ )
147+ assert loss_fn .reduction == reduction_type , msg
148+
152149 if _parse_version (torch .__version__ ) >= (1 , 8 , 0 ):
153150 input_jacobians = torch .autograd .functional .jacobian (
154151 lambda out : loss_fn (out , targets ), out , vectorize = vectorize
0 commit comments