diff --git a/docs/src/data/onehot.md b/docs/src/data/onehot.md index bb97e437cd..3aff242bee 100644 --- a/docs/src/data/onehot.md +++ b/docs/src/data/onehot.md @@ -6,13 +6,13 @@ It's common to encode categorical variables (like `true`, `false` or `cat`, `dog julia> using Flux: onehot, onecold julia> onehot(:b, [:a, :b, :c]) -3-element Flux.OneHotVector: +3-element Flux.OneHotArray{UInt32,3,0,1,UInt32}: 0 1 0 julia> onehot(:c, [:a, :b, :c]) -3-element Flux.OneHotVector: +3-element Flux.OneHotArray{UInt32,3,0,1,UInt32}: 0 0 1 @@ -44,7 +44,7 @@ Flux.onecold julia> using Flux: onehotbatch julia> onehotbatch([:b, :a, :b], [:a, :b, :c]) -3×3 Flux.OneHotMatrix{Array{Flux.OneHotVector,1}}: +3×3 Flux.OneHotArray{UInt32,3,1,2,Array{UInt32,1}}: 0 1 0 1 0 1 0 0 0 diff --git a/src/onehot.jl b/src/onehot.jl index b295f05afa..c9f0b145a0 100644 --- a/src/onehot.jl +++ b/src/onehot.jl @@ -1,52 +1,61 @@ -import Base: * +import Adapt +import .CUDA -struct OneHotVector <: AbstractVector{Bool} - ix::UInt32 - of::UInt32 +struct OneHotArray{T<:Integer, L, N, var"N+1", I<:Union{T, AbstractArray{T, N}}} <: AbstractArray{Bool, var"N+1"} + indices::I end +OneHotArray{T, L, N, I}(indices) where {T, L, N, I} = OneHotArray{T, L, N, N+1, I}(indices) +OneHotArray(indices::T, L::Integer) where {T<:Integer} = OneHotArray{T, L, 0, T}(indices) +OneHotArray(indices::AbstractArray{T, N}, L::Integer) where {T, N} = OneHotArray{T, L, N, typeof(indices)}(indices) -Base.size(xs::OneHotVector) = (Int64(xs.of),) +_indices(x::OneHotArray) = x.indices -Base.getindex(xs::OneHotVector, i::Integer) = i == xs.ix +const OneHotVector{T, L} = OneHotArray{T, L, 0, 1, T} +const OneHotMatrix{T, L, I} = OneHotArray{T, L, 1, 2, I} -Base.getindex(xs::OneHotVector, ::Colon) = OneHotVector(xs.ix, xs.of) +OneHotVector(idx, L) = OneHotArray(idx, L) +OneHotMatrix(indices, L) = OneHotArray(indices, L) -function Base.:*(A::AbstractMatrix, b::OneHotVector) - if size(A, 2) != b.of - throw(DimensionMismatch("Matrix column must correspond with OneHotVector size")) - end - return A[:, b.ix] -end +Base.size(x::OneHotArray{<:Any, L}) where L = (Int(L), size(x.indices)...) -struct OneHotMatrix{A<:AbstractVector{OneHotVector}} <: AbstractMatrix{Bool} - height::Int - data::A -end +_onehotindex(x, i) = (x == i) -Base.size(xs::OneHotMatrix) = (Int64(xs.height),length(xs.data)) +Base.getindex(x::OneHotVector, i::Integer) = _onehotindex(x.indices, i) +Base.getindex(x::OneHotVector{T, L}, ::Colon) where {T, L} = x -Base.getindex(xs::OneHotMatrix, i::Union{Integer, AbstractVector}, j::Integer) = xs.data[j][i] -Base.getindex(xs::OneHotMatrix, ::Colon, i::Integer) = xs.data[i] -Base.getindex(xs::OneHotMatrix, ::Colon, i::AbstractArray) = OneHotMatrix(xs.height, xs.data[i]) -Base.getindex(xs::OneHotMatrix, ::Colon, ::Colon) = OneHotMatrix(xs.height, copy(xs.data)) +Base.getindex(x::OneHotArray, i::Integer, I...) = _onehotindex.(x.indices[I...], i) +Base.getindex(x::OneHotArray{<:Any, L}, ::Colon, I...) where L = OneHotArray(x.indices[I...], L) +Base.getindex(x::OneHotArray{<:Any, <:Any, <:Any, N}, ::Vararg{Colon, N}) where N = x +Base.getindex(x::OneHotArray, I::CartesianIndex{N}) where N = x[I[1], Tuple(I)[2:N]...] -Base.getindex(xs::OneHotMatrix, i::Integer, ::Colon) = map(x -> x[i], xs.data) +_onehot_bool_type(x::OneHotArray{<:Any, <:Any, <:Any, N, <:Union{Integer, AbstractArray}}) where N = Array{Bool, N} +_onehot_bool_type(x::OneHotArray{<:Any, <:Any, <:Any, N, <:CuArray}) where N = CuArray{Bool, N} + +function Base.cat(xs::OneHotArray{<:Any, L}...; dims::Int) where L + if isone(dims) + return throw(ArgumentError("Cannot concat OneHotArray along first dimension. Use collect to convert to Bool array first.")) + else + return OneHotArray(cat(_indices.(xs)...; dims = dims - 1), L) + end +end -# remove workaround when https://github.com/JuliaGPU/CuArrays.jl/issues/676 is fixed -A::AbstractMatrix * B::OneHotMatrix = A[:, cpu(map(x->x.ix, B.data))] +Base.hcat(xs::OneHotArray...) = cat(xs...; dims = 2) +Base.vcat(xs::OneHotArray...) = cat(xs...; dims = 1) -Base.hcat(x::OneHotVector, xs::OneHotVector...) = OneHotMatrix(length(x), [x, xs...]) +Base.reshape(x::OneHotArray{<:Any, L}, dims::Dims) where L = + (first(dims) == L) ? OneHotArray(reshape(x.indices, dims[2:end]...), L) : + throw(ArgumentError("Cannot reshape OneHotArray if first(dims) != size(x, 1)")) +Base._reshape(x::OneHotArray, dims::Tuple{Vararg{Int}}) = reshape(x, dims) -batch(xs::AbstractArray{<:OneHotVector}) = OneHotMatrix(length(first(xs)), xs) +batch(xs::AbstractArray{<:OneHotVector{<:Any, L}}) where L = OneHotArray(_indices.(xs), L) -import Adapt: adapt, adapt_structure +Adapt.adapt_structure(T, x::OneHotArray{<:Any, L}) where L = OneHotArray(adapt(T, x.indices), L) -adapt_structure(T, xs::OneHotMatrix) = OneHotMatrix(xs.height, adapt(T, xs.data)) +Base.BroadcastStyle(::Type{<:OneHotArray{<:Any, <:Any, <:Any, N, <:CuArray}}) where N = CUDA.CuArrayStyle{N}() -import .CUDA: CuArray, CuArrayStyle, cudaconvert -import Base.Broadcast: BroadcastStyle, ArrayStyle -BroadcastStyle(::Type{<:OneHotMatrix{<:CuArray}}) = CuArrayStyle{2}() -cudaconvert(x::OneHotMatrix{<:CuArray}) = OneHotMatrix(x.height, cudaconvert(x.data)) +Base.argmax(x::OneHotArray; dims = Colon()) = + (dims == 1) ? reshape(CartesianIndex.(x.indices, CartesianIndices(x.indices)), 1, size(x.indices)...) : + argmax(convert(_onehot_bool_type(x), x); dims = dims) """ onehot(l, labels[, unk]) @@ -60,13 +69,13 @@ If `l` is not found in labels and `unk` is present, the function returns # Examples ```jldoctest julia> Flux.onehot(:b, [:a, :b, :c]) -3-element Flux.OneHotVector: +3-element Flux.OneHotArray{UInt32,3,0,1,UInt32}: 0 1 0 julia> Flux.onehot(:c, [:a, :b, :c]) -3-element Flux.OneHotVector: +3-element Flux.OneHotArray{UInt32,3,0,1,UInt32}: 0 0 1 @@ -75,13 +84,13 @@ julia> Flux.onehot(:c, [:a, :b, :c]) function onehot(l, labels) i = something(findfirst(isequal(l), labels), 0) i > 0 || error("Value $l is not in labels") - OneHotVector(i, length(labels)) + OneHotVector{UInt32, length(labels)}(i) end function onehot(l, labels, unk) i = something(findfirst(isequal(l), labels), 0) i > 0 || return onehot(unk, labels) - OneHotVector(i, length(labels)) + OneHotVector{UInt32, length(labels)}(i) end """ @@ -95,16 +104,13 @@ return [`onehot(unk, labels)`](@ref) ; otherwise the function will raise an erro # Examples ```jldoctest julia> Flux.onehotbatch([:b, :a, :b], [:a, :b, :c]) -3×3 Flux.OneHotMatrix{Array{Flux.OneHotVector,1}}: +3×3 Flux.OneHotArray{UInt32,3,1,2,Array{UInt32,1}}: 0 1 0 1 0 1 0 0 0 ``` """ -onehotbatch(ls, labels, unk...) = - OneHotMatrix(length(labels), [onehot(l, labels, unk...) for l in ls]) - -Base.argmax(xs::OneHotVector) = xs.ix +onehotbatch(ls, labels, unk...) = batch([onehot(l, labels, unk...) for l in ls]) """ onecold(y[, labels = 1:length(y)]) @@ -120,11 +126,20 @@ julia> Flux.onecold([0.3, 0.2, 0.5], [:a, :b, :c]) :c ``` """ -onecold(y::AbstractVector, labels = 1:length(y)) = labels[Base.argmax(y)] +onecold(y::AbstractVector, labels = 1:length(y)) = labels[argmax(y)] +function onecold(y::AbstractArray, labels = 1:size(y, 1)) + indices = _fast_argmax(y) + xs = isbits(labels) ? indices : collect(indices) # non-bit type cannot be handled by CUDA -onecold(y::AbstractMatrix, labels...) = - dropdims(mapslices(y -> onecold(y, labels...), y, dims=1), dims=1) + return map(xi -> labels[xi[1]], xs) +end -onecold(y::OneHotMatrix, labels...) = map(x -> Flux.onecold(x, labels...), y.data) +_fast_argmax(x::AbstractArray) = dropdims(argmax(x; dims = 1); dims = 1) +_fast_argmax(x::OneHotArray) = x.indices -@nograd onecold, onehot, onehotbatch +@nograd OneHotArray, onecold, onehot, onehotbatch + +function Base.:(*)(A::AbstractMatrix, B::OneHotArray{<:Any, L}) where L + size(A, 2) == L || throw(DimensionMismatch("Matrix column must correspond with OneHot size: $(size(A, 2)) != $L")) + return A[:, onecold(B)] +end diff --git a/test/cuda/cuda.jl b/test/cuda/cuda.jl index 5aaec710b6..92c04404f8 100644 --- a/test/cuda/cuda.jl +++ b/test/cuda/cuda.jl @@ -13,7 +13,7 @@ using LinearAlgebra: I, cholesky, Cholesky x = Flux.onehotbatch([1, 2, 3], 1:3) cx = gpu(x) - @test cx isa Flux.OneHotMatrix && cx.data isa CuArray + @test cx isa Flux.OneHotMatrix && cx.indices isa CuArray @test (cx .+ 1) isa CuArray m = Chain(Dense(10, 5, tanh), Dense(5, 2), softmax) @@ -40,8 +40,10 @@ end @testset "onecold gpu" begin y = Flux.onehotbatch(ones(3), 1:10) |> gpu; + l = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j'] @test Flux.onecold(y) isa CuArray @test y[3,:] isa CuArray + @test Flux.onecold(y, l) == ['a', 'a', 'a'] end @testset "restructure gpu" begin diff --git a/test/onehot.jl b/test/onehot.jl index 591e899920..9461bc816d 100644 --- a/test/onehot.jl +++ b/test/onehot.jl @@ -27,9 +27,83 @@ end @testset "abstractmatrix onehotvector multiplication" begin A = [1 3 5; 2 4 6; 3 6 9] - b1 = Flux.OneHotVector(1,3) - b2 = Flux.OneHotVector(3,5) + b1 = Flux.OneHotVector(1, 3) + b2 = Flux.OneHotVector(3, 5) @test A*b1 == A[:,1] @test_throws DimensionMismatch A*b2 +end + +@testset "OneHotArray" begin + using Flux: OneHotArray, OneHotVector, OneHotMatrix + + ov = OneHotVector(rand(1:10), 10) + om = OneHotMatrix(rand(1:10, 5), 10) + oa = OneHotArray(rand(1:10, 5, 5), 10) + + # sizes + @testset "Base.size" begin + @test size(ov) == (10,) + @test size(om) == (10, 5) + @test size(oa) == (10, 5, 5) + end + + @testset "Indexing" begin + # vector indexing + @test ov[3] == (ov.indices == 3) + @test ov[:] == ov + + # matrix indexing + @test om[3, 3] == (om.indices[3] == 3) + @test om[:, 3] == OneHotVector(om.indices[3], 10) + @test om[3, :] == (om.indices .== 3) + @test om[:, :] == om + + # array indexing + @test oa[3, 3, 3] == (oa.indices[3, 3] == 3) + @test oa[:, 3, 3] == OneHotVector(oa.indices[3, 3], 10) + @test oa[3, :, 3] == (oa.indices[:, 3] .== 3) + @test oa[3, :, :] == (oa.indices .== 3) + @test oa[:, 3, :] == OneHotMatrix(oa.indices[3, :], 10) + @test oa[:, :, :] == oa + + # cartesian indexing + @test oa[CartesianIndex(3, 3, 3)] == oa[3, 3, 3] + end + + @testset "Concatenating" begin + # vector cat + @test hcat(ov, ov) == OneHotMatrix(vcat(ov.indices, ov.indices), 10) + @test_throws ArgumentError vcat(ov, ov) + @test cat(ov, ov; dims = 3) == OneHotArray(cat(ov.indices, ov.indices; dims = 2), 10) + + # matrix cat + @test hcat(om, om) == OneHotMatrix(vcat(om.indices, om.indices), 10) + @test_throws ArgumentError vcat(om, om) + @test cat(om, om; dims = 3) == OneHotArray(cat(om.indices, om.indices; dims = 2), 10) + + # array cat + @test cat(oa, oa; dims = 3) == OneHotArray(cat(oa.indices, oa.indices; dims = 2), 10) + @test_throws ArgumentError cat(oa, oa; dims = 1) + end + + @testset "Base.reshape" begin + # reshape test + @test reshape(oa, 10, 25) isa OneHotArray + @test reshape(oa, 10, :) isa OneHotArray + @test reshape(oa, :, 25) isa OneHotArray + @test_throws ArgumentError reshape(oa, 50, :) + @test_throws ArgumentError reshape(oa, 5, 10, 5) + @test reshape(oa, (10, 25)) isa OneHotArray + end + + @testset "Base.argmax" begin + # argmax test + @test argmax(ov) == argmax(convert(Array{Bool}, ov)) + @test argmax(om) == argmax(convert(Array{Bool}, om)) + @test argmax(om; dims = 1) == argmax(convert(Array{Bool}, om); dims = 1) + @test argmax(om; dims = 2) == argmax(convert(Array{Bool}, om); dims = 2) + @test argmax(oa; dims = 1) == argmax(convert(Array{Bool}, oa); dims = 1) + @test argmax(oa; dims = 3) == argmax(convert(Array{Bool}, oa); dims = 3) + end end \ No newline at end of file