|
9 | 9 | import onnx |
10 | 10 |
|
11 | 11 | import onnxscript |
| 12 | +from onnxscript import script |
12 | 13 | from onnxscript.rewriter import pattern |
| 14 | +from onnxscript.values import Opset |
13 | 15 |
|
| 16 | +# Create an opset for the custom domain |
| 17 | +opset = Opset("custom.domain", 1) |
14 | 18 |
|
15 | | -def create_model_with_custom_domain(): |
| 19 | + |
| 20 | +@script(opset) |
| 21 | +def create_model_with_custom_domain(input: onnxscript.FLOAT[2, 2]) -> onnxscript.FLOAT[2, 2]: |
16 | 22 | """Create a model with a Relu operation in a custom domain.""" |
17 | | - import onnx |
18 | | - from onnx import helper, TensorProto |
19 | | - |
20 | | - # Create input |
21 | | - input_tensor = helper.make_tensor_value_info('A', TensorProto.FLOAT, [2, 2]) |
22 | | - |
23 | | - # Create output |
24 | | - output_tensor = helper.make_tensor_value_info('result', TensorProto.FLOAT, [2, 2]) |
25 | | - |
26 | | - # Create Relu node with custom domain |
27 | | - relu_node = helper.make_node( |
28 | | - 'Relu', |
29 | | - inputs=['A'], |
30 | | - outputs=['result'], |
31 | | - domain='custom.domain' # Set the custom domain |
32 | | - ) |
33 | | - |
34 | | - # Create the graph |
35 | | - graph = helper.make_graph( |
36 | | - [relu_node], # nodes |
37 | | - 'custom_domain_model', # name |
38 | | - [input_tensor], # inputs |
39 | | - [output_tensor] # outputs |
40 | | - ) |
41 | | - |
42 | | - # Create the model with opset for custom domain |
43 | | - opset_imports = [ |
44 | | - helper.make_opsetid("", 18), # Standard ONNX opset |
45 | | - helper.make_opsetid("custom.domain", 1) # Custom domain opset |
46 | | - ] |
47 | | - |
48 | | - model = helper.make_model(graph, opset_imports=opset_imports) |
49 | | - return model |
50 | | - |
51 | | - |
52 | | -_model = create_model_with_custom_domain() |
| 23 | + return opset.Relu(input) |
| 24 | + |
| 25 | + |
| 26 | +_model = create_model_with_custom_domain.to_model_proto() |
| 27 | +_model = onnx.shape_inference.infer_shapes(_model) |
53 | 28 | onnx.checker.check_model(_model) |
54 | 29 |
|
55 | 30 |
|
|
0 commit comments