diff --git a/backends/cadence/aot/functions_fusion_g3.yaml b/backends/cadence/aot/functions_fusion_g3.yaml index dbe44772a2f..0feb1e47891 100644 --- a/backends/cadence/aot/functions_fusion_g3.yaml +++ b/backends/cadence/aot/functions_fusion_g3.yaml @@ -51,7 +51,7 @@ - op: clamp.Tensor_out kernels: - arg_meta: null - kernel_name: cadence::impl::G3::clamp_tensor_out + kernel_name: cadence::impl::G3::clamp_Tensor_out - op: clone.out kernels: @@ -81,12 +81,12 @@ - op: lt.Scalar_out kernels: - arg_meta: null - kernel_name: cadence::impl::G3::lt_scalar_out + kernel_name: cadence::impl::G3::lt_Scalar_out - op: lt.Tensor_out kernels: - arg_meta: null - kernel_name: cadence::impl::G3::lt_tensor_out + kernel_name: cadence::impl::G3::lt_Tensor_out - op: mul.out kernels: @@ -155,7 +155,7 @@ - op: where.self_out kernels: - arg_meta: null - kernel_name: cadence::impl::G3::where_out + kernel_name: cadence::impl::G3::where_self_out - op: native_layer_norm.out kernels: diff --git a/backends/cadence/fusion_g3/operators/op_clamp.cpp b/backends/cadence/fusion_g3/operators/op_clamp.cpp index 52754231176..fa8424e15eb 100644 --- a/backends/cadence/fusion_g3/operators/op_clamp.cpp +++ b/backends/cadence/fusion_g3/operators/op_clamp.cpp @@ -330,7 +330,7 @@ Tensor& clamp_out( return out; } -Tensor& clamp_tensor_out( +Tensor& clamp_Tensor_out( KernelRuntimeContext& ctx, const Tensor& in, const optional& min_opt, diff --git a/backends/cadence/fusion_g3/operators/op_lt.cpp b/backends/cadence/fusion_g3/operators/op_lt.cpp index 5b986946e51..3f6cdbe3505 100644 --- a/backends/cadence/fusion_g3/operators/op_lt.cpp +++ b/backends/cadence/fusion_g3/operators/op_lt.cpp @@ -24,7 +24,7 @@ namespace impl { namespace G3 { namespace native { -Tensor& lt_tensor_out( +Tensor& lt_Tensor_out( KernelRuntimeContext& ctx, const Tensor& a, const Tensor& b, @@ -141,7 +141,7 @@ Tensor& lt_tensor_out( return out; } -Tensor& lt_scalar_out( +Tensor& lt_Scalar_out( KernelRuntimeContext& ctx, const Tensor& a, const Scalar& b, diff --git a/backends/cadence/fusion_g3/operators/op_permute_copy.cpp b/backends/cadence/fusion_g3/operators/op_permute_copy.cpp index 76b9130c6a7..204882f3da9 100644 --- a/backends/cadence/fusion_g3/operators/op_permute_copy.cpp +++ b/backends/cadence/fusion_g3/operators/op_permute_copy.cpp @@ -157,4 +157,4 @@ Tensor& permute_copy_out( } // namespace native } // namespace G3 } // namespace impl -} // namespace cadence \ No newline at end of file +} // namespace cadence diff --git a/backends/cadence/fusion_g3/operators/op_where.cpp b/backends/cadence/fusion_g3/operators/op_where.cpp index e1047b8cd58..54966c4574b 100644 --- a/backends/cadence/fusion_g3/operators/op_where.cpp +++ b/backends/cadence/fusion_g3/operators/op_where.cpp @@ -24,7 +24,7 @@ namespace impl { namespace G3 { namespace native { -Tensor& where_out( +Tensor& where_self_out( KernelRuntimeContext& ctx, const Tensor& cond, const Tensor& a, diff --git a/backends/cadence/fusion_g3/operators/targets.bzl b/backends/cadence/fusion_g3/operators/targets.bzl index e1e7c9a8491..fffeee0d7b3 100644 --- a/backends/cadence/fusion_g3/operators/targets.bzl +++ b/backends/cadence/fusion_g3/operators/targets.bzl @@ -35,6 +35,14 @@ def define_operator(name: str, deps: list[str] | None = None) -> None: OPERATORS = [ "add", "cat", + "clamp", + "lt", + "rsqrt", + "sigmoid", + "sqrt", + "tanh", + "transpose_copy", + "where", "dequantize", "mul", "native_layer_norm", diff --git a/backends/cadence/utils/facto_util.py b/backends/cadence/utils/facto_util.py index 74f846aa706..099ee15eb2f 100644 --- a/backends/cadence/utils/facto_util.py +++ b/backends/cadence/utils/facto_util.py @@ -22,9 +22,29 @@ def apply_tensor_contraints(op_name: str, tensor_constraints: list[object]) -> None: + additional_tensor_constraints = [ + cp.Dtype.In(lambda deps: [torch.int, torch.float]), + cp.Dtype.NotIn(lambda deps: [torch.int64, torch.float64]), + cp.Value.Ge(lambda deps, dtype, struct: -(2**4)), + cp.Value.Le(lambda deps, dtype, struct: 2**4), + cp.Rank.Ge(lambda deps: 1), + cp.Size.Ge(lambda deps, r, d: 1), + cp.Size.Le(lambda deps, r, d: 2**9), + ] + match op_name: + case "where.self": + additional_tensor_constraints = [ + cp.Dtype.In(lambda deps: [torch.float, torch.int, torch.bool]), + cp.Dtype.NotIn(lambda deps: [torch.int64, torch.float64]), + cp.Value.Ge(lambda deps, dtype, struct: -(2**4)), + cp.Value.Le(lambda deps, dtype, struct: 2**4), + cp.Rank.Ge(lambda deps: 1), + cp.Size.Ge(lambda deps, r, d: 1), + cp.Size.Le(lambda deps, r, d: 2**9), + ] case "sigmoid.default" | "rsqrt.default": - tensor_constraints.extend( + additional_tensor_constraints.extend( [ cp.Dtype.In(lambda deps: [torch.float]), cp.Rank.Le(lambda deps: 2**2), @@ -33,14 +53,14 @@ def apply_tensor_contraints(op_name: str, tensor_constraints: list[object]) -> N ] ) case "mean.dim": - tensor_constraints.extend( + additional_tensor_constraints.extend( [ cp.Dtype.In(lambda deps: [torch.float]), cp.Rank.Le(lambda deps: 2**2), ] ) case "exp.default": - tensor_constraints.extend( + additional_tensor_constraints.extend( [ cp.Rank.Le(lambda deps: 2**3), cp.Value.Ge(lambda deps, dtype, struct: -(2**2)), @@ -48,7 +68,7 @@ def apply_tensor_contraints(op_name: str, tensor_constraints: list[object]) -> N ] ) case "slice_copy.Tensor": - tensor_constraints.extend( + additional_tensor_constraints.extend( [ cp.Rank.Le(lambda deps: 2), cp.Value.Ge(lambda deps, dtype, struct: 1), @@ -56,22 +76,12 @@ def apply_tensor_contraints(op_name: str, tensor_constraints: list[object]) -> N ] ) case _: - tensor_constraints.extend( + additional_tensor_constraints.extend( [ cp.Rank.Le(lambda deps: 2**2), ] ) - tensor_constraints.extend( - [ - cp.Dtype.In(lambda deps: [torch.int, torch.float]), - cp.Dtype.NotIn(lambda deps: [torch.int64, torch.float64]), - cp.Value.Ge(lambda deps, dtype, struct: -(2**4)), - cp.Value.Le(lambda deps, dtype, struct: 2**4), - cp.Rank.Ge(lambda deps: 1), - cp.Size.Ge(lambda deps, r, d: 1), - cp.Size.Le(lambda deps, r, d: 2**9), - ] - ) + tensor_constraints.extend(additional_tensor_constraints) def apply_scalar_contraints(op_name: str) -> list[ScalarDtype]: