Skip to content

Conversation

@maartenvd
Copy link
Contributor

fixes #1301

@mcabbott
Copy link
Member

mcabbott commented Sep 1, 2022

Glad existing tests pass with this.

Is there an easy way to test the new behaviour? Ideally hitting both paths (i.e. broadcasting a fairly simple function and a nasty one). Is there a simple package with an array type for which the old code didn't work & the new does?

@maartenvd
Copy link
Contributor Author

I guess LinearAlgebra.Diagonal could serve as a test (it does fail on master, and works here)? I'm not sure what the cleanest way is of hitting the slow fallback though

@mcabbott
Copy link
Member

mcabbott commented Sep 1, 2022

One way to be sure of hitting the slow path is to broadcast a closure, like

gradient((x,y) -> sum(broadcast((z -> (z*x)^2), y)), 1.0, [2,3,4.0])

@maartenvd
Copy link
Contributor Author

I'll probably return to this tomorrow, adding some simple tests

@mcabbott
Copy link
Member

mcabbott commented Sep 3, 2022

Am not sure Diagonal will work as a test, and the one in ea21486 seems to pass on tagged Zygote.

In general its map & broadcast seem to agree:

julia> map(sqrt, Diagonal([2,3]))
2×2 Diagonal{Float64, Vector{Float64}}:
 1.41421   ⋅ 
  ⋅       1.73205

julia> broadcast(sqrt, Diagonal([2,3]))
2×2 Diagonal{Float64, Vector{Float64}}:
 1.41421   ⋅ 
  ⋅       1.73205

julia> map(cos, Diagonal([2,3]))
2×2 Matrix{Float64}:
 -0.416147   1.0
  1.0       -0.989992

julia> broadcast(cos, Diagonal([2,3]))
2×2 Matrix{Float64}:
 -0.416147   1.0
  1.0       -0.989992

I thought perhaps there were some disagreements for Adjoint, but now I don't see them. Something like [1,2]' .+ 3 makes a Matrix, but that's un-map-like. For types from packages, many overload map to preserve themselves:

julia> using NamedDims

julia> sqrt.(NamedDimsArray([1,2], :a))
2-element NamedDimsArray(::Vector{Float64}, :a):
↓ a  1.0
     1.4142135623730951

julia> map(sqrt, NamedDimsArray([1,2], :a))
2-element NamedDimsArray(::Vector{Float64}, :a):
↓ a  1.0
     1.4142135623730951

Is it possible that the issue here is just that the original array type lacks such an overload? Or maybe there's a reason it cannot?

@maartenvd
Copy link
Contributor Author

I am also not entirely sure anymore - the documentation about implementing your own array types at the very least does not mention overloading map, but does mention specifying the broadcaststyle (which determines the output type of broadcasts) (see https://docs.julialang.org/en/v1/manual/interfaces/ )

This map vs broadcast thing could cause other problems, but I feel that in general julia really should make both broadcast and map use the same output type automatically. Using only base types I could come up with this example:

julia> a = Set((1,2,3))
Set{Int64} with 3 elements:
  2
  3
  1

julia> broadcast(x->x+3,a)
3-element Vector{Int64}:
 5
 6
 4

julia> map(x->x+3,a)
ERROR: map is not defined on sets
Stacktrace:
 [1] error(s::String)
   @ Base .\error.jl:33
 [2] map(f::Function, #unused#::Set{Int64})
   @ Base .\abstractarray.jl:2326
 [3] top-level scope
   @ REPL[26]:1

@mcabbott
Copy link
Member

mcabbott commented Sep 8, 2022

Note that many functions like sortslices(PeriodicArray(rand(3,3)); dims=1) do work, via similar. In fact I think map will work without being explicitly overloaded, if you define this method:

Base.similar(a::PeriodicArray, ::Type{T}, dims::Tuple{Vararg{Int64, N}}) where {T, N} = PeriodicArray(similar(a.data, T, dims))

I think the one with a splat calls this. The docs mention Base.Dims which is Tuple{Vararg{Int64, N}} where N but maybe this could be much more explicit? Possibly these ones may be necc too, maybe only if you want to wrap things with offset indexing?

 [80] similar(a::AbstractArray, ::Type{T}, dims::Tuple{Vararg{Int64, N}}) where {T, N}
     @ abstractarray.jl:827
 [81] similar(a::AbstractArray, ::Type{T}, dims::Tuple{Integer, Vararg{Integer}}) where T
     @ abstractarray.jl:825
 [82] similar(a::AbstractArray, ::Type{T}, dims::Tuple{Union{Integer, Base.OneTo}, Vararg{Union{Integer, Base.OneTo}}}) where T
     @ abstractarray.jl:824

@maartenvd
Copy link
Contributor Author

But would you then rather not change zygote to correctly call broadcast where it should? Consider the example of Set, where map errors while broadcast works just fine. I really don't like having to define methods outside the official recommended interface, just to work around something that zygote arguably does wrong

@mcabbott
Copy link
Member

mcabbott commented Sep 9, 2022

I do think this is the documented interface, although I agree many implementations don't quite follow it (including some which I wrote). Making Zygote more forgiving is OK.

I can't see a way to push a Set into this bit of Zygote. Supposedly it does support Dict, for which neither broadcast nor map works. I think even Tuple doesn't reach this code, only arrays.

One reason people write map when there's one argument / shapes match is that broadcasting is pretty complex. E.g. for tuples this is very simple unrolled thing, instead of a giant tower of functions. But I think there's no chance this actually matters for performance here.

@mcabbott mcabbott merged commit 4e10cea into FluxML:master Sep 19, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

zygote broadcast type stability

2 participants