From d4eb7d02595006667157f6020b08a51d1ed1dd7d Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Fri, 1 Jan 2021 13:20:47 -0600 Subject: [PATCH 01/13] Initial arbitrary one hot implementation passing tests. Need to benchmark on machine w/ GPU. --- src/onehot.jl | 102 +++++++++++++++++++++++++--------------------- test/cuda/cuda.jl | 2 +- test/onehot.jl | 4 +- 3 files changed, 58 insertions(+), 50 deletions(-) diff --git a/src/onehot.jl b/src/onehot.jl index b295f05afa..9203c11018 100644 --- a/src/onehot.jl +++ b/src/onehot.jl @@ -1,52 +1,57 @@ -import Base: * +import Adapt +import .CUDA -struct OneHotVector <: AbstractVector{Bool} - ix::UInt32 - of::UInt32 +const OneHotIndex{T, N} = Union{T, AbstractArray{T, N}} + +struct OneHotArray{T<:Integer, L, N, var"N+1", I<:OneHotIndex{T, N}} <: AbstractArray{Bool, var"N+1"} + indices::I end +OneHotArray{T, L, N, I}(indices) where {T, L, N, I} = OneHotArray{T, L, N, N+1, I}(indices) +OneHotArray(L::Integer, indices::T) where {T<:Integer} = OneHotArray{T, L, 0, T}(indices) +OneHotArray(L::Integer, indices::AbstractArray{T, N}) where {T, N} = OneHotArray{T, L, N, typeof(indices)}(indices) -Base.size(xs::OneHotVector) = (Int64(xs.of),) +_indices(x::OneHotArray) = x.indices -Base.getindex(xs::OneHotVector, i::Integer) = i == xs.ix +const OneHotVector{T, L} = OneHotArray{T, L, 0, 1, T} +const OneHotMatrix{T, L, I} = OneHotArray{T, L, 1, 2, I} -Base.getindex(xs::OneHotVector, ::Colon) = OneHotVector(xs.ix, xs.of) +Base.size(x::OneHotArray{<:Any, L}) where L = (Int(L), size(x.indices)...) -function Base.:*(A::AbstractMatrix, b::OneHotVector) - if size(A, 2) != b.of - throw(DimensionMismatch("Matrix column must correspond with OneHotVector size")) - end - return A[:, b.ix] -end +_onehotindex(x, i) = (x == i) -struct OneHotMatrix{A<:AbstractVector{OneHotVector}} <: AbstractMatrix{Bool} - height::Int - data::A -end +Base.getindex(x::OneHotVector, i::Integer) = _onehotindex(x.indices, i) +Base.getindex(x::OneHotVector{T, L}, ::Colon) where {T, L} = OneHotVector{T, L}(x.indices) -Base.size(xs::OneHotMatrix) = (Int64(xs.height),length(xs.data)) +Base.getindex(x::OneHotArray, i::Integer, I...) = _onehotindex.(x.indices[I...], i) +Base.getindex(x::OneHotArray{<:Any, L}, ::Colon, I...) where L = OneHotArray(L, x.indices[I...]) +Base.getindex(x::OneHotArray, I::CartesianIndex{N}) where N = x[I[1], Tuple(I)[2:N]...] -Base.getindex(xs::OneHotMatrix, i::Union{Integer, AbstractVector}, j::Integer) = xs.data[j][i] -Base.getindex(xs::OneHotMatrix, ::Colon, i::Integer) = xs.data[i] -Base.getindex(xs::OneHotMatrix, ::Colon, i::AbstractArray) = OneHotMatrix(xs.height, xs.data[i]) -Base.getindex(xs::OneHotMatrix, ::Colon, ::Colon) = OneHotMatrix(xs.height, copy(xs.data)) +_onehot_bool_type(x::OneHotArray{<:Any, <:Any, <:Any, N, <:OneHotIndex}) where N = Array{Bool, N} +_onehot_bool_type(x::OneHotArray{<:Any, <:Any, <:Any, N, <:CuArray}) where N = CuArray{Bool, N} -Base.getindex(xs::OneHotMatrix, i::Integer, ::Colon) = map(x -> x[i], xs.data) +function Base.cat(xs::OneHotArray{<:Any, L}...; dims::Int) where L + if isone(dims) + return cat(map(x -> convert(_onehot_bool_type(x), x), xs)...; dims = 1) + else + return OneHotArray(L, cat(_indices.(xs)...; dims = dims = dims - 1)) + end +end -# remove workaround when https://github.com/JuliaGPU/CuArrays.jl/issues/676 is fixed -A::AbstractMatrix * B::OneHotMatrix = A[:, cpu(map(x->x.ix, B.data))] +Base.hcat(xs::OneHotArray...) = cat(xs...; dims = 2) +Base.vcat(xs::OneHotArray...) = cat(xs...; dims = 1) -Base.hcat(x::OneHotVector, xs::OneHotVector...) = OneHotMatrix(length(x), [x, xs...]) +Base.reshape(x::OneHotArray{<:Any, L}, dims...) where L = + (first(dims) == L) ? OneHotArray(L, reshape(x.indices, dims[2:end]...)) : reshape(x, dims...) -batch(xs::AbstractArray{<:OneHotVector}) = OneHotMatrix(length(first(xs)), xs) +batch(xs::AbstractArray{<:OneHotVector{<:Any, L}}) where L = OneHotArray(L, _indices.(xs)) -import Adapt: adapt, adapt_structure +Adapt.adapt_structure(T, x::OneHotArray{<:Any, L}) where L = OneHotArray(L, adapt(T, x.indices)) -adapt_structure(T, xs::OneHotMatrix) = OneHotMatrix(xs.height, adapt(T, xs.data)) +Base.BroadcastStyle(::Type{<:OneHotArray{<:Any, <:Any, <:Any, N, <:CuArray}}) where N = CuArrayStyle{N}() -import .CUDA: CuArray, CuArrayStyle, cudaconvert -import Base.Broadcast: BroadcastStyle, ArrayStyle -BroadcastStyle(::Type{<:OneHotMatrix{<:CuArray}}) = CuArrayStyle{2}() -cudaconvert(x::OneHotMatrix{<:CuArray}) = OneHotMatrix(x.height, cudaconvert(x.data)) +Base.argmax(x::OneHotArray; dims = Colon()) = + (dims == 1) ? reshape(CartesianIndex.(x.indices, CartesianIndices(x.indices)), 1, size(x.indices)...) : + argmax(convert(_onehot_bool_type(x), x); dims = dims) """ onehot(l, labels[, unk]) @@ -60,13 +65,13 @@ If `l` is not found in labels and `unk` is present, the function returns # Examples ```jldoctest julia> Flux.onehot(:b, [:a, :b, :c]) -3-element Flux.OneHotVector: +3-element Flux.OneHotArray{UInt32,3,0,1,UInt32}: 0 1 0 julia> Flux.onehot(:c, [:a, :b, :c]) -3-element Flux.OneHotVector: +3-element Flux.OneHotArray{UInt32,3,0,1,UInt32}: 0 0 1 @@ -75,13 +80,13 @@ julia> Flux.onehot(:c, [:a, :b, :c]) function onehot(l, labels) i = something(findfirst(isequal(l), labels), 0) i > 0 || error("Value $l is not in labels") - OneHotVector(i, length(labels)) + OneHotVector{UInt32, length(labels)}(i) end function onehot(l, labels, unk) i = something(findfirst(isequal(l), labels), 0) i > 0 || return onehot(unk, labels) - OneHotVector(i, length(labels)) + OneHotVector{UInt32, length(labels)}(i) end """ @@ -95,16 +100,13 @@ return [`onehot(unk, labels)`](@ref) ; otherwise the function will raise an erro # Examples ```jldoctest julia> Flux.onehotbatch([:b, :a, :b], [:a, :b, :c]) -3×3 Flux.OneHotMatrix{Array{Flux.OneHotVector,1}}: +3×3 Flux.OneHotArray{UInt32,3,1,2,Array{UInt32,1}}: 0 1 0 1 0 1 0 0 0 ``` """ -onehotbatch(ls, labels, unk...) = - OneHotMatrix(length(labels), [onehot(l, labels, unk...) for l in ls]) - -Base.argmax(xs::OneHotVector) = xs.ix +onehotbatch(ls, labels, unk...) = batch([onehot(l, labels, unk...) for l in ls]) """ onecold(y[, labels = 1:length(y)]) @@ -120,11 +122,17 @@ julia> Flux.onecold([0.3, 0.2, 0.5], [:a, :b, :c]) :c ``` """ -onecold(y::AbstractVector, labels = 1:length(y)) = labels[Base.argmax(y)] +onecold(y::AbstractVector, labels = 1:length(y)) = labels[argmax(y)] +function onecold(y::AbstractArray, labels = 1:size(y, 1)) + indices = dropdims(argmax(y; dims = 1); dims = 1) + xs = isbits(labels) ? indices : collect(indices) # non-bit type cannot be handled by CUDA -onecold(y::AbstractMatrix, labels...) = - dropdims(mapslices(y -> onecold(y, labels...), y, dims=1), dims=1) + return map(xi -> labels[xi[1]], xs) +end -onecold(y::OneHotMatrix, labels...) = map(x -> Flux.onecold(x, labels...), y.data) +@nograd OneHotArray, onecold, onehot, onehotbatch -@nograd onecold, onehot, onehotbatch +function Base.:(*)(A::AbstractMatrix, B::OneHotArray{<:Any, L}) where L + size(A, 2) == L || throw(DimensionMismatch("Matrix column must correspond with OneHot size: $(size(A, 2)) != $L")) + return A[:, onecold(B)] +end \ No newline at end of file diff --git a/test/cuda/cuda.jl b/test/cuda/cuda.jl index 5aaec710b6..17851bc2cb 100644 --- a/test/cuda/cuda.jl +++ b/test/cuda/cuda.jl @@ -13,7 +13,7 @@ using LinearAlgebra: I, cholesky, Cholesky x = Flux.onehotbatch([1, 2, 3], 1:3) cx = gpu(x) - @test cx isa Flux.OneHotMatrix && cx.data isa CuArray + @test cx isa Flux.OneHotMatrix && cx.indices isa CuArray @test (cx .+ 1) isa CuArray m = Chain(Dense(10, 5, tanh), Dense(5, 2), softmax) diff --git a/test/onehot.jl b/test/onehot.jl index 591e899920..d325733a2d 100644 --- a/test/onehot.jl +++ b/test/onehot.jl @@ -27,8 +27,8 @@ end @testset "abstractmatrix onehotvector multiplication" begin A = [1 3 5; 2 4 6; 3 6 9] - b1 = Flux.OneHotVector(1,3) - b2 = Flux.OneHotVector(3,5) + b1 = Flux.OneHotVector{eltype(A), 3}(1) + b2 = Flux.OneHotVector{eltype(A), 5}(3) @test A*b1 == A[:,1] @test_throws DimensionMismatch A*b2 From 47ab6224ff47e2ee5ca4d5e6b1b6edc3c6e1d898 Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Fri, 1 Jan 2021 13:54:51 -0600 Subject: [PATCH 02/13] Added fast argmax for onecold --- src/onehot.jl | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/onehot.jl b/src/onehot.jl index 9203c11018..d1c30eaba4 100644 --- a/src/onehot.jl +++ b/src/onehot.jl @@ -124,12 +124,15 @@ julia> Flux.onecold([0.3, 0.2, 0.5], [:a, :b, :c]) """ onecold(y::AbstractVector, labels = 1:length(y)) = labels[argmax(y)] function onecold(y::AbstractArray, labels = 1:size(y, 1)) - indices = dropdims(argmax(y; dims = 1); dims = 1) + indices = convert(Array, _fast_argmax(y)) xs = isbits(labels) ? indices : collect(indices) # non-bit type cannot be handled by CUDA return map(xi -> labels[xi[1]], xs) end +_fast_argmax(x::AbstractArray) = dropdims(argmax(x; dims = 1); dims = 1) +_fast_argmax(x::OneHotArray) = x.indices + @nograd OneHotArray, onecold, onehot, onehotbatch function Base.:(*)(A::AbstractMatrix, B::OneHotArray{<:Any, L}) where L From a98727fa22507e58ebf1428a08925a8ecbe8da9b Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Fri, 1 Jan 2021 14:29:59 -0600 Subject: [PATCH 03/13] Don't convert non-one-hot array argmax in onecold --- src/onehot.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/onehot.jl b/src/onehot.jl index d1c30eaba4..5e39fda20f 100644 --- a/src/onehot.jl +++ b/src/onehot.jl @@ -124,14 +124,14 @@ julia> Flux.onecold([0.3, 0.2, 0.5], [:a, :b, :c]) """ onecold(y::AbstractVector, labels = 1:length(y)) = labels[argmax(y)] function onecold(y::AbstractArray, labels = 1:size(y, 1)) - indices = convert(Array, _fast_argmax(y)) + indices = _fast_argmax(y) xs = isbits(labels) ? indices : collect(indices) # non-bit type cannot be handled by CUDA return map(xi -> labels[xi[1]], xs) end _fast_argmax(x::AbstractArray) = dropdims(argmax(x; dims = 1); dims = 1) -_fast_argmax(x::OneHotArray) = x.indices +_fast_argmax(x::OneHotArray) = convert(Array, x.indices) @nograd OneHotArray, onecold, onehot, onehotbatch From 4355ff17b415c05b9b851df3071871e18bbaf2b9 Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Fri, 1 Jan 2021 15:01:13 -0600 Subject: [PATCH 04/13] Make _fast_argmax even faster --- src/onehot.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/onehot.jl b/src/onehot.jl index 5e39fda20f..32cb40b219 100644 --- a/src/onehot.jl +++ b/src/onehot.jl @@ -131,7 +131,7 @@ function onecold(y::AbstractArray, labels = 1:size(y, 1)) end _fast_argmax(x::AbstractArray) = dropdims(argmax(x; dims = 1); dims = 1) -_fast_argmax(x::OneHotArray) = convert(Array, x.indices) +_fast_argmax(x::OneHotArray) = convert(AbstractArray, x.indices) @nograd OneHotArray, onecold, onehot, onehotbatch From 630e5f0e7f8f76857d15e85ffd8d7719b4549e52 Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Sat, 2 Jan 2021 09:20:11 -0600 Subject: [PATCH 05/13] Fix typo in one hot cat Co-authored-by: Carlo Lucibello --- src/onehot.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/onehot.jl b/src/onehot.jl index 32cb40b219..58c345b383 100644 --- a/src/onehot.jl +++ b/src/onehot.jl @@ -33,7 +33,7 @@ function Base.cat(xs::OneHotArray{<:Any, L}...; dims::Int) where L if isone(dims) return cat(map(x -> convert(_onehot_bool_type(x), x), xs)...; dims = 1) else - return OneHotArray(L, cat(_indices.(xs)...; dims = dims = dims - 1)) + return OneHotArray(L, cat(_indices.(xs)...; dims = dims - 1)) end end @@ -138,4 +138,4 @@ _fast_argmax(x::OneHotArray) = convert(AbstractArray, x.indices) function Base.:(*)(A::AbstractMatrix, B::OneHotArray{<:Any, L}) where L size(A, 2) == L || throw(DimensionMismatch("Matrix column must correspond with OneHot size: $(size(A, 2)) != $L")) return A[:, onecold(B)] -end \ No newline at end of file +end From 36183f0b8ee1da8a3330a0d4aaa0bd5332f9e274 Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Sat, 2 Jan 2021 09:54:14 -0600 Subject: [PATCH 06/13] Remove conversion to AbstractArray --- src/onehot.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/onehot.jl b/src/onehot.jl index 58c345b383..6dd9210e3f 100644 --- a/src/onehot.jl +++ b/src/onehot.jl @@ -131,7 +131,7 @@ function onecold(y::AbstractArray, labels = 1:size(y, 1)) end _fast_argmax(x::AbstractArray) = dropdims(argmax(x; dims = 1); dims = 1) -_fast_argmax(x::OneHotArray) = convert(AbstractArray, x.indices) +_fast_argmax(x::OneHotArray) = x.indices @nograd OneHotArray, onecold, onehot, onehotbatch From 49befe018e4a23023eb291c33b8b9a6635fffe79 Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Sat, 2 Jan 2021 10:14:49 -0600 Subject: [PATCH 07/13] Break infinite recursion of reshape(::OneHotArray) --- src/onehot.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/onehot.jl b/src/onehot.jl index 6dd9210e3f..01fda03313 100644 --- a/src/onehot.jl +++ b/src/onehot.jl @@ -41,7 +41,8 @@ Base.hcat(xs::OneHotArray...) = cat(xs...; dims = 2) Base.vcat(xs::OneHotArray...) = cat(xs...; dims = 1) Base.reshape(x::OneHotArray{<:Any, L}, dims...) where L = - (first(dims) == L) ? OneHotArray(L, reshape(x.indices, dims[2:end]...)) : reshape(x, dims...) + (first(dims) == L) ? OneHotArray(L, reshape(x.indices, dims[2:end]...)) : + reshape(convert(_onehot_bool_type(x), x), dims...) batch(xs::AbstractArray{<:OneHotVector{<:Any, L}}) where L = OneHotArray(L, _indices.(xs)) From 2bad6b92d9c96f828ac8066667c87841ce16e800 Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Sat, 2 Jan 2021 10:38:46 -0600 Subject: [PATCH 08/13] Throw error on reshape(::OneHotArray) if first dim doesn't match --- src/onehot.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/onehot.jl b/src/onehot.jl index 01fda03313..ca8d4cb565 100644 --- a/src/onehot.jl +++ b/src/onehot.jl @@ -40,9 +40,9 @@ end Base.hcat(xs::OneHotArray...) = cat(xs...; dims = 2) Base.vcat(xs::OneHotArray...) = cat(xs...; dims = 1) -Base.reshape(x::OneHotArray{<:Any, L}, dims...) where L = +Base.reshape(x::OneHotArray{<:Any, L}, dims::Dims) where L = (first(dims) == L) ? OneHotArray(L, reshape(x.indices, dims[2:end]...)) : - reshape(convert(_onehot_bool_type(x), x), dims...) + throw(ArgumentError("Cannot reshape OneHotArray if first(dims) != size(x, 1)")) batch(xs::AbstractArray{<:OneHotVector{<:Any, L}}) where L = OneHotArray(L, _indices.(xs)) From 3b89d8567d7aa2b039440bdc00bf12f7fec55788 Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Tue, 5 Jan 2021 14:03:20 -0600 Subject: [PATCH 09/13] Added tests and addressed comments --- src/onehot.jl | 14 +++++---- test/cuda/cuda.jl | 2 ++ test/onehot.jl | 74 +++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 84 insertions(+), 6 deletions(-) diff --git a/src/onehot.jl b/src/onehot.jl index ca8d4cb565..5a51db994f 100644 --- a/src/onehot.jl +++ b/src/onehot.jl @@ -1,9 +1,7 @@ import Adapt import .CUDA -const OneHotIndex{T, N} = Union{T, AbstractArray{T, N}} - -struct OneHotArray{T<:Integer, L, N, var"N+1", I<:OneHotIndex{T, N}} <: AbstractArray{Bool, var"N+1"} +struct OneHotArray{T<:Integer, L, N, var"N+1", I<:Union{T, AbstractArray{T, N}}} <: AbstractArray{Bool, var"N+1"} indices::I end OneHotArray{T, L, N, I}(indices) where {T, L, N, I} = OneHotArray{T, L, N, N+1, I}(indices) @@ -15,18 +13,22 @@ _indices(x::OneHotArray) = x.indices const OneHotVector{T, L} = OneHotArray{T, L, 0, 1, T} const OneHotMatrix{T, L, I} = OneHotArray{T, L, 1, 2, I} +OneHotVector(L, idx) = OneHotArray(L, idx) +OneHotMatrix(L, indices) = OneHotArray(L, indices) + Base.size(x::OneHotArray{<:Any, L}) where L = (Int(L), size(x.indices)...) _onehotindex(x, i) = (x == i) Base.getindex(x::OneHotVector, i::Integer) = _onehotindex(x.indices, i) -Base.getindex(x::OneHotVector{T, L}, ::Colon) where {T, L} = OneHotVector{T, L}(x.indices) +Base.getindex(x::OneHotVector{T, L}, ::Colon) where {T, L} = x Base.getindex(x::OneHotArray, i::Integer, I...) = _onehotindex.(x.indices[I...], i) Base.getindex(x::OneHotArray{<:Any, L}, ::Colon, I...) where L = OneHotArray(L, x.indices[I...]) +Base.getindex(x::OneHotArray{<:Any, <:Any, <:Any, N}, ::Vararg{Colon, N}) where N = x Base.getindex(x::OneHotArray, I::CartesianIndex{N}) where N = x[I[1], Tuple(I)[2:N]...] -_onehot_bool_type(x::OneHotArray{<:Any, <:Any, <:Any, N, <:OneHotIndex}) where N = Array{Bool, N} +_onehot_bool_type(x::OneHotArray{<:Any, <:Any, <:Any, N, <:Union{Integer, AbstractArray}}) where N = Array{Bool, N} _onehot_bool_type(x::OneHotArray{<:Any, <:Any, <:Any, N, <:CuArray}) where N = CuArray{Bool, N} function Base.cat(xs::OneHotArray{<:Any, L}...; dims::Int) where L @@ -48,7 +50,7 @@ batch(xs::AbstractArray{<:OneHotVector{<:Any, L}}) where L = OneHotArray(L, _ind Adapt.adapt_structure(T, x::OneHotArray{<:Any, L}) where L = OneHotArray(L, adapt(T, x.indices)) -Base.BroadcastStyle(::Type{<:OneHotArray{<:Any, <:Any, <:Any, N, <:CuArray}}) where N = CuArrayStyle{N}() +Base.BroadcastStyle(::Type{<:OneHotArray{<:Any, <:Any, <:Any, N, <:CuArray}}) where N = CUDA.CuArrayStyle{N}() Base.argmax(x::OneHotArray; dims = Colon()) = (dims == 1) ? reshape(CartesianIndex.(x.indices, CartesianIndices(x.indices)), 1, size(x.indices)...) : diff --git a/test/cuda/cuda.jl b/test/cuda/cuda.jl index 17851bc2cb..5ede5beeb9 100644 --- a/test/cuda/cuda.jl +++ b/test/cuda/cuda.jl @@ -40,8 +40,10 @@ end @testset "onecold gpu" begin y = Flux.onehotbatch(ones(3), 1:10) |> gpu; + l = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j'] @test Flux.onecold(y) isa CuArray @test y[3,:] isa CuArray + @test Flux.onecold(y, l) isa CuArray end @testset "restructure gpu" begin diff --git a/test/onehot.jl b/test/onehot.jl index d325733a2d..2d238b7e42 100644 --- a/test/onehot.jl +++ b/test/onehot.jl @@ -32,4 +32,78 @@ end @test A*b1 == A[:,1] @test_throws DimensionMismatch A*b2 +end + +@testset "OneHotArray" begin + using Flux: OneHotArray, OneHotVector, OneHotMatrix + + ov = OneHotVector(10, rand(1:10)) + om = OneHotMatrix(10, rand(1:10, 5)) + oa = OneHotArray(10, rand(1:10, 5, 5)) + + # sizes + @testset "Base.size" begin + @test size(ov) == (10,) + @test size(om) == (10, 5) + @test size(oa) == (10, 5, 5) + end + + @testset "Indexing" begin + # vector indexing + @test ov[3] == (ov.indices == 3) + @test ov[:] == ov + + # matrix indexing + @test om[3, 3] == (om.indices[3] == 3) + @test om[:, 3] == OneHotVector(10, om.indices[3]) + @test om[3, :] == (om.indices .== 3) + @test om[:, :] == om + + # array indexing + @test oa[3, 3, 3] == (oa.indices[3, 3] == 3) + @test oa[:, 3, 3] == OneHotVector(10, oa.indices[3, 3]) + @test oa[3, :, 3] == (oa.indices[:, 3] .== 3) + @test oa[3, :, :] == (oa.indices .== 3) + @test oa[:, 3, :] == OneHotMatrix(10, oa.indices[3, :]) + @test oa[:, :, :] == oa + + # cartesian indexing + @test oa[CartesianIndex(3, 3, 3)] == oa[3, 3, 3] + end + + @testset "Concatenating" begin + # vector cat + @test hcat(ov, ov) == OneHotMatrix(10, vcat(ov.indices, ov.indices)) + @test vcat(ov, ov) == vcat(convert(Array{Bool}, ov), convert(Array{Bool}, ov)) + @test cat(ov, ov; dims = 3) == OneHotArray(10, cat(ov.indices, ov.indices; dims = 2)) + + # matrix cat + @test hcat(om, om) == OneHotMatrix(10, vcat(om.indices, om.indices)) + @test vcat(om, om) == vcat(convert(Array{Bool}, om), convert(Array{Bool}, om)) + @test cat(om, om; dims = 3) == OneHotArray(10, cat(om.indices, om.indices; dims = 2)) + + # array cat + @test cat(oa, oa; dims = 3) == OneHotArray(10, cat(oa.indices, oa.indices; dims = 2)) + @test cat(oa, oa; dims = 1) == cat(convert(Array{Bool}, oa), convert(Array{Bool}, oa); dims = 1) + end + + @testset "Base.reshape" begin + # reshape test + @test reshape(oa, 10, 25) isa OneHotArray + @test reshape(oa, 10, :) isa OneHotArray + @test reshape(oa, :, 25) isa OneHotArray + @test_throws ArgumentError reshape(oa, 50, :) + @test_throws ArgumentError reshape(oa, 5, 10, 5) + @test reshape(oa, (10, 25)) isa OneHotArray + end + + @testset "Base.argmax" begin + # argmax test + @test argmax(ov) == argmax(convert(Array{Bool}, ov)) + @test argmax(om) == argmax(convert(Array{Bool}, om)) + @test argmax(om; dims = 1) == argmax(convert(Array{Bool}, om); dims = 1) + @test argmax(om; dims = 2) == argmax(convert(Array{Bool}, om); dims = 2) + @test argmax(oa; dims = 1) == argmax(convert(Array{Bool}, oa); dims = 1) + @test argmax(oa; dims = 3) == argmax(convert(Array{Bool}, oa); dims = 3) + end end \ No newline at end of file From 5a6058041af4ab782d77efb5c32707206ec07f52 Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Thu, 7 Jan 2021 08:18:03 -0600 Subject: [PATCH 10/13] Fix for reshape(::OneHotArray, ::Tuple) --- src/onehot.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/onehot.jl b/src/onehot.jl index 5a51db994f..3a6ffa1eb7 100644 --- a/src/onehot.jl +++ b/src/onehot.jl @@ -45,6 +45,7 @@ Base.vcat(xs::OneHotArray...) = cat(xs...; dims = 1) Base.reshape(x::OneHotArray{<:Any, L}, dims::Dims) where L = (first(dims) == L) ? OneHotArray(L, reshape(x.indices, dims[2:end]...)) : throw(ArgumentError("Cannot reshape OneHotArray if first(dims) != size(x, 1)")) +Base._reshape(x::OneHotArray, dims::Tuple{Vararg{Int}}) = reshape(x, dims) batch(xs::AbstractArray{<:OneHotVector{<:Any, L}}) where L = OneHotArray(L, _indices.(xs)) From 28ece66454b2e8ce38069e5750a014a180d7c1d1 Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Thu, 7 Jan 2021 16:54:57 -0600 Subject: [PATCH 11/13] Fix doctest errors --- docs/src/data/onehot.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/src/data/onehot.md b/docs/src/data/onehot.md index bb97e437cd..3aff242bee 100644 --- a/docs/src/data/onehot.md +++ b/docs/src/data/onehot.md @@ -6,13 +6,13 @@ It's common to encode categorical variables (like `true`, `false` or `cat`, `dog julia> using Flux: onehot, onecold julia> onehot(:b, [:a, :b, :c]) -3-element Flux.OneHotVector: +3-element Flux.OneHotArray{UInt32,3,0,1,UInt32}: 0 1 0 julia> onehot(:c, [:a, :b, :c]) -3-element Flux.OneHotVector: +3-element Flux.OneHotArray{UInt32,3,0,1,UInt32}: 0 0 1 @@ -44,7 +44,7 @@ Flux.onecold julia> using Flux: onehotbatch julia> onehotbatch([:b, :a, :b], [:a, :b, :c]) -3×3 Flux.OneHotMatrix{Array{Flux.OneHotVector,1}}: +3×3 Flux.OneHotArray{UInt32,3,1,2,Array{UInt32,1}}: 0 1 0 1 0 1 0 0 0 From ef1ecb373775df3844d442af8576f93f501df983 Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Fri, 8 Jan 2021 07:41:43 -0600 Subject: [PATCH 12/13] Fix onehot cuda test --- test/cuda/cuda.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/cuda/cuda.jl b/test/cuda/cuda.jl index 5ede5beeb9..92c04404f8 100644 --- a/test/cuda/cuda.jl +++ b/test/cuda/cuda.jl @@ -43,7 +43,7 @@ end l = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j'] @test Flux.onecold(y) isa CuArray @test y[3,:] isa CuArray - @test Flux.onecold(y, l) isa CuArray + @test Flux.onecold(y, l) == ['a', 'a', 'a'] end @testset "restructure gpu" begin From aa63a5a45731712e2ace325113d02ef2224721c4 Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Fri, 8 Jan 2021 10:42:43 -0600 Subject: [PATCH 13/13] Make constructors backwards compatible and throw error on vcat. --- src/onehot.jl | 20 ++++++++++---------- test/onehot.jl | 32 ++++++++++++++++---------------- 2 files changed, 26 insertions(+), 26 deletions(-) diff --git a/src/onehot.jl b/src/onehot.jl index 3a6ffa1eb7..c9f0b145a0 100644 --- a/src/onehot.jl +++ b/src/onehot.jl @@ -5,16 +5,16 @@ struct OneHotArray{T<:Integer, L, N, var"N+1", I<:Union{T, AbstractArray{T, N}}} indices::I end OneHotArray{T, L, N, I}(indices) where {T, L, N, I} = OneHotArray{T, L, N, N+1, I}(indices) -OneHotArray(L::Integer, indices::T) where {T<:Integer} = OneHotArray{T, L, 0, T}(indices) -OneHotArray(L::Integer, indices::AbstractArray{T, N}) where {T, N} = OneHotArray{T, L, N, typeof(indices)}(indices) +OneHotArray(indices::T, L::Integer) where {T<:Integer} = OneHotArray{T, L, 0, T}(indices) +OneHotArray(indices::AbstractArray{T, N}, L::Integer) where {T, N} = OneHotArray{T, L, N, typeof(indices)}(indices) _indices(x::OneHotArray) = x.indices const OneHotVector{T, L} = OneHotArray{T, L, 0, 1, T} const OneHotMatrix{T, L, I} = OneHotArray{T, L, 1, 2, I} -OneHotVector(L, idx) = OneHotArray(L, idx) -OneHotMatrix(L, indices) = OneHotArray(L, indices) +OneHotVector(idx, L) = OneHotArray(idx, L) +OneHotMatrix(indices, L) = OneHotArray(indices, L) Base.size(x::OneHotArray{<:Any, L}) where L = (Int(L), size(x.indices)...) @@ -24,7 +24,7 @@ Base.getindex(x::OneHotVector, i::Integer) = _onehotindex(x.indices, i) Base.getindex(x::OneHotVector{T, L}, ::Colon) where {T, L} = x Base.getindex(x::OneHotArray, i::Integer, I...) = _onehotindex.(x.indices[I...], i) -Base.getindex(x::OneHotArray{<:Any, L}, ::Colon, I...) where L = OneHotArray(L, x.indices[I...]) +Base.getindex(x::OneHotArray{<:Any, L}, ::Colon, I...) where L = OneHotArray(x.indices[I...], L) Base.getindex(x::OneHotArray{<:Any, <:Any, <:Any, N}, ::Vararg{Colon, N}) where N = x Base.getindex(x::OneHotArray, I::CartesianIndex{N}) where N = x[I[1], Tuple(I)[2:N]...] @@ -33,9 +33,9 @@ _onehot_bool_type(x::OneHotArray{<:Any, <:Any, <:Any, N, <:CuArray}) where N = C function Base.cat(xs::OneHotArray{<:Any, L}...; dims::Int) where L if isone(dims) - return cat(map(x -> convert(_onehot_bool_type(x), x), xs)...; dims = 1) + return throw(ArgumentError("Cannot concat OneHotArray along first dimension. Use collect to convert to Bool array first.")) else - return OneHotArray(L, cat(_indices.(xs)...; dims = dims - 1)) + return OneHotArray(cat(_indices.(xs)...; dims = dims - 1), L) end end @@ -43,13 +43,13 @@ Base.hcat(xs::OneHotArray...) = cat(xs...; dims = 2) Base.vcat(xs::OneHotArray...) = cat(xs...; dims = 1) Base.reshape(x::OneHotArray{<:Any, L}, dims::Dims) where L = - (first(dims) == L) ? OneHotArray(L, reshape(x.indices, dims[2:end]...)) : + (first(dims) == L) ? OneHotArray(reshape(x.indices, dims[2:end]...), L) : throw(ArgumentError("Cannot reshape OneHotArray if first(dims) != size(x, 1)")) Base._reshape(x::OneHotArray, dims::Tuple{Vararg{Int}}) = reshape(x, dims) -batch(xs::AbstractArray{<:OneHotVector{<:Any, L}}) where L = OneHotArray(L, _indices.(xs)) +batch(xs::AbstractArray{<:OneHotVector{<:Any, L}}) where L = OneHotArray(_indices.(xs), L) -Adapt.adapt_structure(T, x::OneHotArray{<:Any, L}) where L = OneHotArray(L, adapt(T, x.indices)) +Adapt.adapt_structure(T, x::OneHotArray{<:Any, L}) where L = OneHotArray(adapt(T, x.indices), L) Base.BroadcastStyle(::Type{<:OneHotArray{<:Any, <:Any, <:Any, N, <:CuArray}}) where N = CUDA.CuArrayStyle{N}() diff --git a/test/onehot.jl b/test/onehot.jl index 2d238b7e42..9461bc816d 100644 --- a/test/onehot.jl +++ b/test/onehot.jl @@ -27,8 +27,8 @@ end @testset "abstractmatrix onehotvector multiplication" begin A = [1 3 5; 2 4 6; 3 6 9] - b1 = Flux.OneHotVector{eltype(A), 3}(1) - b2 = Flux.OneHotVector{eltype(A), 5}(3) + b1 = Flux.OneHotVector(1, 3) + b2 = Flux.OneHotVector(3, 5) @test A*b1 == A[:,1] @test_throws DimensionMismatch A*b2 @@ -37,9 +37,9 @@ end @testset "OneHotArray" begin using Flux: OneHotArray, OneHotVector, OneHotMatrix - ov = OneHotVector(10, rand(1:10)) - om = OneHotMatrix(10, rand(1:10, 5)) - oa = OneHotArray(10, rand(1:10, 5, 5)) + ov = OneHotVector(rand(1:10), 10) + om = OneHotMatrix(rand(1:10, 5), 10) + oa = OneHotArray(rand(1:10, 5, 5), 10) # sizes @testset "Base.size" begin @@ -55,16 +55,16 @@ end # matrix indexing @test om[3, 3] == (om.indices[3] == 3) - @test om[:, 3] == OneHotVector(10, om.indices[3]) + @test om[:, 3] == OneHotVector(om.indices[3], 10) @test om[3, :] == (om.indices .== 3) @test om[:, :] == om # array indexing @test oa[3, 3, 3] == (oa.indices[3, 3] == 3) - @test oa[:, 3, 3] == OneHotVector(10, oa.indices[3, 3]) + @test oa[:, 3, 3] == OneHotVector(oa.indices[3, 3], 10) @test oa[3, :, 3] == (oa.indices[:, 3] .== 3) @test oa[3, :, :] == (oa.indices .== 3) - @test oa[:, 3, :] == OneHotMatrix(10, oa.indices[3, :]) + @test oa[:, 3, :] == OneHotMatrix(oa.indices[3, :], 10) @test oa[:, :, :] == oa # cartesian indexing @@ -73,18 +73,18 @@ end @testset "Concatenating" begin # vector cat - @test hcat(ov, ov) == OneHotMatrix(10, vcat(ov.indices, ov.indices)) - @test vcat(ov, ov) == vcat(convert(Array{Bool}, ov), convert(Array{Bool}, ov)) - @test cat(ov, ov; dims = 3) == OneHotArray(10, cat(ov.indices, ov.indices; dims = 2)) + @test hcat(ov, ov) == OneHotMatrix(vcat(ov.indices, ov.indices), 10) + @test_throws ArgumentError vcat(ov, ov) + @test cat(ov, ov; dims = 3) == OneHotArray(cat(ov.indices, ov.indices; dims = 2), 10) # matrix cat - @test hcat(om, om) == OneHotMatrix(10, vcat(om.indices, om.indices)) - @test vcat(om, om) == vcat(convert(Array{Bool}, om), convert(Array{Bool}, om)) - @test cat(om, om; dims = 3) == OneHotArray(10, cat(om.indices, om.indices; dims = 2)) + @test hcat(om, om) == OneHotMatrix(vcat(om.indices, om.indices), 10) + @test_throws ArgumentError vcat(om, om) + @test cat(om, om; dims = 3) == OneHotArray(cat(om.indices, om.indices; dims = 2), 10) # array cat - @test cat(oa, oa; dims = 3) == OneHotArray(10, cat(oa.indices, oa.indices; dims = 2)) - @test cat(oa, oa; dims = 1) == cat(convert(Array{Bool}, oa), convert(Array{Bool}, oa); dims = 1) + @test cat(oa, oa; dims = 3) == OneHotArray(cat(oa.indices, oa.indices; dims = 2), 10) + @test_throws ArgumentError cat(oa, oa; dims = 1) end @testset "Base.reshape" begin