-
Notifications
You must be signed in to change notification settings - Fork 30
Faster arraydist
with LazyArrays.jl
#231
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 6 commits
0aad616
cd9c845
3cfa86e
896bce3
fcf9bf4
6d0855b
74d4e38
ae52b81
fcdd588
59423de
bcfdecf
90d3bbc
9a1b201
80b3d51
a792e8b
c9324a3
b3b2786
a43d6e5
8604902
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
|
@@ -48,3 +48,112 @@ parameterless_type(x) = parameterless_type(typeof(x)) | |||||||||
parameterless_type(x::Type) = __parameterless_type(x) | ||||||||||
|
||||||||||
@non_differentiable adapt_randn(::Any...) | ||||||||||
|
||||||||||
""" | ||||||||||
make_closure(f, g) | ||||||||||
|
||||||||||
Return a closure of the form `(x, args...) -> f(g(args...), x)`. | ||||||||||
|
||||||||||
# Examples | ||||||||||
|
||||||||||
This is particularly useful when one wants to avoid broadcasting over constructors | ||||||||||
which can sometimes cause issues with type-inference, in particular when combined | ||||||||||
with reverse-mode AD frameworks. | ||||||||||
|
||||||||||
```juliarepl | ||||||||||
julia> using DistributionsAD, Distributions, ReverseDiff, BenchmarkTools | ||||||||||
|
||||||||||
julia> const data = randn(1000); | ||||||||||
|
||||||||||
julia> x = randn(length(data)); | ||||||||||
|
||||||||||
julia> f(x) = sum(logpdf.(Normal.(x), data)) | ||||||||||
f (generic function with 2 methods) | ||||||||||
|
||||||||||
julia> @btime ReverseDiff.gradient(\$f, \$x); | ||||||||||
848.759 μs (14605 allocations: 521.84 KiB) | ||||||||||
|
||||||||||
julia> # Much faster with ReverseDiff.jl. | ||||||||||
g(x) = let g_inner = DistributionsAD.make_closure(logpdf, Normal) | ||||||||||
sum(g_inner.(data, x)) | ||||||||||
end | ||||||||||
g (generic function with 1 method) | ||||||||||
|
||||||||||
julia> @btime ReverseDiff.gradient(\$g, \$x); | ||||||||||
17.460 μs (17 allocations: 71.52 KiB) | ||||||||||
``` | ||||||||||
|
||||||||||
See https://github.com/TuringLang/Turing.jl/issues/1934 more further discussion. | ||||||||||
|
||||||||||
# Notes | ||||||||||
To really go "vrooom!\" one needs to specialize on the arguments, e.g. if one | ||||||||||
has a function `myfunc` then we need to define | ||||||||||
|
||||||||||
```julia | ||||||||||
make_closure(::typeof(myfunc), ::Type{D}) where {D} = myfunc(D(args...), x) | ||||||||||
``` | ||||||||||
|
||||||||||
This can also be done using `DistributionsAD.@specialize_make_closure`: | ||||||||||
|
||||||||||
|
||||||||||
```julia | ||||||||||
julia> mylogpdf(d, x) = logpdf(d, x) | ||||||||||
mylogpdf (generic function with 1 method) | ||||||||||
|
||||||||||
julia> h(x) = let inner = DistributionsAD.make_closure(mylogpdf, Normal) | ||||||||||
sum(inner.(data, x)) | ||||||||||
end | ||||||||||
h (generic function with 1 method) | ||||||||||
|
||||||||||
julia> @btime ReverseDiff.gradient(\$h, \$x); | ||||||||||
1.220 ms (37011 allocations: 1.42 MiB) | ||||||||||
|
||||||||||
julia> DistributionsAD.@specialize_make_closure mylogpdf | ||||||||||
|
||||||||||
julia> @btime ReverseDiff.gradient(\$h, \$x); | ||||||||||
17.038 μs (17 allocations: 71.52 KiB) | ||||||||||
``` | ||||||||||
""" | ||||||||||
make_closure(f, g) = (x, args...) -> f(g(args...), x) | ||||||||||
make_closure(f, ::Type{D}) where {D} = (x, args...) -> f(D(args...), x) | ||||||||||
|
make_closure(f, g) = (x, args...) -> f(g(args...), x) | |
make_closure(f, ::Type{D}) where {D} = (x, args...) -> f(D(args...), x) | |
make_closure(f::F, g::G) where {F,G} = (x, args...) -> f(g(args...), x) | |
make_closure(f::F, ::Type{D}) where {F,D} = (x, args...) -> f(D(args...), x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I unfortuantely tried that but AFAIK it ends up with this issue of closing over a UnionAll
type again, which is exactly what we're trying to avoid (because of the issues it's causing with some AD backends) 😕
I might have not done it correctly though.
But you suggestion I have tried, and it unfortunately doesn't have an affect. If you just look at the returned closures, they're all the same one 😕
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe something like
struct Closure{F,G} end
Closure(::F, ::G) where {F,G} = Closure{F,G}()
Closure(::F, ::Type{G}) where {F,G} = Closure{F,G}()
Closure(::Type{F}, ::G) where {F,G} = Closure{F,G}()
Closure(::Type{F}, ::Type{G}) where {F,G} = Closure{F,G}()
for f in [pdf, logpdf, cdf, logcdf]
@eval (::$(Closure){typeof($f),G})(x, args...) where {G} = $f(G(args...), x)
end
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Closure(::F, ::Type{G}) where {F,G} = Closure{F,G}()
and others avoid the UnionAll
issue.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have a look at what I've done now 👍
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it should be possible to remove all of this code. Maybe type parameters are already sufficient. Or using a callable struct might help.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
using .LazyArrays: BroadcastArray, BroadcastVector, LazyArray | ||
torfjelde marked this conversation as resolved.
Show resolved
Hide resolved
|
||
using ChainRulesCore: ChainRulesCore | ||
torfjelde marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
||
# Necessary to make `BroadcastArray` work nicely with Zygote. | ||
function ChainRulesCore.rrule(config::ChainRulesCore.RuleConfig{>:ChainRulesCores.HasReverseMode}, ::Type{BroadcastArray}, f, args...) | ||
return ChainRulesCore.rrule_via_ad(config, Broadcast.broadcasted, f, args...) | ||
end | ||
|
||
const LazyVectorOfUnivariate{ | ||
S<:ValueSupport, | ||
T<:UnivariateDistribution{S}, | ||
Tdists<:BroadcastVector{T}, | ||
} = VectorOfUnivariate{S,T,Tdists} | ||
|
||
_inner_constructor(::Type{<:BroadcastVector{<:Any,Type{D}}}) where {D} = D | ||
|
||
function Distributions._logpdf( | ||
dist::LazyVectorOfUnivariate, | ||
x::AbstractVector{<:Real}, | ||
) | ||
# TODO: Make use of `sum(Broadcast.instantiate(Broadcast.broadcasted(f, x, args...)))` once | ||
# we've addressed performance issues in ReverseDiff.jl. | ||
constructor = _inner_constructor(typeof(dist.v)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm sure this will be problematic in some cases and break. It's not guaranteed that There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A simple example: julia> struct A{X,Y}
x::X
y::Y
A(x::X, y::Y) where {X,Y} = new{X,Y}(x, y)
end
julia> _constructor(::Type{D}) where {D} = D
_constructor (generic function with 1 method)
julia> x, y = 1, 2.0
(1, 2.0)
julia> a = A(x, y)
A{Int64, Float64}(1, 2.0)
julia> _constructor(typeof(a))(x, y)
ERROR: MethodError: no method matching A{Int64, Float64}(::Int64, ::Float64)
Stacktrace:
[1] top-level scope
@ REPL[31]:1 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So I was actually using ConstructionBase locally for this before:) But I removed it because I figured this will only be used for a very simple subset of constructors, so uncertain if it's worth it. But I'll add it back again:) |
||
return if has_specialized_make_closure(logpdf, constructor) | ||
f = make_closure(logpdf, constructor) | ||
sum(f.(x, dist.v.args...)) | ||
else | ||
sum(copy(logpdf.(dist, x))) | ||
end | ||
end | ||
|
||
function Distributions.logpdf( | ||
dist::LazyVectorOfUnivariate, | ||
x::AbstractMatrix{<:Real}, | ||
) | ||
size(x, 1) == length(dist) || | ||
throw(DimensionMismatch("Inconsistent array dimensions.")) | ||
constructor = _inner_constructor(typeof(dist.v)) | ||
return if has_specialized_make_closure(logpdf, constructor) | ||
f = make_closure(logpdf, constructor) | ||
vec(sum(f.(x, dist.v.args...), dims = 1)) | ||
else | ||
vec(sum(copy(logpdf.(dist, x)); dims = 1)) | ||
end | ||
end | ||
|
||
const LazyMatrixOfUnivariate{ | ||
S<:ValueSupport, | ||
T<:UnivariateDistribution{S}, | ||
Tdists<:BroadcastArray{T,2}, | ||
} = MatrixOfUnivariate{S,T,Tdists} | ||
|
||
function Distributions._logpdf( | ||
dist::LazyMatrixOfUnivariate, | ||
x::AbstractMatrix{<:Real}, | ||
) | ||
|
||
constructor = _inner_constructor(typeof(dist.v)) | ||
return if has_specialized_make_closure(logpdf, constructor) | ||
f = make_closure(logpdf, constructor) | ||
sum(f.(x, dist.v.args)) | ||
else | ||
sum(copy(logpdf.(dist.dists, x))) | ||
end | ||
end | ||
|
||
lazyarray(f, x...) = BroadcastArray(f, x...) | ||
export lazyarray | ||
Comment on lines
+45
to
+47
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's not clear to me why this is needed. It doesn't seem much shorter and it makes it less clear that everything is based on LazyArrays. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh no this was already in DistributionsAD.jl 🤷 Not something I put in here. I was also unaware of this methods existence. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we deprecate it? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Happy to! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This even has a bug in it:
dists
isn't defined..