diff --git a/Project.toml b/Project.toml index 07a7098b01..d9df0e012e 100644 --- a/Project.toml +++ b/Project.toml @@ -13,6 +13,7 @@ MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" NNlibCUDA = "a00861dc-f156-4864-bf3c-e6376f28a68d" +OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" ProgressLogging = "33c8b6b6-d38a-422a-b730-caa89a2f386c" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" @@ -34,6 +35,7 @@ MLUtils = "0.2" MacroTools = "0.5" NNlib = "0.8.9" NNlibCUDA = "0.2.4" +OneHotArrays = "0.1" Optimisers = "0.2.1" ProgressLogging = "0.1" Reexport = "0.2, 1.0" diff --git a/src/Flux.jl b/src/Flux.jl index 0cacbd419a..b7d27406b0 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -37,8 +37,8 @@ export Descent, Adam, Momentum, Nesterov, RMSProp, using CUDA const use_cuda = Ref{Union{Nothing,Bool}}(nothing) +using Adapt, Functors, OneHotArrays include("utils.jl") -include("onehot.jl") include("functor.jl") # Pirate error to catch a common mistake. diff --git a/src/onehot.jl b/src/onehot.jl deleted file mode 100644 index 42d263aa44..0000000000 --- a/src/onehot.jl +++ /dev/null @@ -1,290 +0,0 @@ -import Adapt -import .CUDA -using LinearAlgebra, NNlib - -""" - OneHotArray{T,L,N,M,I} <: AbstractArray{Bool,M} - -These are constructed by [`onehot`](@ref) and [`onehotbatch`](@ref). -Parameter `I` is the type of the underlying storage, and `T` its eltype. -""" -struct OneHotArray{T<:Integer, L, N, var"N+1", I<:Union{T, AbstractArray{T, N}}} <: AbstractArray{Bool, var"N+1"} - indices::I -end -OneHotArray{T, L, N, I}(indices) where {T, L, N, I} = OneHotArray{T, L, N, N+1, I}(indices) -OneHotArray(indices::T, L::Integer) where {T<:Integer} = OneHotArray{T, L, 0, 1, T}(indices) -OneHotArray(indices::I, L::Integer) where {T, N, I<:AbstractArray{T, N}} = OneHotArray{T, L, N, N+1, I}(indices) - -_indices(x::OneHotArray) = x.indices -_indices(x::Base.ReshapedArray{<: Any, <: Any, <: OneHotArray}) = - reshape(parent(x).indices, x.dims[2:end]) - -const OneHotVector{T, L} = OneHotArray{T, L, 0, 1, T} -const OneHotMatrix{T, L, I} = OneHotArray{T, L, 1, 2, I} - -@doc @doc(OneHotArray) -OneHotVector(idx, L) = OneHotArray(idx, L) -@doc @doc(OneHotArray) -OneHotMatrix(indices, L) = OneHotArray(indices, L) - -# use this type so reshaped arrays hit fast paths -# e.g. argmax -const OneHotLike{T, L, N, var"N+1", I} = - Union{OneHotArray{T, L, N, var"N+1", I}, - Base.ReshapedArray{Bool, var"N+1", <:OneHotArray{T, L, <:Any, <:Any, I}}} - -_isonehot(x::OneHotArray) = true -_isonehot(x::Base.ReshapedArray{<:Any, <:Any, <:OneHotArray{<:Any, L}}) where L = (size(x, 1) == L) - -Base.size(x::OneHotArray{<:Any, L}) where L = (Int(L), size(x.indices)...) - -_onehotindex(x, i) = (x == i) - -Base.getindex(x::OneHotVector, i::Integer) = _onehotindex(x.indices, i) -Base.getindex(x::OneHotVector{T, L}, ::Colon) where {T, L} = x - -Base.getindex(x::OneHotArray, i::Integer, I...) = _onehotindex.(x.indices[I...], i) -Base.getindex(x::OneHotArray{<:Any, L}, ::Colon, I...) where L = OneHotArray(x.indices[I...], L) -Base.getindex(x::OneHotArray{<:Any, <:Any, <:Any, N}, ::Vararg{Colon, N}) where N = x -Base.getindex(x::OneHotArray, I::CartesianIndex{N}) where N = x[I[1], Tuple(I)[2:N]...] - -function Base.showarg(io::IO, x::OneHotArray, toplevel) - print(io, ndims(x) == 1 ? "OneHotVector(" : ndims(x) == 2 ? "OneHotMatrix(" : "OneHotArray(") - Base.showarg(io, x.indices, false) - print(io, ')') - toplevel && print(io, " with eltype Bool") - return nothing -end - -# this is from /LinearAlgebra/src/diagonal.jl, official way to print the dots: -function Base.replace_in_print_matrix(x::OneHotLike, i::Integer, j::Integer, s::AbstractString) - x[i,j] ? s : _isonehot(x) ? Base.replace_with_centered_mark(s) : s -end - -# copy CuArray versions back before trying to print them: -Base.print_array(io::IO, X::OneHotLike{T, L, N, var"N+1", <:CuArray}) where {T, L, N, var"N+1"} = - Base.print_array(io, cpu(X)) -Base.print_array(io::IO, X::LinearAlgebra.AdjOrTrans{Bool, <:OneHotLike{T, L, N, var"N+1", <:CuArray}}) where {T, L, N, var"N+1"} = - Base.print_array(io, cpu(X)) - -_onehot_bool_type(x::OneHotLike{<:Any, <:Any, <:Any, N, <:Union{Integer, AbstractArray}}) where N = Array{Bool, N} -_onehot_bool_type(x::OneHotLike{<:Any, <:Any, <:Any, N, <:CuArray}) where N = CuArray{Bool, N} - -function Base.cat(x::OneHotLike{<:Any, L}, xs::OneHotLike{<:Any, L}...; dims::Int) where L - if isone(dims) || any(x -> !_isonehot(x), (x, xs...)) - return cat(map(x -> convert(_onehot_bool_type(x), x), (x, xs...))...; dims = dims) - else - return OneHotArray(cat(_indices(x), _indices.(xs)...; dims = dims - 1), L) - end -end - -Base.hcat(x::OneHotLike, xs::OneHotLike...) = cat(x, xs...; dims = 2) -Base.vcat(x::OneHotLike, xs::OneHotLike...) = cat(x, xs...; dims = 1) - -# optimized concatenation for matrices and vectors of same parameters -Base.hcat(x::T, xs::T...) where {L, T <: OneHotLike{<:Any, L, <:Any, 2}} = - OneHotMatrix(reduce(vcat, _indices.(xs); init = _indices(x)), L) -Base.hcat(x::T, xs::T...) where {L, T <: OneHotLike{<:Any, L, <:Any, 1}} = - OneHotMatrix(reduce(vcat, _indices.(xs); init = _indices(x)), L) - -MLUtils.batch(xs::AbstractArray{<:OneHotVector{<:Any, L}}) where L = OneHotMatrix(_indices.(xs), L) - -Adapt.adapt_structure(T, x::OneHotArray{<:Any, L}) where L = OneHotArray(adapt(T, _indices(x)), L) - -Base.BroadcastStyle(::Type{<:OneHotArray{<: Any, <: Any, <: Any, N, <: CuArray}}) where N = CUDA.CuArrayStyle{N}() - -Base.map(f, x::OneHotLike) = Base.broadcast(f, x) - -Base.argmax(x::OneHotLike; dims = Colon()) = - (_isonehot(x) && dims == 1) ? - reshape(CartesianIndex.(_indices(x), CartesianIndices(_indices(x))), 1, size(_indices(x))...) : - invoke(argmax, Tuple{AbstractArray}, x; dims = dims) - -""" - onehot(x, labels, [default]) - -Return a `OneHotVector` which is roughly a sparse representation of `x .== labels`. - -Instead of storing say `Vector{Bool}`, it stores the index of the first occurrence -of `x` in `labels`. If `x` is not found in labels, then it either returns `onehot(default, labels)`, -or gives an error if no default is given. - -See also [`onehotbatch`](@ref) to apply this to many `x`s, -and [`onecold`](@ref) to reverse either of these, as well as to generalise `argmax`. - -# Examples -```jldoctest -julia> β = Flux.onehot(:b, (:a, :b, :c)) -3-element OneHotVector(::UInt32) with eltype Bool: - ⋅ - 1 - ⋅ - -julia> αβγ = (Flux.onehot(0, 0:2), β, Flux.onehot(:z, [:a, :b, :c], :c)) # uses default -(Bool[1, 0, 0], Bool[0, 1, 0], Bool[0, 0, 1]) - -julia> hcat(αβγ...) # preserves sparsity -3×3 OneHotMatrix(::Vector{UInt32}) with eltype Bool: - 1 ⋅ ⋅ - ⋅ 1 ⋅ - ⋅ ⋅ 1 -``` -""" -function onehot(x, labels) - i = _findval(x, labels) - isnothing(i) && error("Value $x is not in labels") - OneHotVector{UInt32, length(labels)}(i) -end - -function onehot(x, labels, default) - i = _findval(x, labels) - isnothing(i) && return onehot(default, labels) - OneHotVector{UInt32, length(labels)}(i) -end - -_findval(val, labels) = findfirst(isequal(val), labels) -# Fast unrolled method for tuples: -function _findval(val, labels::Tuple, i::Integer=1) - ifelse(isequal(val, first(labels)), i, _findval(val, Base.tail(labels), i+1)) -end -_findval(val, labels::Tuple{}, i::Integer) = nothing - -""" - onehotbatch(xs, labels, [default]) - -Returns a `OneHotMatrix` where `k`th column of the matrix is [`onehot(xs[k], labels)`](@ref onehot). -This is a sparse matrix, which stores just a `Vector{UInt32}` containing the indices of the -nonzero elements. - -If one of the inputs in `xs` is not found in `labels`, that column is `onehot(default, labels)` -if `default` is given, else an error. - -If `xs` has more dimensions, `M = ndims(xs) > 1`, then the result is an -`AbstractArray{Bool, M+1}` which is one-hot along the first dimension, -i.e. `result[:, k...] == onehot(xs[k...], labels)`. - -Note that `xs` can be any iterable, such as a string. And that using a tuple -for `labels` will often speed up construction, certainly for less than 32 classes. - -# Examples -```jldoctest -julia> oh = Flux.onehotbatch("abracadabra", 'a':'e', 'e') -5×11 OneHotMatrix(::Vector{UInt32}) with eltype Bool: - 1 ⋅ ⋅ 1 ⋅ 1 ⋅ 1 ⋅ ⋅ 1 - ⋅ 1 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1 ⋅ ⋅ - ⋅ ⋅ ⋅ ⋅ 1 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ - ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1 ⋅ ⋅ ⋅ ⋅ - ⋅ ⋅ 1 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1 ⋅ - -julia> reshape(1:15, 3, 5) * oh # this matrix multiplication is done efficiently -3×11 Matrix{Int64}: - 1 4 13 1 7 1 10 1 4 13 1 - 2 5 14 2 8 2 11 2 5 14 2 - 3 6 15 3 9 3 12 3 6 15 3 -``` -""" -onehotbatch(data, labels, default...) = _onehotbatch(data, length(labels) < 32 ? Tuple(labels) : labels, default...) - -_onehotbatch(data::CuArray, labels) = _onehotbatch(data |> cpu, labels) |> gpu - -function _onehotbatch(data, labels) - indices = UInt32[something(_findval(i, labels), 0) for i in data] - if 0 in indices - for x in data - isnothing(_findval(x, labels)) && error("Value $x not found in labels") - end - end - return OneHotArray(indices, length(labels)) -end - -function _onehotbatch(data, labels, default) - default_index = _findval(default, labels) - isnothing(default_index) && error("Default value $default is not in labels") - indices = UInt32[something(_findval(i, labels), default_index) for i in data] - return OneHotArray(indices, length(labels)) -end - -""" - onecold(y::AbstractArray, labels = 1:size(y,1)) - -Roughly the inverse operation of [`onehot`](@ref) or [`onehotbatch`](@ref): -This finds the index of the largest element of `y`, or each column of `y`, -and looks them up in `labels`. - -If `labels` are not specified, the default is integers `1:size(y,1)` -- -the same operation as `argmax(y, dims=1)` but sometimes a different return type. - -# Examples -```jldoctest -julia> Flux.onecold([false, true, false]) -2 - -julia> Flux.onecold([0.3, 0.2, 0.5], (:a, :b, :c)) -:c - -julia> Flux.onecold([ 1 0 0 1 0 1 0 1 0 0 1 - 0 1 0 0 0 0 0 0 1 0 0 - 0 0 0 0 1 0 0 0 0 0 0 - 0 0 0 0 0 0 1 0 0 0 0 - 0 0 1 0 0 0 0 0 0 1 0 ], 'a':'e') |> String -"abeacadabea" -``` -""" -onecold(y::AbstractVector, labels = 1:length(y)) = labels[argmax(y)] -function onecold(y::AbstractArray, labels = 1:size(y, 1)) - indices = _fast_argmax(y) - xs = isbits(labels) ? indices : collect(indices) # non-bit type cannot be handled by CUDA - - return map(xi -> labels[xi[1]], xs) -end - -_fast_argmax(x::AbstractArray) = dropdims(argmax(x; dims = 1); dims = 1) -function _fast_argmax(x::OneHotLike) - if _isonehot(x) - return _indices(x) - else - return _fast_argmax(convert(_onehot_bool_type(x), x)) - end -end - -ChainRulesCore.@non_differentiable onehot(::Any...) -ChainRulesCore.@non_differentiable onehotbatch(::Any...) -ChainRulesCore.@non_differentiable onecold(::Any...) - -ChainRulesCore.@non_differentiable (::Type{<:OneHotArray})(indices::Any, L::Integer) - -function Base.:(*)(A::AbstractMatrix, B::OneHotLike{<:Any, L}) where L - _isonehot(B) || return invoke(*, Tuple{AbstractMatrix, AbstractMatrix}, A, B) - size(A, 2) == L || throw(DimensionMismatch("Matrix column must correspond with OneHot size: $(size(A, 2)) != $L")) - return A[:, onecold(B)] -end - -function Base.:(*)(A::AbstractMatrix, B::OneHotLike{<:Any, L, 1}) where L - _isonehot(B) || return invoke(*, Tuple{AbstractMatrix, AbstractMatrix}, A, B) - size(A, 2) == L || throw(DimensionMismatch("Matrix column must correspond with OneHot size: $(size(A, 2)) != $L")) - return NNlib.gather(A, _indices(B)) -end - -function Base.:(*)(A::AbstractMatrix, B::Adjoint{Bool, <:OneHotMatrix}) - B_dim = length(_indices(parent(B))) - size(A, 2) == B_dim || throw(DimensionMismatch("Matrix column must correspond with OneHot size: $(size(A, 2)) != $B_dim")) - return NNlib.scatter(+, A, _indices(parent(B)), dstsize=(size(A,1), size(B,2))) -end - -for wrapper in [:Adjoint, :Transpose] - @eval begin - function Base.:*(A::$wrapper{<:Any, <:AbstractMatrix{T}}, b::OneHotVector{<:Any, L}) where {L, T} - size(A, 2) == L || - throw(DimensionMismatch("Matrix column must correspond with OneHot size: $(size(A, 2)) != $L")) - - return A[:, onecold(b)] - end - - function Base.:*(A::$wrapper{<:Number, <:AbstractVector{T}}, b::OneHotVector{<:Any, L}) where {L, T} - size(A, 2) == L || - throw(DimensionMismatch("Matrix column must correspond with OneHot size: $(size(A, 2)) != $L")) - - return A[onecold(b)] - end - end -end diff --git a/test/cuda/cuda.jl b/test/cuda/cuda.jl index 5caeb18d11..2b4fec6e4c 100644 --- a/test/cuda/cuda.jl +++ b/test/cuda/cuda.jl @@ -49,9 +49,11 @@ end # construct from CuArray x = [1, 3, 2] y = Flux.onehotbatch(x, 0:3) + @test_skip begin # https://github.com/FluxML/OneHotArrays.jl/issues/16 y2 = Flux.onehotbatch(x |> gpu, 0:3) @test y2.indices isa CuArray @test y2 |> cpu == y + end end @testset "onecold gpu" begin diff --git a/test/onehot.jl b/test/onehot.jl deleted file mode 100644 index 91f64b763d..0000000000 --- a/test/onehot.jl +++ /dev/null @@ -1,203 +0,0 @@ -using Flux: onehot, onehotbatch, onecold -using Test - -@testset "onehot constructors" begin - @test onehot(20, 10:10:30) == [false, true, false] - @test onehot(20, (10,20,30)) == [false, true, false] - @test onehot(40, (10,20,30), 20) == [false, true, false] - - @test_throws Exception onehot('d', 'a':'c') - @test_throws Exception onehot(:d, (:a, :b, :c)) - @test_throws Exception onehot('d', 'a':'c', 'e') - @test_throws Exception onehot(:d, (:a, :b, :c), :e) - - @test onehotbatch([20, 10], 10:10:30) == Bool[0 1; 1 0; 0 0] - @test onehotbatch([20, 10], (10,20,30)) == Bool[0 1; 1 0; 0 0] - @test onehotbatch([40, 10], (10,20,30), 20) == Bool[0 1; 1 0; 0 0] - - @test onehotbatch("abc", 'a':'c') == Bool[1 0 0; 0 1 0; 0 0 1] - @test onehotbatch("zbc", ('a', 'b', 'c'), 'a') == Bool[1 0 0; 0 1 0; 0 0 1] - - @test onehotbatch([10, 20], [30, 40, 50], 30) == Bool[1 1; 0 0; 0 0] - - @test_throws Exception onehotbatch([:a, :d], [:a, :b, :c]) - @test_throws Exception onehotbatch([:a, :d], (:a, :b, :c)) - @test_throws Exception onehotbatch([:a, :d], [:a, :b, :c], :e) - @test_throws Exception onehotbatch([:a, :d], (:a, :b, :c), :e) - - floats = (0.0, -0.0, NaN, -NaN, Inf, -Inf) - @test onecold(onehot(0.0, floats)) == 1 - @test onecold(onehot(-0.0, floats)) == 2 # as it uses isequal - @test onecold(onehot(Inf, floats)) == 5 -end - -@testset "onecold" begin - a = [1, 2, 5, 3.] - A = [1 20 5; 2 7 6; 3 9 10; 2 1 14] - labels = ['A', 'B', 'C', 'D'] - - @test onecold(a) == 3 - @test onecold(A) == [3, 1, 4] - @test onecold(a, labels) == 'C' - @test onecold(a, Tuple(labels)) == 'C' - @test onecold(A, labels) == ['C', 'A', 'D'] - @test onecold(A, Tuple(labels)) == ['C', 'A', 'D'] - - data = [:b, :a, :c] - labels = [:a, :b, :c] - hot = Flux.onehotbatch(data, labels) - cold = onecold(hot, labels) - - @test cold == data -end - -@testset "onehotbatch indexing" begin - y = Flux.onehotbatch(ones(3), 1:10) - @test y[:,1] isa Flux.OneHotVector - @test y[:,:] isa Flux.OneHotMatrix -end - -@testset "abstractmatrix onehotvector multiplication" begin - A = [1 3 5; 2 4 6; 3 6 9] - v = [1, 2, 3, 4, 5] - X = reshape(v, (5, 1)) - b1 = Flux.OneHotVector(1, 3) - b2 = Flux.OneHotVector(3, 5) - - @test A * b1 == A[:,1] - @test b1' * A == Array(b1') * A - @test A' * b1 == A' * Array(b1) - @test v' * b2 == v' * Array(b2) - @test transpose(X) * b2 == transpose(X) * Array(b2) - @test transpose(v) * b2 == transpose(v) * Array(b2) - @test_throws DimensionMismatch A*b2 -end - -@testset "AbstractMatrix-OneHotMatrix multiplication" begin - A = [1 3 5; 2 4 6; 3 6 9] - v = [1, 2, 3, 4, 5] - X = reshape(v, (5, 1)) - b1 = Flux.OneHotMatrix([1, 1, 2, 2], 3) - b2 = Flux.OneHotMatrix([2, 4, 1, 3], 5) - b3 = Flux.OneHotMatrix([1, 1, 2], 4) - b4 = reshape(Flux.OneHotMatrix([1 2 3; 2 2 1], 3), 3, :) - b5 = reshape(b4, 6, :) - b6 = reshape(Flux.OneHotMatrix([1 2 2; 2 2 1], 2), 3, :) - b7 = reshape(Flux.OneHotMatrix([1 2 3; 1 2 3], 3), 6, :) - - @test A * b1 == A[:,[1, 1, 2, 2]] - @test b1' * A == Array(b1') * A - @test A' * b1 == A' * Array(b1) - @test A * b3' == A * Array(b3') - @test transpose(X) * b2 == transpose(X) * Array(b2) - @test A * b4 == A[:,[1, 2, 2, 2, 3, 1]] - @test A * b5' == hcat(A[:,[1, 2, 3, 3]], A[:,1]+A[:,2], zeros(Int64, 3)) - @test A * b6 == hcat(A[:,1], 2*A[:,2], A[:,2], A[:,1]+A[:,2]) - @test A * b7' == A[:,[1, 2, 3, 1, 2, 3]] - - @test_throws DimensionMismatch A*b1' - @test_throws DimensionMismatch A*b2 - @test_throws DimensionMismatch A*b2' - @test_throws DimensionMismatch A*b6' - @test_throws DimensionMismatch A*b7 -end - -@testset "OneHotArray" begin - using Flux: OneHotArray, OneHotVector, OneHotMatrix, OneHotLike - - ov = OneHotVector(rand(1:10), 10) - ov2 = OneHotVector(rand(1:11), 11) - om = OneHotMatrix(rand(1:10, 5), 10) - om2 = OneHotMatrix(rand(1:11, 5), 11) - oa = OneHotArray(rand(1:10, 5, 5), 10) - - # sizes - @testset "Base.size" begin - @test size(ov) == (10,) - @test size(om) == (10, 5) - @test size(oa) == (10, 5, 5) - end - - @testset "Indexing" begin - # vector indexing - @test ov[3] == (ov.indices == 3) - @test ov[:] == ov - - # matrix indexing - @test om[3, 3] == (om.indices[3] == 3) - @test om[:, 3] == OneHotVector(om.indices[3], 10) - @test om[3, :] == (om.indices .== 3) - @test om[:, :] == om - - # array indexing - @test oa[3, 3, 3] == (oa.indices[3, 3] == 3) - @test oa[:, 3, 3] == OneHotVector(oa.indices[3, 3], 10) - @test oa[3, :, 3] == (oa.indices[:, 3] .== 3) - @test oa[3, :, :] == (oa.indices .== 3) - @test oa[:, 3, :] == OneHotMatrix(oa.indices[3, :], 10) - @test oa[:, :, :] == oa - - # cartesian indexing - @test oa[CartesianIndex(3, 3, 3)] == oa[3, 3, 3] - end - - @testset "Concatenating" begin - # vector cat - @test hcat(ov, ov) == OneHotMatrix(vcat(ov.indices, ov.indices), 10) - @test hcat(ov, ov) isa OneHotMatrix - @test vcat(ov, ov) == vcat(collect(ov), collect(ov)) - @test cat(ov, ov; dims = 3) == OneHotArray(cat(ov.indices, ov.indices; dims = 2), 10) - - # matrix cat - @test hcat(om, om) == OneHotMatrix(vcat(om.indices, om.indices), 10) - @test hcat(om, om) isa OneHotMatrix - @test vcat(om, om) == vcat(collect(om), collect(om)) - @test cat(om, om; dims = 3) == OneHotArray(cat(om.indices, om.indices; dims = 2), 10) - - # array cat - @test cat(oa, oa; dims = 3) == OneHotArray(cat(oa.indices, oa.indices; dims = 2), 10) - @test cat(oa, oa; dims = 3) isa OneHotArray - @test cat(oa, oa; dims = 1) == cat(collect(oa), collect(oa); dims = 1) - - # proper error handling of inconsistent sizes - @test_throws DimensionMismatch hcat(ov, ov2) - @test_throws DimensionMismatch hcat(om, om2) - end - - @testset "Base.reshape" begin - # reshape test - @test reshape(oa, 10, 25) isa OneHotLike - @test reshape(oa, 10, :) isa OneHotLike - @test reshape(oa, :, 25) isa OneHotLike - @test reshape(oa, 50, :) isa OneHotLike - @test reshape(oa, 5, 10, 5) isa OneHotLike - @test reshape(oa, (10, 25)) isa OneHotLike - - @testset "w/ cat" begin - r = reshape(oa, 10, :) - @test hcat(r, r) isa OneHotArray - @test vcat(r, r) isa Array{Bool} - end - - @testset "w/ argmax" begin - r = reshape(oa, 10, :) - @test argmax(r) == argmax(OneHotMatrix(reshape(oa.indices, :), 10)) - @test Flux._fast_argmax(r) == collect(reshape(oa.indices, :)) - end - end - - @testset "Base.argmax" begin - # argmax test - @test argmax(ov) == argmax(convert(Array{Bool}, ov)) - @test argmax(om) == argmax(convert(Array{Bool}, om)) - @test argmax(om; dims = 1) == argmax(convert(Array{Bool}, om); dims = 1) - @test argmax(om; dims = 2) == argmax(convert(Array{Bool}, om); dims = 2) - @test argmax(oa; dims = 1) == argmax(convert(Array{Bool}, oa); dims = 1) - @test argmax(oa; dims = 3) == argmax(convert(Array{Bool}, oa); dims = 3) - end - - @testset "Forward map to broadcast" begin - @test map(identity, oa) == oa - @test map(x -> 2 * x, oa) == 2 .* oa - end -end diff --git a/test/runtests.jl b/test/runtests.jl index 706f126451..9027b114fc 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -16,10 +16,6 @@ Random.seed!(0) include("utils.jl") end - @testset "Onehot" begin - include("onehot.jl") - end - @testset "Optimise" begin include("optimise.jl") end