@@ -42,6 +42,21 @@ def ref_constant(
42
42
)
43
43
else :
44
44
raise Exception ("only support splat value now" )
45
+ elif isinstance (value , gc_mlir ._mlir_libs ._mlir .ir .IntegerAttr ):
46
+ return (
47
+ torch .full (size = tuple (), fill_value = value .__int__ (), dtype = torch .int ),
48
+ )
49
+ elif isinstance (value , gc_mlir ._mlir_libs ._mlir .ir .DenseIntElementsAttr ):
50
+ if value .is_splat :
51
+ return (
52
+ torch .full (
53
+ size = tuple (value .type .shape ),
54
+ fill_value = value .get_splat_value ().value ,
55
+ dtype = benchgc .util .get_dtype (str (value .get_splat_value ().type )),
56
+ ),
57
+ )
58
+ else :
59
+ raise Exception ("only support splat value now" )
45
60
else :
46
61
raise Exception ("Not support constant type %s" , type (value ))
47
62
@@ -56,3 +71,34 @@ def ref_addf(
56
71
cache : MLIRCache , op : gc_mlir .ir .OpView , var : Dict [str , torch .Tensor ]
57
72
) -> Tuple [torch .Tensor , ...]:
58
73
return (var [cache .opr [0 ]] + var [cache .opr [1 ]],)
74
+
75
+ def ref_maxf (
76
+ cache : MLIRCache , op : gc_mlir .ir .OpView , var : Dict [str , torch .Tensor ]
77
+ ) -> Tuple [torch .Tensor , ...]:
78
+ return (torch .max (var [cache .opr [0 ]], var [cache .opr [1 ]]),)
79
+
80
+ def ref_minf (
81
+ cache : MLIRCache , op : gc_mlir .ir .OpView , var : Dict [str , torch .Tensor ]
82
+ ) -> Tuple [torch .Tensor , ...]:
83
+ return (torch .min (var [cache .opr [0 ]], var [cache .opr [1 ]]),)
84
+
85
+ def ref_muli (
86
+ cache : MLIRCache , op : gc_mlir .ir .OpView , var : Dict [str , torch .Tensor ]
87
+ ) -> Tuple [torch .Tensor , ...]:
88
+ return (var [cache .opr [0 ]] * var [cache .opr [1 ]],)
89
+
90
+
91
+ def ref_addi (
92
+ cache : MLIRCache , op : gc_mlir .ir .OpView , var : Dict [str , torch .Tensor ]
93
+ ) -> Tuple [torch .Tensor , ...]:
94
+ return (var [cache .opr [0 ]] + var [cache .opr [1 ]],)
95
+
96
+ def ref_maxsi (
97
+ cache : MLIRCache , op : gc_mlir .ir .OpView , var : Dict [str , torch .Tensor ]
98
+ ) -> Tuple [torch .Tensor , ...]:
99
+ return (torch .max (var [cache .opr [0 ]], var [cache .opr [1 ]]),)
100
+
101
+ def ref_minsi (
102
+ cache : MLIRCache , op : gc_mlir .ir .OpView , var : Dict [str , torch .Tensor ]
103
+ ) -> Tuple [torch .Tensor , ...]:
104
+ return (torch .min (var [cache .opr [0 ]], var [cache .opr [1 ]]),)
0 commit comments