diff --git a/src/ATen/native/xpu/Activation.cpp b/src/ATen/native/xpu/Activation.cpp index d744d5c4b..06e87641f 100644 --- a/src/ATen/native/xpu/Activation.cpp +++ b/src/ATen/native/xpu/Activation.cpp @@ -85,7 +85,7 @@ std::tuple log_sigmoid_forward_out_xpu( std::tuple log_sigmoid_forward_xpu(const Tensor& input) { auto result = at::empty_like(input); - auto buffer = at::empty({0}, input.options()); + auto buffer = at::empty_like(input); log_sigmoid_forward_out_xpu(input, result, buffer); return std::forward_as_tuple(result, buffer); } diff --git a/test/xpu/functorch/test_ops_xpu.py b/test/xpu/functorch/test_ops_xpu.py index 4966480a5..81551ae9a 100644 --- a/test/xpu/functorch/test_ops_xpu.py +++ b/test/xpu/functorch/test_ops_xpu.py @@ -3026,6 +3026,43 @@ def test_ordered_complex_raises(self, device, dtype, op): **sample_input.kwargs, ) + def test_vmap_jvp_vjp_logsigmoid(self): + primal_in = torch.tensor(12.34, requires_grad=True) + cotangent_in = torch.tensor([1.0, 3.0, 2.0, 4.5]) + primal_tangent_in = torch.tensor(1.0) + cotangent_tangent_in = torch.tensor(1.0) + + def push_vjp(primal_in, cotangent_in): + _, vjp_fn = vjp(torch.nn.functional.logsigmoid, primal_in) + (grad,) = vjp_fn(cotangent_in) + return grad + + def jvp_of_vjp( + primal_in, cotangent_in, primal_tangent_in, cotangent_tangent_in + ): + return jvp( + push_vjp, + (primal_in, cotangent_in), + (primal_tangent_in, cotangent_tangent_in), + ) + + cpu_results = vmap(jvp_of_vjp, in_dims=(None, 0, None, None))( + primal_in, cotangent_in, primal_tangent_in, cotangent_tangent_in + ) + + xpu_results = vmap(jvp_of_vjp, in_dims=(None, 0, None, None))( + primal_in.xpu(), + cotangent_in.xpu(), + primal_tangent_in.xpu(), + cotangent_tangent_in.xpu(), + ) + + primal_cpu, tangent_cpu = cpu_results + primal_xpu, tangent_xpu = xpu_results + + self.assertEqual(primal_cpu, primal_xpu.cpu()) + self.assertEqual(tangent_cpu, tangent_xpu.cpu()) + instantiate_device_type_tests( TestOperators, globals(), only_for=("xpu"), allow_xpu=True