Skip to content

Commit fa01d3e

Browse files
committed
Provide safe function wrappers
This tries to address concerns raised in rust-lang#56 in different ways with different tradeoffs: 1. Implicit `Deref` to the underlying function pointer is replace with an explicit unsafe method `into_inner`, but in most cases it shouldn't be necessary to use. 2. Provides function `Symbol::call` which guarantees that (a) backing storage will be alive as long as it's required for calling and (b) preserves `unsafe`ty for the public interface. This is done at the cost of less convenient invocation syntax (`symbol.call(a,b,c,...)` instead of `symbol(a,b,c,...)`). 3. Provides `Symbol::into_closure` which allows to cast a `Symbol` into an opaque Rust closure. This also guarantees that the backing storage will be kept alive as long as it's required for calling, and provides a more convenient syntax (you can use normal function calls like `closure(a,b,c,...)`) at the cost of losing `unsafe` marker in the interface of the resulting function (although conversion itself still has `unsafe`). You might argue that this is unacceptable tradeoff for a public API, but, unfortunately, we don't have `UnsafeFn` family of traits yet, and also there is lots of existing precedents in stdlib that do similar "pretend that it's safe" conversions like `Box::from_raw`, `Rc::from_raw`, `slice::from_raw_parts`, `str::from_utf8_unchecked`, etc.
1 parent e185481 commit fa01d3e

File tree

5 files changed

+48
-19
lines changed

5 files changed

+48
-19
lines changed

examples/jit.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,8 @@ fn run() -> Result<(), Box<Error>> {
5454
let z = 3u64;
5555

5656
unsafe {
57-
println!("{} + {} + {} = {}", x, y, z, sum(x, y, z));
58-
assert_eq!(sum(x, y, z), x + y + z);
57+
println!("{} + {} + {} = {}", x, y, z, sum.call(x, y, z));
58+
assert_eq!(sum.call(x, y, z), x + y + z);
5959
}
6060

6161
Ok(())

examples/kaleidoscope/main.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -1330,7 +1330,7 @@ pub fn main() {
13301330
};
13311331

13321332
unsafe {
1333-
println!("=> {}", compiled_fn());
1333+
println!("=> {}", compiled_fn.call());
13341334
}
13351335
}
13361336
}

src/execution_engine.rs

+35-6
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,7 @@ impl ExecutionEngine {
270270
/// // fetch our JIT'd function and execute it
271271
/// unsafe {
272272
/// let test_fn = ee.get_function::<unsafe extern "C" fn() -> f64>("test_fn").unwrap();
273-
/// let return_value = test_fn();
273+
/// let return_value = test_fn.call();
274274
/// assert_eq!(return_value, 64.0);
275275
/// }
276276
/// ```
@@ -440,11 +440,17 @@ pub struct Symbol<F> {
440440
inner: F,
441441
}
442442

443-
impl<F: UnsafeFunctionPointer> Deref for Symbol<F> {
444-
type Target = F;
445-
446-
fn deref(&self) -> &Self::Target {
447-
&self.inner
443+
impl<F: UnsafeFunctionPointer> Symbol<F> {
444+
/// This method allows to retrieve the internal function pointer.
445+
///
446+
/// This is highly unsafe because as soon as you get it, it's up to you
447+
/// to make sure that the [`ExecutionEngine`] is not dropped earlier than
448+
/// the returned function.
449+
///
450+
/// It's almost always better to use either [`Symbol::call`] or
451+
/// [`Symbol::to_closure`] which will uphold the memory guarantee for you.
452+
pub unsafe fn into_inner(&self) -> F {
453+
self.inner
448454
}
449455
}
450456

@@ -470,7 +476,30 @@ mod private {
470476
macro_rules! impl_unsafe_fn {
471477
($( $param:ident ),*) => {
472478
impl<Output, $( $param ),*> private::Sealed for unsafe extern "C" fn($( $param ),*) -> Output {}
479+
473480
impl<Output, $( $param ),*> UnsafeFunctionPointer for unsafe extern "C" fn($( $param ),*) -> Output {}
481+
482+
impl<Output, $( $param ),*> Symbol<unsafe extern "C" fn($( $param ),*) -> Output> {
483+
/// This method allows to call the underlying function while making
484+
/// sure that the backing storage is not dropped too early and
485+
/// preserves the `unsafe` marker for any calls.
486+
#[allow(non_snake_case)]
487+
#[inline(always)]
488+
pub unsafe fn call(&self, $( $param: $param ),*) -> Output {
489+
(self.inner)($( $param ),*)
490+
}
491+
492+
/// This method allows to cast [`Symbol`] to an opaque Rust closure.
493+
///
494+
/// Unlike [`Symbol::into_inner`], it still automatically keeps the backing
495+
/// storage alive as long as it's needed, however it "forgets"
496+
/// that function is actually `unsafe`.
497+
pub unsafe fn into_closure(self) -> impl Fn($( $param ),*) -> Output {
498+
#[allow(non_snake_case)]
499+
#[inline(always)]
500+
move |$( $param ),*| (self.inner)($( $param ),*)
501+
}
502+
}
474503
};
475504
}
476505

tests/all/test_builder.rs

+9-9
Original file line numberDiff line numberDiff line change
@@ -137,14 +137,14 @@ fn test_null_checked_ptr_ops() {
137137
let execution_engine = module.create_jit_execution_engine(OptimizationLevel::None).unwrap();
138138

139139
unsafe {
140-
let check_null_index1: Symbol<unsafe extern "C" fn(*const i8) -> i8> = execution_engine.get_function("check_null_index1").unwrap();
140+
let check_null_index1 = execution_engine.get_function::<unsafe extern "C" fn(*const i8) -> i8>("check_null_index1").unwrap().into_closure();
141141

142142
let array = &[100i8, 42i8];
143143

144144
assert_eq!(check_null_index1(null()), -1i8);
145145
assert_eq!(check_null_index1(array.as_ptr()), 42i8);
146146

147-
let check_null_index2: Symbol<unsafe extern "C" fn(*const i8) -> i8> = execution_engine.get_function("check_null_index2").unwrap();
147+
let check_null_index2 = execution_engine.get_function::<unsafe extern "C" fn(*const i8) -> i8>("check_null_index2").unwrap().into_closure();
148148

149149
assert_eq!(check_null_index2(null()), -1i8);
150150
assert_eq!(check_null_index2(array.as_ptr()), 42i8);
@@ -216,9 +216,9 @@ fn test_binary_ops() {
216216
unsafe {
217217
type BoolFunc = unsafe extern "C" fn(bool, bool) -> bool;
218218

219-
let and: Symbol<BoolFunc> = execution_engine.get_function("and").unwrap();
220-
let or: Symbol<BoolFunc> = execution_engine.get_function("or").unwrap();
221-
let xor: Symbol<BoolFunc> = execution_engine.get_function("xor").unwrap();
219+
let and = execution_engine.get_function::<BoolFunc>("and").unwrap().into_closure();
220+
let or = execution_engine.get_function::<BoolFunc>("or").unwrap().into_closure();
221+
let xor = execution_engine.get_function::<BoolFunc>("xor").unwrap().into_closure();
222222

223223
assert!(!and(false, false));
224224
assert!(!and(true, false));
@@ -287,7 +287,7 @@ fn test_switch() {
287287
builder.build_return(Some(&double));
288288

289289
unsafe {
290-
let switch: Symbol<unsafe extern "C" fn(u8) -> u8> = execution_engine.get_function("switch").unwrap();
290+
let switch = execution_engine.get_function::<unsafe extern "C" fn(u8) -> u8>("switch").unwrap().into_closure();
291291

292292
assert_eq!(switch(0), 1);
293293
assert_eq!(switch(1), 2);
@@ -357,9 +357,9 @@ fn test_bit_shifts() {
357357
builder.build_return(Some(&shift));
358358

359359
unsafe {
360-
let left_shift: Symbol<unsafe extern "C" fn(u8, u8) -> u8> = execution_engine.get_function("left_shift").unwrap();
361-
let right_shift: Symbol<unsafe extern "C" fn(u8, u8) -> u8> = execution_engine.get_function("right_shift").unwrap();
362-
let right_shift_sign_extend: Symbol<unsafe extern "C" fn(i8, u8) -> i8> = execution_engine.get_function("right_shift_sign_extend").unwrap();
360+
let left_shift = execution_engine.get_function::<unsafe extern "C" fn(u8, u8) -> u8>("left_shift").unwrap().into_closure();
361+
let right_shift = execution_engine.get_function::<unsafe extern "C" fn(u8, u8) -> u8>("right_shift").unwrap().into_closure();
362+
let right_shift_sign_extend = execution_engine.get_function::<unsafe extern "C" fn(i8, u8) -> i8>("right_shift_sign_extend").unwrap().into_closure();
363363

364364
assert_eq!(left_shift(0, 0), 0);
365365
assert_eq!(left_shift(0, 4), 0);

tests/all/test_tari_example.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,6 @@ fn test_tari_example() {
3939
let y = 2u64;
4040
let z = 3u64;
4141

42-
assert_eq!(sum(x, y, z), x + y + z);
42+
assert_eq!(sum.call(x, y, z), x + y + z);
4343
}
4444
}

0 commit comments

Comments
 (0)