Skip to content
Closed
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DistributionsAD"
uuid = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
version = "0.6.43"
version = "0.7"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down
40 changes: 1 addition & 39 deletions src/DistributionsAD.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,45 +80,7 @@ include("zygote.jl")
end

@require LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02" begin
using .LazyArrays: BroadcastArray, BroadcastVector, LazyArray

const LazyVectorOfUnivariate{
S<:ValueSupport,
T<:UnivariateDistribution{S},
Tdists<:BroadcastVector{T},
} = VectorOfUnivariate{S,T,Tdists}

function Distributions._logpdf(
dist::LazyVectorOfUnivariate,
x::AbstractVector{<:Real},
)
return sum(copy(logpdf.(dist.v, x)))
end

function Distributions.logpdf(
dist::LazyVectorOfUnivariate,
x::AbstractMatrix{<:Real},
)
size(x, 1) == length(dist) ||
throw(DimensionMismatch("Inconsistent array dimensions."))
return vec(sum(copy(logpdf.(dists, x)), dims = 1))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This even has a bug in it: dists isn't defined..

end

const LazyMatrixOfUnivariate{
S<:ValueSupport,
T<:UnivariateDistribution{S},
Tdists<:BroadcastArray{T,2},
} = MatrixOfUnivariate{S,T,Tdists}

function Distributions._logpdf(
dist::LazyMatrixOfUnivariate,
x::AbstractMatrix{<:Real},
)
return sum(copy(logpdf.(dist.dists, x)))
end

lazyarray(f, x...) = LazyArray(Base.broadcasted(f, x...))
export lazyarray
include("lazyarrays.jl")
end
end

Expand Down
109 changes: 109 additions & 0 deletions src/common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,112 @@ parameterless_type(x) = parameterless_type(typeof(x))
parameterless_type(x::Type) = __parameterless_type(x)

@non_differentiable adapt_randn(::Any...)

"""
make_closure(f, g)

Return a closure of the form `(x, args...) -> f(g(args...), x)`.

# Examples

This is particularly useful when one wants to avoid broadcasting over constructors
which can sometimes cause issues with type-inference, in particular when combined
with reverse-mode AD frameworks.

```juliarepl
julia> using DistributionsAD, Distributions, ReverseDiff, BenchmarkTools

julia> const data = randn(1000);

julia> x = randn(length(data));

julia> f(x) = sum(logpdf.(Normal.(x), data))
f (generic function with 2 methods)

julia> @btime ReverseDiff.gradient(\$f, \$x);
848.759 μs (14605 allocations: 521.84 KiB)

julia> # Much faster with ReverseDiff.jl.
g(x) = let g_inner = DistributionsAD.make_closure(logpdf, Normal)
sum(g_inner.(data, x))
end
g (generic function with 1 method)

julia> @btime ReverseDiff.gradient(\$g, \$x);
17.460 μs (17 allocations: 71.52 KiB)
```

See https://github.com/TuringLang/Turing.jl/issues/1934 more further discussion.

# Notes
To really go "vrooom!\" one needs to specialize on the arguments, e.g. if one
has a function `myfunc` then we need to define

```julia
make_closure(::typeof(myfunc), ::Type{D}) where {D} = myfunc(D(args...), x)
```

This can also be done using `DistributionsAD.@specialize_make_closure`:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe just missing type parameters in the definition below?


```julia
julia> mylogpdf(d, x) = logpdf(d, x)
mylogpdf (generic function with 1 method)

julia> h(x) = let inner = DistributionsAD.make_closure(mylogpdf, Normal)
sum(inner.(data, x))
end
h (generic function with 1 method)

julia> @btime ReverseDiff.gradient(\$h, \$x);
1.220 ms (37011 allocations: 1.42 MiB)

julia> DistributionsAD.@specialize_make_closure mylogpdf

julia> @btime ReverseDiff.gradient(\$h, \$x);
17.038 μs (17 allocations: 71.52 KiB)
```
"""
make_closure(f, g) = (x, args...) -> f(g(args...), x)
make_closure(f, ::Type{D}) where {D} = (x, args...) -> f(D(args...), x)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if there are any possible performance/compiler benefits by not closing over the variables but to use (more-Julian) callable structs that capture f and g. In any case, I think you want

Suggested change
make_closure(f, g) = (x, args...) -> f(g(args...), x)
make_closure(f, ::Type{D}) where {D} = (x, args...) -> f(D(args...), x)
make_closure(f::F, g::G) where {F,G} = (x, args...) -> f(g(args...), x)
make_closure(f::F, ::Type{D}) where {F,D} = (x, args...) -> f(D(args...), x)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I unfortuantely tried that but AFAIK it ends up with this issue of closing over a UnionAll type again, which is exactly what we're trying to avoid (because of the issues it's causing with some AD backends) 😕

I might have not done it correctly though.

But you suggestion I have tried, and it unfortunately doesn't have an affect. If you just look at the returned closures, they're all the same one 😕

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe something like

struct Closure{F,G} end

Closure(::F, ::G) where {F,G} = Closure{F,G}()
Closure(::F, ::Type{G}) where {F,G} = Closure{F,G}()
Closure(::Type{F}, ::G) where {F,G} = Closure{F,G}()
Closure(::Type{F}, ::Type{G}) where {F,G} = Closure{F,G}()

for f in [pdf, logpdf, cdf, logcdf]
    @eval (::$(Closure){typeof($f),G})(x, args...) where {G} = $f(G(args...), x)
end

?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Closure(::F, ::Type{G}) where {F,G} = Closure{F,G}()

and others avoid the UnionAll issue.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have a look at what I've done now 👍



"""
has_specialized_make_closure(f, g)

Return `true` if there exists a specialized `make_closure(f, g)` implementation.
"""
has_specialized_make_closure(f, g) = false

# To go vroooom we need to specialize on the first argument, thus ensuring that
# a different closure is constructed for each method.
"""
@specialize_make_closure(f)

Define `make_closure` and `has_specialized_make_closure` for first first argument being `f`
and second argument being a type.
"""
macro specialize_make_closure(f)
return quote
$(DistributionsAD).make_closure(::typeof($(esc(f))), ::Type{D}) where {D} = (x, args...) -> $(esc(f))(D(args...), x)
$(DistributionsAD).has_specialized_make_closure(::typeof($(esc(f))), ::Type{D}) where {D} = true
end
end

"""
@specialize_make_closure(f, g)

Define `make_closure` and `has_specialized_make_closure` for first first argument being `f`
and second argument being `g`.
"""
macro specialize_make_closure(f, g)
return quote
$(DistributionsAD).make_closure(::typeof($(esc(f))), ::typeof($(esc(g)))) = (x, args...) -> $(esc(f))($(esc(g))(args...), x)
$(DistributionsAD).has_specialized_make_closure(::typeof($(esc(f))), ::typeof{$(esc(g))}) = true
end
end

@specialize_make_closure Distributions.pdf
@specialize_make_closure Distributions.logpdf
@specialize_make_closure Distributions.loglikelihood
@specialize_make_closure Distributions.cdf
@specialize_make_closure Distributions.logcdf
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it should be possible to remove all of this code. Maybe type parameters are already sufficient. Or using a callable struct might help.

67 changes: 67 additions & 0 deletions src/lazyarrays.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
using .LazyArrays: BroadcastArray, BroadcastVector, LazyArray

# Necessary to make `BroadcastArray` work nicely with Zygote.
function ChainRulesCore.rrule(config::ChainRulesCore.RuleConfig{>:ChainRulesCore.HasReverseMode}, ::Type{BroadcastArray}, f, args...)
return ChainRulesCore.rrule_via_ad(config, Broadcast.broadcasted, f, args...)
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ChainRules definitions for LazyArrays.jl should really not be part of DistributionsAD. The other type piracy is already bad but at least Distributions-specific. But these lines seem really inappropriate here.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed! See JuliaArrays/LazyArrays.jl#232

Think the plan is to make a glue-package for now.

Btw, this rrule is also not good since, for example, logpdf(arraydist(BroadcastArray(Normal, x)), data) will then be separated into two broadcast statements again, which is the opposite of what we want 😕

We could of course define adjoint rules for logpdf, etc. that specializes on the BroadcastArray scenario, but that is all non-ideal 😕

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think a glue package is a good long term solution. If one does not want to load ChainRulesCore for all users (even though I think it's loaded anyways in almost all realistic scenarios), a weak dependency seems the best way IMO.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Btw got any good idea of how to define the constructor for BroadcastArray properly?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This far I have:

function ChainRulesCore.rrule(::Type{BroadcastArray}, f, args...)
    function BroadcastArray_pullback::ChainRulesCore.Tangent)
        return (ChainRulesCore.NoTangent(), Δ.f, Δ.args...)
    end
    return BroadcastArray(f, args...), BroadcastArray_pullback
end

function ChainRulesCore.rrule(
    config::ChainRulesCore.RuleConfig{>:ChainRulesCore.HasReverseMode},
    ::typeof(Distributions.logpdf),
    dist::Distributions.Product{V,D,A}, x::AbstractVector{<:Real}
) where {V,D<:Distribution,A<:BroadcastArray}
    cl = DistributionsAD.Closure(logpdf, DistributionsAD._inner_constructor(typeof(dist.v)))
    y, dy = ChainRulesCore.rrule_via_ad(config, broadcast, cl, x, dist.v.args...)
    z, dz = ChainRulesCore.rrule_via_ad(config, sum, y)

    f = dist.v.f
    function logpdf_adjoint...)
        # 1st argument is `sum` -> nothing.
        (_, sum_Δ...) = dz...)
        # 1st argument is `broadcast` -> nothing.
        # 2nd argument is `cl` -> `nothing`.
        # 3rd argument is `x` -> something.
        # Rest is `dist` arguments -> something
        (_, _, x_Δ, args_Δ...) = dy(sum_Δ...)
        # Construct the structural tangents.
        ba_tangent = ChainRulesCore.Tangent{A}(f=f, args=args_Δ)
        dist_tangent = ChainRulesCore.Tangent{typeof(dist)}(v=ba_tangent)

        return (ChainRulesCore.NoTangent(), dist_tangent, x_Δ)
    end

    return z, logpdf_adjoint
end

which does indeed to the trick but does not look great 😕

  1. I don't like how I have to call back into AD, but I don't atm see a way around that.
  2. I'm not too familiar with structural Tangent, and so I don't know if nesting them is a bad idea. For example I noticed that when we hit the pullback for BroadcastArray, I'm looking at a Tangent{Any} despite this not being the case in logpdf_adjoint (though this might just be Zygote doing something with it in the mean time?).
  3. This isn't full support for BroadcastArray since in most cases it won't receive a Tangent.

Any more?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll try to look into it a bit this evening (currently a bit busy since we are working on a paper for ICML but it's in such an early stage I'm not sure we will make the deadline 😄). ProjectTo reminds me of this PR which I just happened to comment on this morning: JuliaArrays/FillArrays.jl#153 Maybe the code there could be helpful if you want to define ProjectTo for BroadcastArray (not sure if that was what you were asking about).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome:) Btw, really appreciate your responses over these past days; been incredibly helpful:) I'll be doing some workshop talks at a winter school on Turing next week and I would have loved to pin this down before that but we'll see 👍

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Btw I just pushed the resulting impl so you can look at it properly when you have time. It does seem to work awfully well but it's very hacky 😕

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A PhD student from Gothenburg actually told me about the winter school and thought it might be interesting for me based on my research interests. I would have liked to go to Norway since I have only been to Bergen so far but unfortunately I don't have time (and I assume I'm not the intended audience for a workshop about Turing 😛).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ooooh that would have been awesome!:( Well, hope you make it at some other point. Haha, maybe not 😅


const LazyVectorOfUnivariate{
S<:ValueSupport,
T<:UnivariateDistribution{S},
Tdists<:BroadcastVector{T},
} = VectorOfUnivariate{S,T,Tdists}

_inner_constructor(::Type{<:BroadcastVector{<:Any,Type{D}}}) where {D} = D

function Distributions._logpdf(
dist::LazyVectorOfUnivariate,
x::AbstractVector{<:Real},
)
# TODO: Make use of `sum(Broadcast.instantiate(Broadcast.broadcasted(f, x, args...)))` once
# we've addressed performance issues in ReverseDiff.jl.
constructor = _inner_constructor(typeof(dist.v))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm sure this will be problematic in some cases and break. It's not guaranteed that _inner_constructor returns a proper constructor. Something safer would be https://github.com/JuliaObjects/ConstructionBase.jl I assume.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A simple example:

julia> struct A{X,Y}
           x::X
           y::Y
           A(x::X, y::Y) where {X,Y} = new{X,Y}(x, y)
       end

julia> _constructor(::Type{D}) where {D} = D
_constructor (generic function with 1 method)

julia> x, y = 1, 2.0
(1, 2.0)

julia> a = A(x, y)
A{Int64, Float64}(1, 2.0)

julia> _constructor(typeof(a))(x, y)
ERROR: MethodError: no method matching A{Int64, Float64}(::Int64, ::Float64)
Stacktrace:
 [1] top-level scope
   @ REPL[31]:1

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I was actually using ConstructionBase locally for this before:) But I removed it because I figured this will only be used for a very simple subset of constructors, so uncertain if it's worth it. But I'll add it back again:)

return if has_specialized_make_closure(logpdf, constructor)
f = make_closure(logpdf, constructor)
sum(f.(x, dist.v.args...))
else
sum(copy(logpdf.(dist, x)))
end
end

function Distributions.logpdf(
dist::LazyVectorOfUnivariate,
x::AbstractMatrix{<:Real},
)
size(x, 1) == length(dist) ||
throw(DimensionMismatch("Inconsistent array dimensions."))
constructor = _inner_constructor(typeof(dist.v))
return if has_specialized_make_closure(logpdf, constructor)
f = make_closure(logpdf, constructor)
vec(sum(f.(x, dist.v.args...), dims = 1))
else
vec(sum(copy(logpdf.(dist, x)); dims = 1))
end
end

const LazyMatrixOfUnivariate{
S<:ValueSupport,
T<:UnivariateDistribution{S},
Tdists<:BroadcastArray{T,2},
} = MatrixOfUnivariate{S,T,Tdists}

function Distributions._logpdf(
dist::LazyMatrixOfUnivariate,
x::AbstractMatrix{<:Real},
)

constructor = _inner_constructor(typeof(dist.v))
return if has_specialized_make_closure(logpdf, constructor)
f = make_closure(logpdf, constructor)
sum(f.(x, dist.v.args))
else
sum(copy(logpdf.(dist.dists, x)))
end
end

lazyarray(f, x...) = BroadcastArray(f, x...)
export lazyarray
Comment on lines +45 to +47
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not clear to me why this is needed. It doesn't seem much shorter and it makes it less clear that everything is based on LazyArrays.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh no this was already in DistributionsAD.jl 🤷 Not something I put in here. I was also unaware of this methods existence.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we deprecate it?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Happy to!