@@ -40,6 +40,17 @@ Base.propertynames(p::ProjectTo) = propertynames(backing(p))
4040backing (project:: ProjectTo ) = getfield (project, :info )
4141
4242project_type (p:: ProjectTo{T} ) where {T} = T
43+ project_eltype (p:: ProjectTo{T} ) where {T} = eltype (T)
44+
45+ function project_promote_type (projectors)
46+ T = mapreduce (project_type, promote_type, projectors)
47+ if T <: Number
48+ # The point of this function is to make p.element for arrays. Not in use yet!
49+ return ProjectTo (zero (T))
50+ else
51+ return ProjectTo {Any} ()
52+ end
53+ end
4354
4455function Base. show (io:: IO , project:: ProjectTo{T} ) where {T}
4556 print (io, " ProjectTo{" )
178189# no structure worth re-imposing. Then any array is acceptable as a gradient.
179190
180191# For arrays of numbers, just store one projector:
181- function ProjectTo (x:: AbstractArray{T} ) where {T<: Number }
182- return ProjectTo {AbstractArray} (; element= _eltype_projectto (T), axes= axes (x))
192+ function ProjectTo (x:: AbstractArray{T,N} ) where {T<: Number ,N}
193+ element = _eltype_projectto (T)
194+ S = project_type (element) # new idea -- for any number, S is enough.
195+ # Store .element for now too, although it's redundant? Reconstruct from eltype?
196+ if axes (x) isa NTuple{N,Base. OneTo{Int}}
197+ return ProjectTo {AbstractArray{S,N}} (; element= element, axes= axes (x))
198+ else
199+ # Omitting N prohibits the fast path, and thus won't skip reshape for OffsetArrays, SArrays, etc.
200+ return ProjectTo {AbstractArray{S}} (; element= element, axes= axes (x))
201+ end
183202end
184203ProjectTo (x:: AbstractArray{Bool} ) = ProjectTo {NoTangent} ()
185204
@@ -197,7 +216,7 @@ function ProjectTo(xs::AbstractArray)
197216 end
198217end
199218
200- function (project:: ProjectTo{AbstractArray} )(dx:: AbstractArray{S,M} ) where {S,M}
219+ function (project:: ProjectTo{<: AbstractArray} )(dx:: AbstractArray{S,M} ) where {S,M}
201220 # First deal with shape. The rule is that we reshape to add or remove trivial dimensions
202221 # like dx = ones(4,1), where x = ones(4), but throw an error on dx = ones(1,4) etc.
203222 dy = if axes (dx) == project. axes
@@ -221,24 +240,33 @@ function (project::ProjectTo{AbstractArray})(dx::AbstractArray{S,M}) where {S,M}
221240 return dz
222241end
223242
243+ # Fast path, for arrays of numbers:
244+ # (::ProjectTo{AbstractArray{T,N}})(dx::AbstractArray{T,N}) where {T,N} = (@info "fast 1"; dx)
245+ (:: ProjectTo{AbstractArray{T,N}} )(dx:: AbstractArray{S,N} ) where {S<: T } where {T,N} = dx # (@info "fast 2"; dx)
246+ (project:: ProjectTo{AbstractArray{T,N}} )(dx:: AbstractArray{S,N} ) where {S,T,N} = (@info " fast 3" ; map (project. element, dx))
247+
224248# Trivial case, this won't collapse Any[NoTangent(), NoTangent()] but that's OK.
225- (project:: ProjectTo{AbstractArray} )(dx:: AbstractArray{<:AbstractZero} ) = NoTangent ()
249+ (project:: ProjectTo{AbstractArray{T,N}} )(dx:: AbstractArray{<:AbstractZero} ) where {T,N} = NoTangent ()
226250
227251# Row vectors aren't acceptable as gradients for 1-row matrices:
228- function (project:: ProjectTo{AbstractArray} )(dx:: LinearAlgebra.AdjOrTransAbsVec )
252+ # function (project::ProjectTo{<:AbstractArray})(dx::LinearAlgebra.AdjOrTransAbsVec)
253+ # return project(reshape(vec(dx), 1, :))
254+ # end
255+ function (project:: ProjectTo{AbstractArray{T,N}} )(dx:: LinearAlgebra.AdjOrTransAbsVec ) where {T,N}
229256 return project (reshape (vec (dx), 1 , :))
230257end
231258
232259# Zero-dimensional arrays -- these have a habit of going missing,
233260# although really Ref() is probably a better structure.
234- function (project:: ProjectTo{AbstractArray} )(dx:: Number ) # ... so we restore from numbers
235- if ! (project. axes isa Tuple{})
236- throw (DimensionMismatch (
237- " array with ndims(x) == $(length (project. axes)) > 0 cannot have dx::Number" ,
238- ))
239- end
240- return fill (project. element (dx))
241- end
261+ # function (project::ProjectTo{<:AbstractArray})(dx::Number) # ... so we restore from numbers
262+ # if !(project.axes isa Tuple{})
263+ # throw(DimensionMismatch(
264+ # "array with ndims(x) == $(length(project.axes)) > 0 cannot have dx::Number",
265+ # ))
266+ # end
267+ # return fill(project.element(dx))
268+ # end
269+ (project:: ProjectTo{AbstractArray{<:Number,0}} )(dx:: Number ) = fill (project. element (dx))
242270
243271# Ref -- works like a zero-array, also allows restoration from a number:
244272ProjectTo (x:: Ref ) = ProjectTo {Ref} (; x= ProjectTo (x[]))
0 commit comments