-
-
Couldn't load subscription status.
- Fork 615
Arbitrary dimension one-hot arrays #1448
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
d4eb7d0
47ab622
a98727f
4355ff1
630e5f0
36183f0
49befe0
2bad6b9
3b89d85
5a60580
28ece66
ef1ecb3
aa63a5a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,52 +1,57 @@ | ||
| import Base: * | ||
| import Adapt | ||
| import .CUDA | ||
|
|
||
| struct OneHotVector <: AbstractVector{Bool} | ||
| ix::UInt32 | ||
| of::UInt32 | ||
| const OneHotIndex{T, N} = Union{T, AbstractArray{T, N}} | ||
|
|
||
| struct OneHotArray{T<:Integer, L, N, var"N+1", I<:OneHotIndex{T, N}} <: AbstractArray{Bool, var"N+1"} | ||
darsnack marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| indices::I | ||
| end | ||
| OneHotArray{T, L, N, I}(indices) where {T, L, N, I} = OneHotArray{T, L, N, N+1, I}(indices) | ||
| OneHotArray(L::Integer, indices::T) where {T<:Integer} = OneHotArray{T, L, 0, T}(indices) | ||
| OneHotArray(L::Integer, indices::AbstractArray{T, N}) where {T, N} = OneHotArray{T, L, N, typeof(indices)}(indices) | ||
|
|
||
| Base.size(xs::OneHotVector) = (Int64(xs.of),) | ||
| _indices(x::OneHotArray) = x.indices | ||
|
|
||
| Base.getindex(xs::OneHotVector, i::Integer) = i == xs.ix | ||
| const OneHotVector{T, L} = OneHotArray{T, L, 0, 1, T} | ||
| const OneHotMatrix{T, L, I} = OneHotArray{T, L, 1, 2, I} | ||
|
|
||
| Base.getindex(xs::OneHotVector, ::Colon) = OneHotVector(xs.ix, xs.of) | ||
| Base.size(x::OneHotArray{<:Any, L}) where L = (Int(L), size(x.indices)...) | ||
|
|
||
| function Base.:*(A::AbstractMatrix, b::OneHotVector) | ||
| if size(A, 2) != b.of | ||
| throw(DimensionMismatch("Matrix column must correspond with OneHotVector size")) | ||
| end | ||
| return A[:, b.ix] | ||
| end | ||
| _onehotindex(x, i) = (x == i) | ||
|
|
||
| struct OneHotMatrix{A<:AbstractVector{OneHotVector}} <: AbstractMatrix{Bool} | ||
| height::Int | ||
| data::A | ||
| end | ||
| Base.getindex(x::OneHotVector, i::Integer) = _onehotindex(x.indices, i) | ||
| Base.getindex(x::OneHotVector{T, L}, ::Colon) where {T, L} = OneHotVector{T, L}(x.indices) | ||
darsnack marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| Base.size(xs::OneHotMatrix) = (Int64(xs.height),length(xs.data)) | ||
| Base.getindex(x::OneHotArray, i::Integer, I...) = _onehotindex.(x.indices[I...], i) | ||
| Base.getindex(x::OneHotArray{<:Any, L}, ::Colon, I...) where L = OneHotArray(L, x.indices[I...]) | ||
| Base.getindex(x::OneHotArray, I::CartesianIndex{N}) where N = x[I[1], Tuple(I)[2:N]...] | ||
|
|
||
| Base.getindex(xs::OneHotMatrix, i::Union{Integer, AbstractVector}, j::Integer) = xs.data[j][i] | ||
| Base.getindex(xs::OneHotMatrix, ::Colon, i::Integer) = xs.data[i] | ||
| Base.getindex(xs::OneHotMatrix, ::Colon, i::AbstractArray) = OneHotMatrix(xs.height, xs.data[i]) | ||
| Base.getindex(xs::OneHotMatrix, ::Colon, ::Colon) = OneHotMatrix(xs.height, copy(xs.data)) | ||
| _onehot_bool_type(x::OneHotArray{<:Any, <:Any, <:Any, N, <:OneHotIndex}) where N = Array{Bool, N} | ||
| _onehot_bool_type(x::OneHotArray{<:Any, <:Any, <:Any, N, <:CuArray}) where N = CuArray{Bool, N} | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We might want to simply return the type of the underlying array iiic There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The underlying array is an integer array. This is just an internal convenience function I use when I want to convert the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. But why do we need to have this? The implementation for CuArray should fall straight out of assuming regular array |
||
|
|
||
| Base.getindex(xs::OneHotMatrix, i::Integer, ::Colon) = map(x -> x[i], xs.data) | ||
| function Base.cat(xs::OneHotArray{<:Any, L}...; dims::Int) where L | ||
darsnack marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| if isone(dims) | ||
| return cat(map(x -> convert(_onehot_bool_type(x), x), xs)...; dims = 1) | ||
| else | ||
| return OneHotArray(L, cat(_indices.(xs)...; dims = dims - 1)) | ||
| end | ||
| end | ||
|
|
||
| # remove workaround when https://github.com/JuliaGPU/CuArrays.jl/issues/676 is fixed | ||
| A::AbstractMatrix * B::OneHotMatrix = A[:, cpu(map(x->x.ix, B.data))] | ||
| Base.hcat(xs::OneHotArray...) = cat(xs...; dims = 2) | ||
| Base.vcat(xs::OneHotArray...) = cat(xs...; dims = 1) | ||
|
|
||
| Base.hcat(x::OneHotVector, xs::OneHotVector...) = OneHotMatrix(length(x), [x, xs...]) | ||
| Base.reshape(x::OneHotArray{<:Any, L}, dims...) where L = | ||
darsnack marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| (first(dims) == L) ? OneHotArray(L, reshape(x.indices, dims[2:end]...)) : reshape(x, dims...) | ||
|
|
||
| batch(xs::AbstractArray{<:OneHotVector}) = OneHotMatrix(length(first(xs)), xs) | ||
| batch(xs::AbstractArray{<:OneHotVector{<:Any, L}}) where L = OneHotArray(L, _indices.(xs)) | ||
|
|
||
| import Adapt: adapt, adapt_structure | ||
| Adapt.adapt_structure(T, x::OneHotArray{<:Any, L}) where L = OneHotArray(L, adapt(T, x.indices)) | ||
|
|
||
| adapt_structure(T, xs::OneHotMatrix) = OneHotMatrix(xs.height, adapt(T, xs.data)) | ||
| Base.BroadcastStyle(::Type{<:OneHotArray{<:Any, <:Any, <:Any, N, <:CuArray}}) where N = CuArrayStyle{N}() | ||
darsnack marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| import .CUDA: CuArray, CuArrayStyle, cudaconvert | ||
| import Base.Broadcast: BroadcastStyle, ArrayStyle | ||
| BroadcastStyle(::Type{<:OneHotMatrix{<:CuArray}}) = CuArrayStyle{2}() | ||
| cudaconvert(x::OneHotMatrix{<:CuArray}) = OneHotMatrix(x.height, cudaconvert(x.data)) | ||
| Base.argmax(x::OneHotArray; dims = Colon()) = | ||
| (dims == 1) ? reshape(CartesianIndex.(x.indices, CartesianIndices(x.indices)), 1, size(x.indices)...) : | ||
| argmax(convert(_onehot_bool_type(x), x); dims = dims) | ||
|
|
||
| """ | ||
| onehot(l, labels[, unk]) | ||
|
|
@@ -60,13 +65,13 @@ If `l` is not found in labels and `unk` is present, the function returns | |
| # Examples | ||
| ```jldoctest | ||
| julia> Flux.onehot(:b, [:a, :b, :c]) | ||
| 3-element Flux.OneHotVector: | ||
| 3-element Flux.OneHotArray{UInt32,3,0,1,UInt32}: | ||
| 0 | ||
| 1 | ||
| 0 | ||
|
|
||
| julia> Flux.onehot(:c, [:a, :b, :c]) | ||
| 3-element Flux.OneHotVector: | ||
| 3-element Flux.OneHotArray{UInt32,3,0,1,UInt32}: | ||
| 0 | ||
| 0 | ||
| 1 | ||
|
|
@@ -75,13 +80,13 @@ julia> Flux.onehot(:c, [:a, :b, :c]) | |
| function onehot(l, labels) | ||
| i = something(findfirst(isequal(l), labels), 0) | ||
| i > 0 || error("Value $l is not in labels") | ||
| OneHotVector(i, length(labels)) | ||
| OneHotVector{UInt32, length(labels)}(i) | ||
| end | ||
|
|
||
| function onehot(l, labels, unk) | ||
| i = something(findfirst(isequal(l), labels), 0) | ||
| i > 0 || return onehot(unk, labels) | ||
| OneHotVector(i, length(labels)) | ||
| OneHotVector{UInt32, length(labels)}(i) | ||
| end | ||
|
|
||
| """ | ||
|
|
@@ -95,16 +100,13 @@ return [`onehot(unk, labels)`](@ref) ; otherwise the function will raise an erro | |
| # Examples | ||
| ```jldoctest | ||
| julia> Flux.onehotbatch([:b, :a, :b], [:a, :b, :c]) | ||
| 3×3 Flux.OneHotMatrix{Array{Flux.OneHotVector,1}}: | ||
| 3×3 Flux.OneHotArray{UInt32,3,1,2,Array{UInt32,1}}: | ||
| 0 1 0 | ||
| 1 0 1 | ||
| 0 0 0 | ||
| ``` | ||
| """ | ||
| onehotbatch(ls, labels, unk...) = | ||
| OneHotMatrix(length(labels), [onehot(l, labels, unk...) for l in ls]) | ||
|
|
||
| Base.argmax(xs::OneHotVector) = xs.ix | ||
| onehotbatch(ls, labels, unk...) = batch([onehot(l, labels, unk...) for l in ls]) | ||
|
|
||
| """ | ||
| onecold(y[, labels = 1:length(y)]) | ||
|
|
@@ -120,11 +122,20 @@ julia> Flux.onecold([0.3, 0.2, 0.5], [:a, :b, :c]) | |
| :c | ||
| ``` | ||
| """ | ||
| onecold(y::AbstractVector, labels = 1:length(y)) = labels[Base.argmax(y)] | ||
| onecold(y::AbstractVector, labels = 1:length(y)) = labels[argmax(y)] | ||
| function onecold(y::AbstractArray, labels = 1:size(y, 1)) | ||
| indices = _fast_argmax(y) | ||
| xs = isbits(labels) ? indices : collect(indices) # non-bit type cannot be handled by CUDA | ||
darsnack marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| onecold(y::AbstractMatrix, labels...) = | ||
| dropdims(mapslices(y -> onecold(y, labels...), y, dims=1), dims=1) | ||
| return map(xi -> labels[xi[1]], xs) | ||
| end | ||
|
|
||
| onecold(y::OneHotMatrix, labels...) = map(x -> Flux.onecold(x, labels...), y.data) | ||
| _fast_argmax(x::AbstractArray) = dropdims(argmax(x; dims = 1); dims = 1) | ||
| _fast_argmax(x::OneHotArray) = convert(AbstractArray, x.indices) | ||
darsnack marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| @nograd onecold, onehot, onehotbatch | ||
| @nograd OneHotArray, onecold, onehot, onehotbatch | ||
|
|
||
| function Base.:(*)(A::AbstractMatrix, B::OneHotArray{<:Any, L}) where L | ||
| size(A, 2) == L || throw(DimensionMismatch("Matrix column must correspond with OneHot size: $(size(A, 2)) != $L")) | ||
| return A[:, onecold(B)] | ||
| end | ||
Uh oh!
There was an error while loading. Please reload this page.