Skip to content
Closed
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 153 additions & 0 deletions test/onnx/expect/TestOperators.test_baddbmm.expect
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
ir_version: 4
producer_name: "pytorch"
producer_version: "1.2"
graph {
node {
input: "1"
input: "2"
output: "3"
op_type: "MatMul"
}
node {
output: "4"
op_type: "Constant"
attribute {
name: "value"
t {
data_type: 7
raw_data: "\001\000\000\000\000\000\000\000"
}
type: TENSOR
}
}
node {
input: "4"
output: "5"
op_type: "Cast"
attribute {
name: "to"
i: 1
type: INT
}
}
node {
input: "3"
input: "5"
output: "6"
op_type: "Mul"
}
node {
output: "7"
op_type: "Constant"
attribute {
name: "value"
t {
data_type: 7
raw_data: "\001\000\000\000\000\000\000\000"
}
type: TENSOR
}
}
node {
input: "7"
output: "8"
op_type: "Cast"
attribute {
name: "to"
i: 1
type: INT
}
}
node {
input: "0"
input: "8"
output: "9"
op_type: "Mul"
}
node {
input: "6"
input: "9"
output: "10"
op_type: "Add"
}
name: "torch-jit-export"
input {
name: "0"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 10
}
dim {
dim_value: 3
}
dim {
dim_value: 5
}
}
}
}
}
input {
name: "1"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 10
}
dim {
dim_value: 3
}
dim {
dim_value: 4
}
}
}
}
}
input {
name: "2"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 10
}
dim {
dim_value: 4
}
dim {
dim_value: 5
}
}
}
}
}
output {
name: "10"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 10
}
dim {
dim_value: 3
}
dim {
dim_value: 5
}
}
}
}
}
}
opset_import {
version: 9
}
6 changes: 6 additions & 0 deletions test/onnx/test_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -743,6 +743,12 @@ def test_unique(self):
self.assertONNX(lambda x: torch.unique(x, dim=0, sorted=True, return_inverse=False, return_counts=True), x,
opset_version=11)

def test_baddbmm(self):
x = torch.randn(10, 3, 5)
b1 = torch.randn(10, 3, 4)
b2 = torch.randn(10, 4, 5)
self.assertONNX(lambda x, b1, b2: torch.baddbmm(x, b1, b2), (x, b1, b2))

def test_round(self):
x = torch.tensor([0.9920, -1.0362, -1.5000, 2.5000], requires_grad=True)
self.assertONNX(lambda x: torch.round(x), x, opset_version=11)
Expand Down
9 changes: 9 additions & 0 deletions test/onnx/test_pytorch_onnx_caffe2.py
Original file line number Diff line number Diff line change
Expand Up @@ -2186,6 +2186,15 @@ def forward(self, x):
x = torch.arange(16).view(2, 2, 4).to(torch.float32)
self.run_model_test(MaskedFillModel2(), input=(x, ), train=False, batch_size=BATCH_SIZE)

def test_baddbmm(self):
class MyModule(torch.nn.Module):
def forward(self, input, batch1, batch2):
return torch.baddbmm(input, batch1, batch2, alpha=torch.tensor(5), beta=3.5)
x = torch.randn(10, 3, 5)
batch1 = torch.randn(10, 3, 4)
batch2 = torch.randn(10, 4, 5)
self.run_model_test(MyModule(), input=(x, batch1, batch2), train=False, batch_size=BATCH_SIZE)

@skipIfUnsupportedMinOpsetVersion(9)
def test_gelu(self):
class GeluModel(torch.nn.Module):
Expand Down
22 changes: 22 additions & 0 deletions test/onnx/test_pytorch_onnx_onnxruntime.py
Original file line number Diff line number Diff line change
Expand Up @@ -964,6 +964,28 @@ def forward(self, input):
model = CumSum()
self.run_test(model, x)

def test_baddbmm(self):
class MyModule(torch.nn.Module):
def forward(self, input, batch1, batch2):
return torch.baddbmm(input, batch1, batch2, alpha=torch.tensor(5), beta=3.5)
x = torch.randn(10, 3, 5)
batch1 = torch.randn(10, 3, 4)
batch2 = torch.randn(10, 4, 5)
model = MyModule()
self.run_test(model, (x, batch1, batch2))

def test_baddbmm_dynamic(self):
class MyModule(torch.nn.Module):
def forward(self, input, batch1, batch2, alpha, beta):
return torch.baddbmm(input, batch1, batch2, alpha=alpha, beta=beta)
x = torch.randn(10, 3, 5)
batch1 = torch.randn(10, 3, 4)
batch2 = torch.randn(10, 4, 5)
alpha = torch.tensor(5)
beta = torch.tensor(3.5)
model = MyModule()
self.run_test(model, (x, batch1, batch2, alpha, beta))

def test_log(self):
class Log(torch.nn.Module):
def forward(self, input):
Expand Down
7 changes: 7 additions & 0 deletions torch/onnx/symbolic_opset9.py
Original file line number Diff line number Diff line change
Expand Up @@ -1904,6 +1904,13 @@ def multinomial(g, input, num_samples, replacement=False, generator=None):
dtype_i=sym_help.cast_pytorch_to_onnx['Long'],
sample_size_i=num_samples)

@parse_args('v', 'v', 'v', 't', 't')
def baddbmm(g, self, batch1, batch2, beta, alpha):
dtype = self.type().scalarType()
batch_mul = matmul(g, batch1, batch2)
mul_a = mul(g, batch_mul, g.op("Cast", alpha, to_i=sym_help.cast_pytorch_to_onnx[dtype]))
mul_b = mul(g, self, g.op("Cast", beta, to_i=sym_help.cast_pytorch_to_onnx[dtype]))
return add(g, mul_a, mul_b)

def gelu(g, self):
_sqrt2 = 1.4142135623730951
Expand Down