Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 45 additions & 2 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
from onnxscript.onnx_opset import opset18 as op
from onnxscript.onnx_types import TensorType

_INT32_MAX = 2147483647
_INT64_MAX = 9223372036854775807
_INT64_MIN = -9223372036854775808
_MATH_PI = math.pi
Expand Down Expand Up @@ -9183,15 +9184,57 @@ def aten_unfold_copy(self: TensorType, dimension: int, size: int, step: int) ->
raise NotImplementedError()


@torch_op("aten::unique_consecutive", trace_only=True)
def aten_unique_consecutive(
self: TensorType,
x: TensorType,
return_inverse: bool = False,
return_counts: bool = False,
dim: Optional[int] = None,
) -> tuple[TensorType, TensorType, TensorType]:
"""unique_consecutive(Tensor self, bool return_inverse=False, bool return_counts=False, int? dim=None) -> (Tensor, Tensor, Tensor)"""
assert x.dtype in {INT64.dtype, INT32.dtype}, (
"unique_consecutive not implemented for other type than int32, int64"
)
rank_x = len(x.shape)

raise NotImplementedError()
zero = op.Constant(value=ir.tensor([0], dtype=x.dtype))
zero64 = op.Constant(value=ir.tensor([0], dtype=INT64.dtype))
minus_one = op.Constant(value=ir.tensor([-1], dtype=INT64.dtype))

if dim is None:
if rank_x != 1:
x = op.Reshape(x, minus_one)
else:
assert rank_x == 1 and dim == 0, (
f"Not implemented for x={x!r} with rank={rank_x} and dim={dim}."
)

lag = op.Concat(
# Hopefully this will never be equal to the first value of the tensor x
# ideally we could do differently but with a higher cost
op.Constant(value=ir.tensor([_INT32_MAX], dtype=x.dtype)),
op.Slice(x, zero64, minus_one, zero64),
axis=0,
)
eq = op.Equal(x, lag)
diff = op.Not(eq)
res = op.Compress(x, diff, axis=0)

zero_no_dim = op.Constant(value=ir.tensor(0, dtype=x.dtype))
one_no_dim = op.Constant(value=ir.tensor(1, dtype=x.dtype))
one = op.Constant(value=ir.tensor([1], dtype=x.dtype))

inverse = op.Sub(op.CumSum(op.Cast(diff, to=x.dtype), zero), one)
shape_x = op.Shape(x)
indices = op.Range(zero_no_dim, op.Squeeze(shape_x), one_no_dim)
points = op.Compress(indices, diff, axis=0)
lagp = op.Concat(
op.Slice(points, one, op.Shape(points), zero),
shape_x,
axis=0,
)
counts = op.Sub(lagp, points)
return res, inverse, counts


@torch_op("aten::_unique", trace_only=True)
Expand Down
45 changes: 45 additions & 0 deletions tests/function_libs/torch_lib/e2e_ops_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,51 @@ def forward(self, x):
onnx_program = torch.onnx.export(model, (x,), dynamo=True, verbose=False)
_testing.assert_onnx_program(onnx_program)

def test_aten_unique_consecutive(self):
class Model(torch.nn.Module):
def forward(self, x):
return torch.unique_consecutive(x)

model = Model()
x = torch.tensor([0, 1, 2, 2, 3, 3, 0, 0], dtype=torch.int64)
onnx_program = torch.onnx.export(
model,
(x,),
dynamic_shapes=({0: "length"},),
dynamo=True,
)
_testing.assert_onnx_program(onnx_program)

def test_aten_unique_consecutive_int32(self):
class Model(torch.nn.Module):
def forward(self, x):
return torch.unique_consecutive(x)

model = Model()
x = torch.tensor([0, 1, 2, 2, 3, 3, 0, 0], dtype=torch.int32)
onnx_program = torch.onnx.export(
model,
(x,),
dynamic_shapes=({0: "length"},),
dynamo=True,
)
_testing.assert_onnx_program(onnx_program)

def test_aten_unique_consecutive_return(self):
class Model(torch.nn.Module):
def forward(self, x):
return torch.unique_consecutive(x, return_inverse=True, return_counts=True)

model = Model()
x = torch.tensor([0, 1, 2, 2, 3, 3, 3, 0, 0], dtype=torch.int64)
onnx_program = torch.onnx.export(
model,
(x,),
dynamic_shapes=({0: "length"},),
dynamo=True,
)
_testing.assert_onnx_program(onnx_program)

def test_aten_stft_1(self):
class Model(torch.nn.Module):
def forward(self, x):
Expand Down
Loading