From 8b79c1bfc5f6e3f30c870f63ceb4d3a8bd9866e8 Mon Sep 17 00:00:00 2001 From: Sunoru Date: Thu, 27 Apr 2023 11:33:22 -0400 Subject: [PATCH] Avoid type piracy related to `__m128i`. (Fix #16) --- src/aesni.jl | 18 +++++++++--------- src/aesni_common.jl | 37 ++++++++++++++++++++----------------- 2 files changed, 29 insertions(+), 26 deletions(-) diff --git a/src/aesni.jl b/src/aesni.jl index 0ecad5d..f4f93fe 100644 --- a/src/aesni.jl +++ b/src/aesni.jl @@ -28,17 +28,17 @@ copy(src::AESNIKey) = copyto!(AESNIKey(), src) """ Assistant function for AES128. Compiled from the C++ source code: ```cpp -R123_STATIC_INLINE __m128i AES_128_ASSIST (__m128i temp1, __m128i temp2) { - __m128i temp3; - temp2 = _mm_shuffle_epi32 (temp2 ,0xff); +R123_STATIC_INLINE __m128i AES_128_ASSIST (__m128i temp1, __m128i temp2) { + __m128i temp3; + temp2 = _mm_shuffle_epi32 (temp2 ,0xff); temp3 = _mm_slli_si128 (temp1, 0x4); temp1 = _mm_xor_si128 (temp1, temp3); temp3 = _mm_slli_si128 (temp3, 0x4); temp1 = _mm_xor_si128 (temp1, temp3); temp3 = _mm_slli_si128 (temp3, 0x4); temp1 = _mm_xor_si128 (temp1, temp3); - temp1 = _mm_xor_si128 (temp1, temp2); - return temp1; + temp1 = _mm_xor_si128 (temp1, temp2); + return temp1; } ``` """ @@ -58,9 +58,9 @@ _aes_128_assist(a::__m128i, b::__m128i) = llvmcall( %15 = xor <2 x i64> %12, %5 %16 = xor <2 x i64> %15, %14 ret <2 x i64> %16""", - __m128i, Tuple{__m128i, __m128i}, - a, b -) + __m128i_lvec, Tuple{__m128i_lvec, __m128i_lvec}, + a.data, b.data +) |> __m128i function _aesni_expand!(k::AESNIKey, rkey::__m128i) k.key1 = rkey @@ -230,7 +230,7 @@ end """ aesni(key::NTuple{11,UInt128}, ctr::Tuple{UInt128})::Tuple{UInt128} -Functional variant of [`AESNI1x`](@ref) and [`AESNI4x`](@ref). +Functional variant of [`AESNI1x`](@ref) and [`AESNI4x`](@ref). This function if free of mutability and side effects. """ @inline function aesni(key::NTuple{11,UInt128}, ctr::Tuple{UInt128})::Tuple{UInt128} diff --git a/src/aesni_common.jl b/src/aesni_common.jl index 4079cdd..fbb440d 100644 --- a/src/aesni_common.jl +++ b/src/aesni_common.jl @@ -4,7 +4,10 @@ import Base.(+) using ..Random123: R123Generator1x, R123Generator4x import ..Random123: random123_r, set_counter! -const __m128i = NTuple{2, VecElement{UInt64}} +const __m128i_lvec = NTuple{2, VecElement{UInt64}} +struct __m128i + data::__m128i_lvec +end Base.convert(::Type{__m128i}, x::UInt128) = unsafe_load(Ptr{__m128i}(pointer_from_objref(Ref(x)))) Base.convert(::Type{UInt128}, x::__m128i) = unsafe_load(Ptr{UInt128}(pointer_from_objref(Ref(x)))) UInt128(x::__m128i) = convert(UInt128, x) @@ -13,42 +16,42 @@ Base.convert(::Type{__m128i}, x::Union{Signed, Unsigned}) = convert(__m128i, UIn Base.convert(::Type{T}, x::__m128i) where T <: Union{Signed, Unsigned} = convert(T, UInt128(x)) const LITTLE_ENDIAN = ENDIAN_BOM ≡ 0x04030201 -__m128i(hi::UInt64, lo::UInt64) = LITTLE_ENDIAN ? (VecElement(lo), VecElement(hi)) : (VecElement(hi), VecElement(lo)) +__m128i(hi::UInt64, lo::UInt64) = LITTLE_ENDIAN ? __m128i((VecElement(lo), VecElement(hi))) : __m128i((VecElement(hi), VecElement(lo))) Base.zero(::Type{__m128i}) = __m128i(zero(UInt64), zero(UInt64)) Base.one(::Type{__m128i}) = __m128i(zero(UInt64), one(UInt64)) Base.xor(a::__m128i, b::__m128i) = llvmcall( """%3 = xor <2 x i64> %1, %0 ret <2 x i64> %3""", - __m128i, Tuple{__m128i, __m128i}, - a, b -) + __m128i_lvec, Tuple{__m128i_lvec, __m128i_lvec}, + a.data, b.data +) |> __m128i (+)(a::__m128i, b::__m128i) = llvmcall( """%3 = add <2 x i64> %1, %0 ret <2 x i64> %3""", - __m128i, Tuple{__m128i, __m128i}, - a, b -) + __m128i_lvec, Tuple{__m128i_lvec, __m128i_lvec}, + a.data, b.data +) |> __m128i (+)(a::__m128i, b::Integer) = a + __m128i(UInt128(b)) _aes_enc(a::__m128i, round_key::__m128i) = ccall( "llvm.x86.aesni.aesenc", llvmcall, - __m128i, (__m128i, __m128i), - a, round_key -) + __m128i_lvec, (__m128i_lvec, __m128i_lvec), + a.data, round_key.data +) |> __m128i _aes_enc_last(a::__m128i, round_key::__m128i) = ccall( "llvm.x86.aesni.aesenclast", llvmcall, - __m128i, (__m128i, __m128i), - a, round_key -) + __m128i_lvec, (__m128i_lvec, __m128i_lvec), + a.data, round_key.data +) |> __m128i _aes_key_gen_assist(a::__m128i, ::Val{R}) where R = ccall( "llvm.x86.aesni.aeskeygenassist", llvmcall, - __m128i, (__m128i, UInt8), - a, R -) + __m128i_lvec, (__m128i_lvec, UInt8), + a.data, R +) |> __m128i "Abstract RNG that generates one number at a time and is based on AESNI." abstract type AbstractAESNI1x <: R123Generator1x{UInt128} end