44#include " core/providers/iree/iree_ep_runtime.h"
55
66#include " core/session/onnxruntime_cxx_api.h"
7+ #include < iostream>
78
89namespace onnxruntime ::iree_ep_rt {
910
@@ -57,10 +58,18 @@ Session::~Session() {
5758}
5859
5960iree_status_t Session::Initialize () {
60- return iree_runtime_session_create_with_device (
61+ iree_status_t res_status = iree_runtime_session_create_with_device (
6162 instance->instance , &session_options, instance->device ,
6263 iree_runtime_instance_host_allocator (instance->instance ),
6364 &session);
65+ iree_vm_module_t * custom_module = NULL ;
66+ iree_allocator_t host_allocator = iree_allocator_system ();
67+ IREE_CHECK_OK (iree_custom_module_async_create (
68+ iree_runtime_instance_vm_instance (instance->instance ), instance->device ,
69+ host_allocator, &custom_module));
70+ IREE_CHECK_OK (iree_runtime_session_append_module (session, custom_module));
71+ iree_vm_module_release (custom_module);
72+ return res_status;
6473}
6574
6675iree_status_t Session::AppendBytecodeModule (fs::path vmfb_path, std::function<void ()> dispose_callback) {
@@ -147,6 +156,13 @@ iree_hal_element_type_t ConvertOrtElementType(ONNXTensorElementDataType et) {
147156common::Status Session::Call (const char * entrypoint_name, const OrtApi* ort_api, OrtKernelContext* ort_context_c) {
148157 // TODO: This is far from the most efficient way to make a call. Synchronous and copying. We can do
149158 // better but this gets points for simplicity and lets us bootstrap the tests.
159+ iree_vm_list_t * inputs = NULL ;
160+ iree_allocator_t host_allocator = iree_allocator_system ();
161+ IREE_CHECK_OK (iree_vm_list_create (iree_vm_make_undefined_type_def (), 1 ,
162+ host_allocator, &inputs));
163+ iree_vm_list_t * outputs = NULL ;
164+ IREE_CHECK_OK (iree_vm_list_create (iree_vm_make_undefined_type_def (), 1 ,
165+ host_allocator, &outputs));
150166 Ort::KernelContext context (ort_context_c);
151167 SynchronousCall call (session);
152168 ORT_RETURN_IF_ERROR (HandleIREEStatus (call.InitializeByName (entrypoint_name)));
@@ -161,8 +177,10 @@ common::Status Session::Call(const char* entrypoint_name, const OrtApi* ort_api,
161177
162178 // Process inputs. We could be smarter about this in a lot of ways, including carrying
163179 // more state from compilation so we are doing less munging here.
164- for (size_t i = 0 ; i < context.GetInputCount (); ++i) {
165- auto input_tensor = context.GetInput (i);
180+
181+ std::cout<<" input count: " <<context.GetInputCount ()<<" \n " ;
182+ // for (size_t i = 0; i < context.GetInputCount(); ++i) {
183+ auto input_tensor = context.GetInput (0 );
166184 ORT_ENFORCE (input_tensor.IsTensor ());
167185
168186 // The device type is rather... sparse... CPU, GPU and FPGA. Not sure how that
@@ -207,13 +225,45 @@ common::Status Session::Call(const char* entrypoint_name, const OrtApi* ort_api,
207225 // Buffer view + storage are returned and owned by the caller:
208226 &arg.bv )));
209227
228+ iree_vm_ref_t input_view_ref = iree_hal_buffer_view_move_ref (arg.bv );
229+ IREE_CHECK_OK (iree_vm_list_push_ref_move (inputs, &input_view_ref));
230+
231+ iree_hal_semaphore_t * semaphore = NULL ;
232+ IREE_CHECK_OK (iree_hal_semaphore_create (
233+ device, 0ull , IREE_HAL_SEMAPHORE_FLAG_NONE, &semaphore));
234+ iree_hal_fence_t * fence_t1 = NULL ;
235+ IREE_CHECK_OK (
236+ iree_hal_fence_create_at (semaphore, 1ull , host_allocator, &fence_t1));
237+ iree_hal_fence_t * fence_t2 = NULL ;
238+ IREE_CHECK_OK (
239+ iree_hal_fence_create_at (semaphore, 2ull , host_allocator, &fence_t2));
240+ iree_hal_semaphore_release (semaphore);
241+ std::cout<<" \n semaphore released" ;
242+ iree_vm_ref_t fence_t1_ref = iree_hal_fence_retain_ref (fence_t1);
243+ std::cout<<" \n semaphore released1" ;
244+ IREE_CHECK_OK (iree_vm_list_push_ref_move (inputs, &fence_t1_ref));
245+ std::cout<<" \n semaphore released2" ;
246+ iree_vm_ref_t fence_t2_ref = iree_hal_fence_retain_ref (fence_t2);
247+ std::cout<<" \n semaphore released3" ;
248+ IREE_CHECK_OK (iree_vm_list_push_ref_move (inputs, &fence_t2_ref));
249+ std::cout<<" \n semaphore released4" ;
250+ IREE_CHECK_OK (iree_hal_fence_signal (fence_t1));
251+ std::cout<<" \n T=1 reached" ;
210252 // Add it to the call.
211- iree_status_t status = iree_runtime_call_inputs_push_back_buffer_view (&call.call , arg.bv );
212- ORT_RETURN_IF_ERROR (HandleIREEStatus (status));
213- }
253+ iree_string_view_t entry_point = iree_make_cstring_view (entrypoint_name);
254+ IREE_CHECK_OK (
255+ iree_runtime_session_call_by_name (session, entry_point, inputs, outputs));
256+ // We could go do other things now while the async work progresses. Here we
257+ // just immediately wait.
258+ IREE_CHECK_OK (iree_hal_fence_wait (fence_t2, iree_infinite_timeout ()));
259+ std::cout<<" \n T=2 reached" ;
260+ // iree_status_t status = iree_runtime_call_inputs_push_back_buffer_view(&call.call, arg.bv);
261+ // ORT_RETURN_IF_ERROR(HandleIREEStatus(status));
262+ // }
263+ // Read back the tensor<?xi32> result:
214264
215265 // Invoke.
216- ORT_RETURN_IF_ERROR (HandleIREEStatus (iree_runtime_call_invoke (&call.call , /* flags=*/ 0 )));
266+ // ORT_RETURN_IF_ERROR(HandleIREEStatus(iree_runtime_call_invoke(&call.call, [> flags=<] 0)));
217267
218268 // Marshal the outputs.
219269 // TODO: Accessing the ORT output requires the shape and then we could get zero copy
@@ -222,16 +272,19 @@ common::Status Session::Call(const char* entrypoint_name, const OrtApi* ort_api,
222272 // convention, which allows passing in slabs of result buffers. Further, that would
223273 // run the host-side computation (which would compute output metadata) inline.
224274 // For static cases, we could also side-load the shape from the compile time.
225- std::vector<int64_t > shape;
226- for (size_t i = 0 ; i < context.GetOutputCount (); ++i) {
275+ // std::vector<int64_t> shape;
276+ std::cout<<" output count: " <<context.GetOutputCount ()<<" \n " ;
277+ // for (size_t i = 0; i < context.GetOutputCount(); ++i) {
227278 HalBufferView ret;
228- ORT_RETURN_IF_ERROR (HandleIREEStatus (
229- iree_runtime_call_outputs_pop_front_buffer_view (&call.call , &ret.bv )));
279+ ret.bv = iree_vm_list_get_buffer_view_assign (outputs, 0 );
280+ // ORT_RETURN_IF_ERROR(HandleIREEStatus(
281+ // iree_runtime_call_outputs_pop_front_buffer_view(&call.call, &ret.bv)));
230282 size_t ret_rank = iree_hal_buffer_view_shape_rank (ret.bv );
231283 const iree_hal_dim_t * ret_dims = iree_hal_buffer_view_shape_dims (ret.bv );
284+ shape.clear ();
232285 shape.resize (ret_rank);
233286 std::copy (ret_dims, ret_dims + ret_rank, shape.begin ());
234- auto output_tensor = context.GetOutput (i , shape.data (), shape.size ());
287+ auto output_tensor = context.GetOutput (0 , shape.data (), shape.size ());
235288 ORT_ENFORCE (output_tensor.IsTensor ());
236289
237290 iree_hal_buffer_t * ret_buffer = iree_hal_buffer_view_buffer (ret.bv );
@@ -250,8 +303,12 @@ common::Status Session::Call(const char* entrypoint_name, const OrtApi* ort_api,
250303 ORT_RETURN_IF_ERROR (HandleIREEStatus (iree_hal_buffer_map_read (ret_buffer, /* source_offset=*/ 0 ,
251304 output_tensor.GetTensorMutableRawData (),
252305 iree_hal_buffer_view_byte_length (ret.bv ))));
253- }
306+ // }
254307
308+ iree_vm_list_release (inputs);
309+ iree_vm_list_release (outputs);
310+ iree_hal_fence_release (fence_t1);
311+ iree_hal_fence_release (fence_t2);
255312 return common::Status::OK ();
256313}
257314
0 commit comments