Skip to content

Pure Julia implementation of OpenLibm's lgamma, lgamma_r #413

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 25 commits into from
Dec 20, 2022

Conversation

andrewjradcliffe
Copy link
Contributor

Why bother with this port? The short answer is that for Float64, there are considerable performance gains to be had. (See https://github.com/andrewjradcliffe/OpenLibmPorts.jl for plots produced on my machine -- I encourage you to reproduce them for yourself).

The long answer can be found in the linked package. To summarize for those unfamiliar with large probabilistic models, log-posterior computations often involve many unavoidable (i.e. cannot be eliminated by using the unnormalized density) special function calls; in particular, log-likelihoods involving the (log)gamma function are not uncommon, e.g. negative binomial, beta-binomial. Given such a likelihood function, each term in the log-likelihood sum would in necessitate a loggamma call, hence, any reduction in the latency in this particular special function translates can have a substantial impact.

It is also a convenience to have a loggamma function which is differentiable using Enzyme (though, there are some gaps in the derivative -- at exactly 0.0, 1.0 and 2.0 -- but that is just a quirk of applying AD to OpenLibm's loggamma implementation).

The author realizes that many things depend on loggamma, and my intention is not to cause headaches for others. Fortunately, by directly porting the OpenLibm implementation to Julia, we achieve the same approximation, albeit, with slightly different rounding due, presumably, to the difference between the implementations of log in Julia and OpenLibm.

Why bother with this port? The short answer is that for `Float64`,
there are considerable performance gains to be had. (See
https://github.com/andrewjradcliffe/OpenLibmPorts.jl for plots
produced on my machine -- I encourage you to reproduce them for yourself).

The long answer can be found in the linked package. To summarize for
those unfamiliar with large probabilistic models, log-posterior
computations often involve many unavoidable (i.e. cannot be eliminated
by using the unnormalized density) special function calls; in
particular, log-likelihoods involving the (log)gamma function are not
uncommon, e.g. negative binomial, beta-binomial. Given such a
likelihood function, each term in the log-likelihood sum would in
necessitate a loggamma call, hence, any reduction in the latency in
this particular special function translates can have a substantial
impact.

It is also a convenience to have a loggamma function which is
differentiable using `Enzyme` (though, there are some gaps in the
derivative -- at exactly 0.0, 1.0 and 2.0 -- but that is just a quirk
of applying AD to OpenLibm's loggamma implementation).

The author realizes that many things depend on `loggamma`, and my
intention is not to cause headaches for others. Fortunately, by
directly porting the OpenLibm implementation to Julia, we
achieve the same approximation, albeit, with slightly different
rounding due, presumably, to the difference between the
implementations of `log` in Julia and OpenLibm.
@codecov
Copy link

codecov bot commented Oct 28, 2022

Codecov Report

Base: 93.63% // Head: 93.93% // Increases project coverage by +0.29% 🎉

Coverage data is based on head (83071be) compared to base (36c547b).
Patch coverage: 100.00% of modified lines in pull request are covered.

Additional details and impacted files
@@            Coverage Diff             @@
##           master     #413      +/-   ##
==========================================
+ Coverage   93.63%   93.93%   +0.29%     
==========================================
  Files          12       14       +2     
  Lines        2767     2903     +136     
==========================================
+ Hits         2591     2727     +136     
  Misses        176      176              
Flag Coverage Δ
unittests 93.93% <100.00%> (+0.29%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
src/SpecialFunctions.jl 100.00% <ø> (ø)
src/gamma.jl 95.04% <ø> (-0.11%) ⬇️
src/logabsgamma/e_lgamma_r.jl 100.00% <100.00%> (ø)
src/logabsgamma/e_lgammaf_r.jl 100.00% <100.00%> (ø)

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

☔ View full report at Codecov.
📢 Do you have feedback about the report comment? Let us know in this issue.

@andrewjradcliffe
Copy link
Contributor Author

To complete this port, I will need to replicate the tests from OpenLibm. Specifically, the relevant parts of

Addition of the tests will ensure the same consistency which OpenLimb has delivered; they will also address the code coverage issue.

@heltonmc
Copy link
Member

heltonmc commented Nov 2, 2022

Thank you for making this port! This is a pretty important function so would be nice to have in Julia. To get the best reviews it will be important to have the ULP error characterized over the whole range. The most important sections for loggamma will be for x < 2.0. The asymptotic expansions are straightforward to get around 1 ULP and for medium arguments the reduction down to the polynomial approximation are also numerically stable. The issue is that shifting the range forward is actually slightly unstable which can increase the ULP larger than 1.5. It does look like this implementation takes care of that but having the ULP characterized will be important. With that said I'll give some thoughts to move this forward which mainly stem from being more of C style.

Major:

  1. All the polynomial evaluations should use the evalpoly function (see any function in https://github.com/JuliaMath/SpecialFunctions.jl/blob/master/src/gamma.jl). This has two main advantages that remove all the global constants and also I've at least verified that the rational functions take advantage of SIMD instructions on different architectures. We really shouldn't be having global constants here I don't think if it can be avoided.
  2. Remove all semicolons ; from end of lines.
  3. Is it really worth having two separate implementations for the sign? The code seems to be identical and I'm not for sure if there are any performance hits to just have loggamma call the underlining function that also computes the sign. Along the same lines the Float64 and Float32 code are also similar with different coefficients.... This seems like a lot of code duplication that I wonder if we can use dispatch to help out with. The function naming I think should also be changed. The main functions should all be _loggamma there is no reason to have different names for different precisions.

Personal preferences:

  1. There are a lot of branches here so the microbenchmarks need to be set up with care. I think having a benchmark that can make sure the penalties of all these branches are observed will be helpful. For example for the x<8 range uses a lot all of branches and goto and label statements which I think could be improved. Having this just be a while statement will get rid of all those branches and I believe would be faster.

@heltonmc
Copy link
Member

heltonmc commented Nov 2, 2022

I implemented some of my suggestions which look like....

Code

function _loggamma(x::Float64)
    u = reinterpret(UInt64, x)
    hx = (u >>> 32) % Int32
    lx = u % Int32

    #= purge off +-inf, NaN, +-0, tiny and negative arguments =#
    signgamp = Int32(1)
    ix = signed(hx & 0x7fffffff)
    ix  0x7ff00000 && return x * x, signgamp
    ix | lx == 0 && return 1.0 / 0.0, signgamp
    if ix < 0x3b900000 #= |x|<2**-70, return -log(|x|) =#
        if hx < 0
            signgamp = Int32(-1)
            return -log(-x), signgamp
        else
            return -log(x), signgamp
        end
    end
    if hx < 0
        ix  0x43300000 && return 1.0 / 0.0, signgamp #= |x|>=2**52, must be -integer =#
        t = sinpi(x)
        t == 0.0 && return 1.0 / 0.0, signgamp #= -integer =#
        nadj = log/ abs(t * x))
        if t < 0.0; signgamp = Int32(-1); end
        x = -x
    end
    
    if ix <= 0x40000000     #= for x < 2.0 =#
        fpart, ipart = modf(x)
        if iszero(fpart)
            return 0.0, signgamp
        elseif isone(ipart)
            r = 0.0
            _y = 1.0
        else
            r = -log(x)
            _y = 0.0
        end
        
        if fpart  0.7316
            y = (1.0 + _y) - x
            z = y*y
            p1 = evalpoly(z, (7.72156649015328655494e-02, 6.73523010531292681824e-02, 7.38555086081402883957e-03, 1.19270763183362067845e-03, 2.20862790713908385557e-04, 2.52144565451257326939e-05))
            p2 = z * evalpoly(z, (3.22467033424113591611e-01, 2.05808084325167332806e-02, 2.89051383673415629091e-03, 5.10069792153511336608e-04, 1.08011567247583939954e-04, 4.48640949618915160150e-05))
            p  = muladd(p1, y, p2)
            r  += muladd(y, -0.5, p)
        elseif fpart  0.23163999999
            y = x - (1.46163214496836224576 - (1.0 - _y))
            z = y*y
            w = z*y
            p1 = evalpoly(w, (4.83836122723810047042e-01, -3.27885410759859649565e-02, 6.10053870246291332635e-03, -1.40346469989232843813e-03, 3.15632070903625950361e-04))
            p2 = evalpoly(w, (-1.47587722994593911752e-01, 1.79706750811820387126e-02, -3.68452016781138256760e-03, 8.81081882437654011382e-04, -3.12754168375120860518e-04))
            p3 = evalpoly(w, (6.46249402391333854778e-02, -1.03142241298341437450e-02, 2.25964780900612472250e-03, -5.38595305356740546715e-04, 3.35529192635519073543e-04))
            p = muladd(w, -muladd(p3, y, p2), -3.63867699703950536541e-18)
            p = muladd(z, p1, -p)
            r += p - 1.21486290535849611461e-1
        else
            y = x - _y
            p1 = y * evalpoly(y, (-7.72156649015328655494e-02, 6.32827064025093366517e-01, 1.45492250137234768737, 9.77717527963372745603e-01, 2.28963728064692451092e-01, 1.33810918536787660377e-02))
            p2 = evalpoly(y, (1.0, 2.45597793713041134822, 2.12848976379893395361, 7.69285150456672783825e-01, 1.04222645593369134254e-01, 3.21709242282423911810e-03))
		    r += muladd(y, -0.5, p1 / p2)
        end
    elseif ix < 0x40200000              #= x < 8.0 =#
        i = Base.unsafe_trunc(Int8, x)
        y = x - float(i)
	    z = 1.0
        p = 0.0
        u = x
        while u >= 3.0
            p -= 1.0
            u = x + p
            z *= u
        end
        p = y * evalpoly(y, (-7.72156649015328655494e-2, 2.14982415960608852501e-1, 3.25778796408930981787e-1, 1.46350472652464452805e-1, 2.66422703033638609560e-2, 1.84028451407337715652e-3, 3.19475326584100867617e-5))
        q = evalpoly(y, (1.0, 1.39200533467621045958, 7.21935547567138069525e-1, 1.71933865632803078993e-1, 1.86459191715652901344e-2, 7.77942496381893596434e-4, 7.32668430744625636189e-6))
        r = log(z) + 0.5*y + p / q
    elseif ix < 0x43900000              #= 8.0 ≤ x < 2^58 =#
        r = (x - 0.5) * (log(x) - 1.0)
        z = inv(x)
        y = z * z
        w = 4.18938533204672725052e-1
        w += z * evalpoly(y, (8.33333333333329678849e-2, -2.77777777728775536470e-3, 7.93650558643019558500e-4, -5.95187557450339963135e-4, 8.36339918996282139126e-4, -1.63092934096575273989e-3))
	    r += w
    else #= 2^58 ≤ x ≤ Inf =#
        r = x * (log(x) - 1.0)
    end
    if hx < 0
        r = nadj - r
    end
    return r, signgamp
end

I didn't look too much at the first part yet before the main computation routine so hopefully that could be improved a bit...

Accuracy

The error for this function looks like...

Screenshot from 2022-11-02 17-52-32

So both the routines peak at around 1.5 ULP.....

Benchmarks

function bench(f; N=50000000, Order=false)
    v = rand(N)*200
    if Order
        v = sort(v)
    end
    a = 0.0

    tstart = time()
    for _x in v
        a += f(_x)[1]
    end
    tend = time()
    t = (tend-tstart) / N * 1e9
    return t, a
end

Results (averaged nanoseconds per loggamma evaluation)

# This suggestion
julia> bench(_loggamma, Order=true)[1]
11.009101867675781

julia> bench(_loggamma, Order=false)[1]
11.52916431427002

# SpecialFunctions.jl
julia> bench(SpecialFunctions.loggamma, Order=true)[1]
16.55251979827881

julia> bench(SpecialFunctions.loggamma, Order=false)[1]
17.064623832702637

# implementation in PR
julia> bench(_lgamma_r, Order=true)[1]
12.965197563171387

julia> bench(_lgamma_r, Order=false)[1]
13.158679008483887

So at least with this benchmark. I found the PR to be about ~25% faster than the openlibm port and this suggestion to be about ~35% faster.

@andrewjradcliffe
Copy link
Contributor Author

My objective was to be as faithful to the C as possible, but, I agree, if we can eliminate the goto's (which emulated fallthrough in the C switch), then all the better. Though, the reason I used goto was that it actually happens to be faster than a while loop, at least sometimes.

while loop vs. goto

julia> using BenchmarkTools


julia> function looped(x)
           i = Base.unsafe_trunc(Int8, x)
           y = x - float(i)
               z = 1.0
           p = 0.0
           u = x
           while u >= 3.0
               p -= 1.0
               u = x + p
               z *= u
           end
           z
       end;

julia> function goto(x)
           i = Base.unsafe_trunc(Int8, x)
           y = x - float(i)
               z = 1.0
           if i == Int8(7)
               z *= y + 6.0
               @goto case6
           elseif i == Int8(6)
               @label case6
               z *= y + 5.0
               @goto case5
           elseif i == Int8(5)
               @label case5
               z *= y + 4.0
               @goto case4
           elseif i == Int8(4)
               @label case4
               z *= y + 3.0
               @goto case3
           elseif i == Int8(3)
               @label case3
               z *= y + 2.0
           end
           z
       end;

julia> x7 = 7.9219183157870985;

julia> x6 = 6.9219183157870985;

julia> x5 = 5.9219183157870985;

julia> x4 = 4.9219183157870985;

julia> x3 = 3.9219183157870985;

julia> function bench(x)
           b1 = @benchmark looped($x)
           b2 = @benchmark goto($x)
           b1, b2
       end;

julia> foreach(println, map(bench, (x7, x6, x5, x4, x3)))
(Trial(5.294 ns), Trial(3.918 ns))
(Trial(3.908 ns), Trial(3.620 ns))
(Trial(3.629 ns), Trial(3.529 ns))
(Trial(2.524 ns), Trial(3.569 ns))
(Trial(2.524 ns), Trial(2.795 ns))

From a stylistic standpoint, the selected names for the functions and files were meant to be as clear as possible that they are literally OpenLibm's (e.g. https://github.com/JuliaMath/openlibm/blob/master/src/e_lgamma_r.c). However, as that this PR at least documents their original state, I am satisfied that due credit will be given.

With respect to a separate implementation for Float64 vs. Float32, the plots below demonstrate that Float32 is perhaps a marginal regression. With respect to two implementations (one carrying the sign, and the other throwing a DomainError instead), the benchmark (see plots below) seem to indicate that it is worthwhile. Typically, I am against code duplication, but ≈10% reduction seems worthwhile.

I agree that eliminating the const's is for the best. I had not looked at the assembly until now, but, indeed, muladd (nested muladd being produced by evalpoly) for IEEE floats specifically calls Base.muladd_float, which, seemingly, is identical to mymuladd(a, b, c) = @fastmath a * b + c -- at least, they both produce the same instruction (vfmadd213sd for me).
Interesting to learn that every a * b + c should either be @fastmath'd or re-written as a muladd :/

assembly compare

julia> mymuladd(a, b, c) = @fastmath a * b + c
mymuladd (generic function with 1 method)

julia> @code_native mymuladd(1.5, 2.0, 2.5)
        .text
        .file   "mymuladd"
        .globl  julia_mymuladd_1406             # -- Begin function julia_mymuladd_1406
        .p2align        4, 0x90
        .type   julia_mymuladd_1406,@function
julia_mymuladd_1406:                    # @julia_mymuladd_1406
; ┌ @ REPL[40]:1 within `mymuladd`
        .cfi_startproc
# %bb.0:                                # %top
; │┌ @ fastmath.jl:165 within `add_fast`
        vfmadd213sd     %xmm2, %xmm1, %xmm0     # xmm0 = (xmm1 * xmm0) + xmm2
; │└
        retq
.Lfunc_end0:
        .size   julia_mymuladd_1406, .Lfunc_end0-julia_mymuladd_1406
        .cfi_endproc
; └
                                        # -- End function
        .section        ".note.GNU-stack","",@progbits

julia> @code_native muladd(1.5, 2.0, 2.5)
        .text
        .file   "muladd"
        .globl  julia_muladd_1412               # -- Begin function julia_muladd_1412
        .p2align        4, 0x90
        .type   julia_muladd_1412,@function
julia_muladd_1412:                      # @julia_muladd_1412
; ┌ @ float.jl:411 within `muladd`
        .cfi_startproc
# %bb.0:                                # %top
        vfmadd213sd     %xmm2, %xmm1, %xmm0     # xmm0 = (xmm1 * xmm0) + xmm2
        retq
.Lfunc_end0:
        .size   julia_muladd_1412, .Lfunc_end0-julia_muladd_1412
        .cfi_endproc
; └
                                        # -- End function
        .section        ".note.GNU-stack","",@progbits

I should have been clear about the comprehensive performance benchmark for performance -- see here for the code and here for the raw plot files.

benchplot_64
benchplot_32

@heltonmc
Copy link
Member

heltonmc commented Nov 3, 2022

My objective was to be as faithful to the C as possible

I'll definitely defer to what others here think! I am in favor of having this ported to Julia and it looks like the accuracy is similar with noticeable performance gains! I would prefer that any port have a Julia style but that's just a personal preference. I think if we maintain a strict criteria that the max ULP always be less than 1.5 then any types of code simplification or optimization be favored.

Though, the reason I used goto was that it actually happens to be faster than a while loop, at least sometimes.

Benchmarking is always so fun 😅 \s. I am wondering how much we are just testing branch prediction though at a single value. The run time is so small these branch misses/hits are a big difference. It makes sense that one value might perform best considering each branch has all the computation unrolled where the while loop recursively builds it up. So for small values the while loop will win and for larger values (considering that is also the first branch) the goto will win. I get similar results though if we assume that the value can be some random value in a larger range (but that is also devaluing that branch a bit)...

Benchmarks loop vs goto
# branched
julia> @benchmark _loggamma(x) setup=(x=rand()*200)
BenchmarkTools.Trial: 10000 samples with 999 evaluations.
 Range (min  max):   9.815 ns  20.980 ns  ┊ GC (min  max): 0.00%  0.00%
 Time  (median):     12.908 ns              ┊ GC (median):    0.00%
 Time  (mean ± σ):   12.985 ns ±  0.595 ns  ┊ GC (mean ± σ):  0.00% ± 0.00%

                           ▆█▅                                ▁
  ▆▃▆▃▁▁▁▁▁▁▁▄▆▅▄▁▁▅▃▁▁▁▁▁████▁▃▁▃▁▁▁▁▆▅▆█▃▁▅▆▄▅█▃▃▆█▆▃▁▃▃▁▃▇ █
  9.81 ns      Histogram: log(frequency) by time      16.7 ns <

 Memory estimate: 0 bytes, allocs estimate: 0.

julia> @benchmark _loggamma(x) setup=(x=rand()*10)
BenchmarkTools.Trial: 10000 samples with 999 evaluations.
 Range (min  max):   9.803 ns  37.448 ns  ┊ GC (min  max): 0.00%  0.00%
 Time  (median):     15.178 ns              ┊ GC (median):    0.00%
 Time  (mean ± σ):   14.809 ns ±  2.168 ns  ┊ GC (mean ± σ):  0.00% ± 0.00%

            ▁            ▇█         █      ▇  ▁█      ▂  ▇  █  
  ▆▅▂▁▁▁▁▁▁▁█▂▁▁▁▁▁▁▂▁▂▁▃██▂▁▁▁▁▁▁▅▄█▄▁▂▂▂▃█▁▁██▁▁▁▂▂▄█▄▄█▁▂█ ▃
  9.8 ns          Histogram: frequency by time        17.5 ns <

 Memory estimate: 0 bytes, allocs estimate: 0.


# while loop
julia> @benchmark _loggamma(x) setup=(x=rand()*200)
BenchmarkTools.Trial: 10000 samples with 999 evaluations.
 Range (min  max):   9.567 ns  24.685 ns  ┊ GC (min  max): 0.00%  0.00%
 Time  (median):     12.910 ns              ┊ GC (median):    0.00%
 Time  (mean ± σ):   13.015 ns ±  0.721 ns  ┊ GC (mean ± σ):  0.00% ± 0.00%

                          ▁ █▃                                ▁
  ▆▃▁▅▄▁▁▁▁▁▁▁▃▇▁▁▁▁▄▁▁▁▁▁████▇▃▁▁▁▃▅▅▁▁█▅▁▅▅▄▄▄█▆▇▃▁▁█▆▁▁▁▃▇ █
  9.57 ns      Histogram: log(frequency) by time        17 ns <

 Memory estimate: 0 bytes, allocs estimate: 0.

julia> @benchmark _loggamma(x) setup=(x=rand()*10)
BenchmarkTools.Trial: 10000 samples with 999 evaluations.
 Range (min  max):   9.570 ns  33.758 ns  ┊ GC (min  max): 0.00%  0.00%
 Time  (median):     15.471 ns              ┊ GC (median):    0.00%
 Time  (mean ± σ):   15.032 ns ±  2.409 ns  ┊ GC (mean ± σ):  0.00% ± 0.00%

                       █                 ▄   ▂     ▂        ▂  
  ▄▄▂▁▁▁▁▁▁▁▆▂▂▂▂▁▂▁▁▁▃█▆▂▂▁▁▁▃▃▄▇▆▁▂▂▁▁▃█▅▁▃█▂▂▂▂▂█▂▂▂▂▅▇▂▁█ ▃
  9.57 ns         Histogram: frequency by time        18.5 ns <

 Memory estimate: 0 bytes, allocs estimate: 0.

The while loop may hit a higher max speed but it appears on average they will be similar.

With respect to two implementations (one carrying the sign, and the other throwing a DomainError instead), the benchmark (see plots below) seem to indicate that it is worthwhile. Typically, I am against code duplication, but ≈10% reduction seems worthwhile.

Yes - hopefully we could improve that a little bit though so we wouldn't see a difference. But 10% at these times is about 1 ns right? I tried to reproduce this and I did observe a difference in one benchmark (using benchmarktools) but did not observe a difference in the sum trial...

Benchmarks return sign
# throwing a DomainError
julia> @benchmark _loggamma(x) setup=(x=rand()*200)
BenchmarkTools.Trial: 10000 samples with 999 evaluations.
 Range (min  max):   8.652 ns  23.051 ns  ┊ GC (min  max): 0.00%  0.00%
 Time  (median):     11.468 ns              ┊ GC (median):    0.00%
 Time  (mean ± σ):   11.911 ns ±  0.918 ns  ┊ GC (mean ± σ):  0.00% ± 0.00%

                    ▁  █▅    ▁▂ ▄ ▅ ▄ ▃▁                      ▁
  ▆▄▃▃▁▆▁▄▁▅▁▁▁▁▁▁▁▁█▃▄███▆█▃██▁█▁█▃█▁██▄▇▄▅▄█▆▃▃█▃▆▆▁▁█▁▃▁▁▆ █
  8.65 ns      Histogram: log(frequency) by time      16.2 ns <

 Memory estimate: 0 bytes, allocs estimate: 0.

julia> bench(_loggamma, Order=true)[1]
11.165080070495605

julia> bench(_loggamma, Order=false)[1]
11.603279113769531

# returning the sign

julia> @benchmark _loggamma(x) setup=(x=rand()*200)
BenchmarkTools.Trial: 10000 samples with 999 evaluations.
 Range (min  max):   9.832 ns  25.419 ns  ┊ GC (min  max): 0.00%  0.00%
 Time  (median):     12.909 ns              ┊ GC (median):    0.00%
 Time  (mean ± σ):   12.990 ns ±  0.718 ns  ┊ GC (mean ± σ):  0.00% ± 0.00%

                         ▅▇█▃                                 ▂
  ▅▆▄▁▁▁▁▁▁▁▁▇▃▃▁▁▁▁▁▁▁▁▁████▆▄▃▁▁▁▁▆▄▆█▃▁▅▅▅█▇▁▁▇█▁▃▃▃▁▁▆▇▄▆ █
  9.83 ns      Histogram: log(frequency) by time        17 ns <

 Memory estimate: 0 bytes, allocs estimate: 0.


julia> bench(_loggamma, Order=true)[1]
11.080536842346191

julia> bench(_loggamma, Order=false)[1]
11.531438827514648

So I think it is going to depend on how you use the function. Hopefully, we can continue improving the code a little bit. I am sure there is some more bit hacking in the beginning that others can improve a bit.

@heltonmc heltonmc requested a review from oscardssmith November 3, 2022 15:01
@oscardssmith
Copy link
Member

Thanks so much for doing this!

@andrewjradcliffe
Copy link
Contributor Author

andrewjradcliffe commented Nov 3, 2022

I am wondering how much we are just testing branch prediction though at a single value.

I wondered this as well. See the code and attached plot.

benchplot_lpgt64

while loop vs. goto

using Distributed, Plots
addprocs(24)
@everywhere using BenchmarkTools
fname = joinpath(@__DIR__, "macrointerpolationworkaround.jl")
s = """
function bench_lpgt(x)
    b1 = @benchmark looped(\$x)
    b2 = @benchmark goto(\$x)
    b1, b2
end
mintimes_lpgt(x) = map(x -> minimum(x).time, bench_lpgt(x))
"""
open(fname, "w") do io
    write(io, s)
end
@everywhere function looped(x)
    i = Base.unsafe_trunc(Int8, x)
    y = x - float(i)
	z = 1.0
    p = 0.0
    u = x
    while u >= 3.0
        p -= 1.0
        u = x + p
        z *= u
    end
    z
end;
@everywhere function goto(x)
    i = Base.unsafe_trunc(Int8, x)
    y = x - float(i)
	z = 1.0
    if i == Int8(7)
        z *= y + 6.0
        @goto case6
    elseif i == Int8(6)
        @label case6
        z *= y + 5.0
        @goto case5
    elseif i == Int8(5)
        @label case5
        z *= y + 4.0
        @goto case4
    elseif i == Int8(4)
        @label case4
        z *= y + 3.0
        @goto case3
    elseif i == Int8(3)
        @label case3
        z *= y + 2.0
    end
    z
end;
@everywhere include($fname)
timescat(ts) = mapreduce(x -> [x...], hcat, ts)
pl_f(x, y; opts...) = plot(x, y, label=["looped, avg=$(mean(y[:, 1]))" "goto, avg=$(mean(y[:, 2]))"], legend=:topleft, ylabel="time (ns)", xlabel="x (function arg)"; opts...)

Δ = 1e-3
x = nextfloat(2.0):Δ:prevfloat(8.0)
ts = pmap(mintimes_lpgt, x);
y = timescat(ts);
p = pl_f(x, y', title="Float64, Δx=$(Δ)");
savefig(p, joinpath(@__DIR__, "benchplot_lpgt64.pdf"))
savefig(p, joinpath(@__DIR__, "benchplot_lpgt64.png"))
mean(y, dims=2)

But 10% at these times is about 1 ns right?

1ns extra per call matters when the computation involves millions of calls per iteration of an algorithm (e.g. a large log-likelihood computation).
More generally, one must admit -- this type of argument cuts both ways. 1ns is about 4x the gain that one gets from replacing all a * b + c with muladd. For elementary functions, I am inclined to optimize as far as possible. The necessity of a separate function definition is poor justification to leave performance on the table. Why should a few bytes worth of characters stand in the way?
As for making the return of both value and sign as fast as value, it just won't be possible -- your return simply occupies more bytes. I had considered a method to pre-check for the sign of the return in a separate function, so that the actual loggamma body could always return just a Float64. However, that would clearly be more convoluted, and, moreover, inevitably wasteful (i.e. slower) due to the fact that some of the intermediates (the bitcast and splitting of the u64) would need to be re-used in a the actual call (when early exit does not occur due to sign of return being negative). OpenLibm's code is constructed well -- there is little room to optimize.

Edit

Incidentally, the following changes make the while loop superior. See the updated plot below.

Change: use round(x, RoundToZero) to avoid casts.

function looped(x)
    i = round(x, RoundToZero)
    y = x - i
	z = 1.0
    p = 0.0
    u = x
    while u >= 3.0
        p -= 1.0
        u = x + p
        z *= u
    end
    z
end;
function goto(x)
    i = round(x, RoundToZero)
    y = x - i
	z = 1.0
    if i == 7.0
        z *= y + 6.0
        @goto case6
    elseif i == 6.0
        @label case6
        z *= y + 5.0
        @goto case5
    elseif i == 5.0
        @label case5
        z *= y + 4.0
        @goto case4
    elseif i == 4.0
        @label case4
        z *= y + 3.0
        @goto case3
    elseif i == 3.0
        @label case3
        z *= y + 2.0
    end
    z
end;

benchplot_lpgt64_new

@andrewjradcliffe
Copy link
Contributor Author

andrewjradcliffe commented Nov 4, 2022

I will revise the Float32 implementation tomorrow. The benchmark of the revised Float64 implementation makes a separate implementation (throw instead of return sign) unnecesary.

Edit for clarity: In the plot below, logabsgamma's code corresponds to commit 606190d. loggamma is the same code, but with the throw substitution as in the original commit.

benchplot_64_new

2nd edit: Benchmark of revised Float32 implementation. I am not surprised to find that the two-arg return (Float32,Int32) is faster here, as 1) there is no error path and 2) the return fits into a single 8byte word.
benchplot_32_new2

function _lgamma_r(x::Float64)
ux = reinterpret(UInt64, x)
hx = ux >>> 32 % Int32
lx = ux % UInt32
Copy link
Member

@oscardssmith oscardssmith Nov 16, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's not obvious to me that the split into hx and lx is beneficial on modern CPUs, but undoing it is probably a massive pain.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that you mention it, between two weeks ago and now I have seen some other tricks in musl. Let me take a second pass at the bitmagic portions.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, other than computing hx < Int32(0) a single time (already incorporate in latest commits), musl takes a similar approach. The commentary below is not strictly on-topic, but may be helpful to someone in the future.

Close inspection of the musl implementation (which is based on fdlibm, same as openlibm) reveals only one major difference. (other than our treatment of the switch statements)

  • Handling of 0: Initial inspection of lines 177-188 suggests that one might eliminate a branch around 0.0 by letting it fall through (resulting in a log call, which is slower than simply returning Inf -- most likely, keeping the branch is faster anyway). This has an undesirable side effect: -0.0 returns a sign of -1, rather than 1. This is not-standard (arguably a violation of IEEE rules around signed zeros, as it is a novel special treatment); oddly, it propagates into the rust port.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

although the log call is slower than returning Inf you save a branch for all the nonzero values so it's probably faster overall.

This propagates any signals which happen to be on the `NaN`, mimicking
the behavior of OpenLibm (though, I suspect no one depends on
signaling `NaN`s being propagated through `logabsgamma`).
@oscardssmith
Copy link
Member

have you done Float32 exhaustive tests yet? If not, they should probably be done before this gets merged.

@andrewjradcliffe
Copy link
Contributor Author

Float32 exhaustive tests -- test function: logabsgamma_new(x)[1] ≈ logabsgamma_old(x)[1]

  • Every positive Float32 tests true
  • 2139092692 out of 2139095041 negative Float32s test true. The 2349 which do not return true fall within the range [-2.4566386f0, -7.0001984f0], thus, all within the same branch; the (minimum, maximum) absolute difference of new - old is (2.9802322f-8, 7.1525574f-7).

@heltonmc
Copy link
Member

heltonmc commented Nov 18, 2022

Along the same lines is there a reason tests use absolute tolerances instead of relative tolerances? The openlibm implementation has max errors around ~1.5-1.6 ULPs which this implementation also does which is great. But it would be nice if we had automated checks with tolerances like rtol=2*eps() across the whole domain. Ideally, these tests would test the strictest possible error criteria to prevent any backsliding (if someone modifies the code in the future) and be focused around the cutoffs of the different branches.

That does bring up an issue though if we are testing against openlibm which has max ULP around 1.6 and this implementation also has similar errors but at different values. We might not be able to have the strictest possible test unless we test against the higher precision implementation. It would be nice if there were a few checks that would guarantee relative tolerances that CI could check on different operating systems and julia versions.

@oscardssmith
Copy link
Member

can't we just test against the MPFR gamma?

@andrewjradcliffe
Copy link
Contributor Author

can't we just test against the MPFR gamma?

This is most logical as we no longer need to use openlibm as a reference. When the Float32 finishes, I'll update with results.

Construction of unit tests needs some consideration. @heltonmc seems to be suggesting that we sweep the entire domain, imposing a condition such as:

function ulp(x::Union{Float32, Float64})
    z′ = logabsgamma(x)[1]
    z = logabsgamma(big(x))[1] # Dispatches to MPFR
    isinf(z′) && isinf(z) && return big(0.0)
    iszero(z′) && iszero(z) && return big(0.0)
    e = exponent(z′)
    abs(z′ - z) * 2.0^(precision(x) - 1 - e)
end

# Interval 1: 0.0 < x ≤ 2.0
x = nextfloat(0.0):1e-8:2.0
@test maximum(ulp, x) < thresh1
# Interval 2: 2.0 < x < 8.0
x = nextfloat(2.0):1e-8:prevfloat(8.0)
@test maximum(ulp, x) < thresh2
# Interval 3: 8.0 ≤ x < 2^58
x = 8.0:1e-4:2.0^7
@test maximum(ulp, x) < thresh3_1
x = 2.0^7:1.0:2.0^20
@test maximum(ulp, x) < thresh3_2
x = 2.0^20:2.0^20:2.0^40
@test maximum(ulp, x) < thresh3_3
x = 2.0^40:2.0^40:prevfloat(2.0^58)
@test maximum(ulp, x) < thresh3_4
# Interval 4: 2^58 ≤ x ≤ Inf
x = (2.0^58):(2.0^1000):prevfloat(Inf)
@test maximum(ulp, x) < thresh4

It makes sense to sweep the domain when proposing changes, but for regular integration tests, it would be unnecessary (and introduce considerable latency).

Incidentally, here is ulp for Float64. Stepsize was chosen largely for convenience of plotting, so don't read anything more into it.

ulp_square

@oscardssmith
Copy link
Member

Looks great! I like full sweeps for Float32 and smaller functions since it's relatively easy. But that's way overkill for CI. For CI, I'd probably just test a few values in different ranges. to make sure everything is working as expected.

@andrewjradcliffe
Copy link
Contributor Author

I like full sweeps for Float32 and smaller functions since it's relatively easy.

Float32 sweep is still ongoing; launched it ~72 hours ago on a single process. I underestimated the cpu time of emulated floating point (MPFR). In hindsight, should have split the range and pmap'd it. (which I just did using 96 cores, so hopefully I can wrap this up tomorrow).

@oscardssmith
Copy link
Member

oscardssmith commented Nov 22, 2022

oh, you are doing this the very slow way. You can compare Float32 to the C Float64 implimentation which is more than accurate enough and should be at least 10x faster. You only need MPFR to test the 64 bit version.

@andrewjradcliffe
Copy link
Contributor Author

oh, you are doing this the very slow way.

heh, yeah. ._.; brain is slow at the end of a long day.

The Float32 sweep across 0 <= x <= Inf32 revealed a problem in the Float64 method, fixed with the above PR. The maximum ulp was 2.5, and only 0.0175% (373390 out of 2139095041) have ulp >= 1.5.

The sweep across -Inf32 <= x < 0 revealed a different problem -- 0.0825% (1763178 out of 2139095041) of Float32 values have ulp >= 3.5. The maximum ulp was 3.5877014609375e6, thus, sometimes negative Float32s return completely wrong values. However, the values which exhibit ulp >= 3.5 are consistent with those produced by OpenLibm, so there is no net change. Quantitatively: 99.8667% of the time (1760829 out of 1763178), this PR and OpenLibm give an equally wrong value.

@oscardssmith
Copy link
Member

sounds mostly good! it would be good to see where the very wrong answers are. my guess is overflow but it would be good to make sure it's not a problem.

@andrewjradcliffe
Copy link
Contributor Author

The following sets of plots should provide a comprehensive view into my previous comment's statements. The code used for this can be found here; note that the implementation of logabsgamma in this PR is identical to logabsgamma in the OpenLibmPorts.

anomaly_Float32_thispr0
anomaly_Float32_openlibm

@oscardssmith
Copy link
Member

IMO, this is ready to merge.

@heltonmc
Copy link
Member

heltonmc commented Dec 5, 2022

Looks good to me 😄! We could also use this function at Bessels.jl if you would like to contribute that there as well to separate native Julia code.

There is some momentum to separate these gamma functions into a smaller package (#409 (comment)) but this would complement the existing gamma function at Bessels.jl for the time being.

@oscardssmith
Copy link
Member

The test failures here appear to be real.

@andrewjradcliffe
Copy link
Contributor Author

The test failures here appear to be real.

Strange, but true. I don't have access to a mac, so my options are limited to raising the threshold in the hopes that (Julia v1.3, macOS) will magically pass. Sadly, this means that regular CI will not be as stringent, but I added a comment in the tests to document this.

@andrewjradcliffe
Copy link
Contributor Author

Perhaps I should have mentioned it in the previous comment, but addressing the single source of test failures was quite minor (see 83071be). Consequently, this is ready to merge.

@oscardssmith oscardssmith merged commit 2cbb4ae into JuliaMath:master Dec 20, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants