Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 58 additions & 47 deletions src/onehot.jl
Original file line number Diff line number Diff line change
@@ -1,52 +1,57 @@
import Base: *
import Adapt
import .CUDA

struct OneHotVector <: AbstractVector{Bool}
ix::UInt32
of::UInt32
const OneHotIndex{T, N} = Union{T, AbstractArray{T, N}}

struct OneHotArray{T<:Integer, L, N, var"N+1", I<:OneHotIndex{T, N}} <: AbstractArray{Bool, var"N+1"}
indices::I
end
OneHotArray{T, L, N, I}(indices) where {T, L, N, I} = OneHotArray{T, L, N, N+1, I}(indices)
OneHotArray(L::Integer, indices::T) where {T<:Integer} = OneHotArray{T, L, 0, T}(indices)
OneHotArray(L::Integer, indices::AbstractArray{T, N}) where {T, N} = OneHotArray{T, L, N, typeof(indices)}(indices)

Base.size(xs::OneHotVector) = (Int64(xs.of),)
_indices(x::OneHotArray) = x.indices

Base.getindex(xs::OneHotVector, i::Integer) = i == xs.ix
const OneHotVector{T, L} = OneHotArray{T, L, 0, 1, T}
const OneHotMatrix{T, L, I} = OneHotArray{T, L, 1, 2, I}

Base.getindex(xs::OneHotVector, ::Colon) = OneHotVector(xs.ix, xs.of)
Base.size(x::OneHotArray{<:Any, L}) where L = (Int(L), size(x.indices)...)

function Base.:*(A::AbstractMatrix, b::OneHotVector)
if size(A, 2) != b.of
throw(DimensionMismatch("Matrix column must correspond with OneHotVector size"))
end
return A[:, b.ix]
end
_onehotindex(x, i) = (x == i)

struct OneHotMatrix{A<:AbstractVector{OneHotVector}} <: AbstractMatrix{Bool}
height::Int
data::A
end
Base.getindex(x::OneHotVector, i::Integer) = _onehotindex(x.indices, i)
Base.getindex(x::OneHotVector{T, L}, ::Colon) where {T, L} = OneHotVector{T, L}(x.indices)

Base.size(xs::OneHotMatrix) = (Int64(xs.height),length(xs.data))
Base.getindex(x::OneHotArray, i::Integer, I...) = _onehotindex.(x.indices[I...], i)
Base.getindex(x::OneHotArray{<:Any, L}, ::Colon, I...) where L = OneHotArray(L, x.indices[I...])
Base.getindex(x::OneHotArray, I::CartesianIndex{N}) where N = x[I[1], Tuple(I)[2:N]...]

Base.getindex(xs::OneHotMatrix, i::Union{Integer, AbstractVector}, j::Integer) = xs.data[j][i]
Base.getindex(xs::OneHotMatrix, ::Colon, i::Integer) = xs.data[i]
Base.getindex(xs::OneHotMatrix, ::Colon, i::AbstractArray) = OneHotMatrix(xs.height, xs.data[i])
Base.getindex(xs::OneHotMatrix, ::Colon, ::Colon) = OneHotMatrix(xs.height, copy(xs.data))
_onehot_bool_type(x::OneHotArray{<:Any, <:Any, <:Any, N, <:OneHotIndex}) where N = Array{Bool, N}
_onehot_bool_type(x::OneHotArray{<:Any, <:Any, <:Any, N, <:CuArray}) where N = CuArray{Bool, N}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might want to simply return the type of the underlying array iiic

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The underlying array is an integer array. This is just an internal convenience function I use when I want to convert the OneHotArray to an Bool array. I use this to decide whether to convert to a Array{Bool} or CuArray{Bool} depending on the underlying storage location.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But why do we need to have this? The implementation for CuArray should fall straight out of assuming regular array


Base.getindex(xs::OneHotMatrix, i::Integer, ::Colon) = map(x -> x[i], xs.data)
function Base.cat(xs::OneHotArray{<:Any, L}...; dims::Int) where L
if isone(dims)
return cat(map(x -> convert(_onehot_bool_type(x), x), xs)...; dims = 1)
else
return OneHotArray(L, cat(_indices.(xs)...; dims = dims - 1))
end
end

# remove workaround when https://github.com/JuliaGPU/CuArrays.jl/issues/676 is fixed
A::AbstractMatrix * B::OneHotMatrix = A[:, cpu(map(x->x.ix, B.data))]
Base.hcat(xs::OneHotArray...) = cat(xs...; dims = 2)
Base.vcat(xs::OneHotArray...) = cat(xs...; dims = 1)

Base.hcat(x::OneHotVector, xs::OneHotVector...) = OneHotMatrix(length(x), [x, xs...])
Base.reshape(x::OneHotArray{<:Any, L}, dims...) where L =
(first(dims) == L) ? OneHotArray(L, reshape(x.indices, dims[2:end]...)) : reshape(x, dims...)

batch(xs::AbstractArray{<:OneHotVector}) = OneHotMatrix(length(first(xs)), xs)
batch(xs::AbstractArray{<:OneHotVector{<:Any, L}}) where L = OneHotArray(L, _indices.(xs))

import Adapt: adapt, adapt_structure
Adapt.adapt_structure(T, x::OneHotArray{<:Any, L}) where L = OneHotArray(L, adapt(T, x.indices))

adapt_structure(T, xs::OneHotMatrix) = OneHotMatrix(xs.height, adapt(T, xs.data))
Base.BroadcastStyle(::Type{<:OneHotArray{<:Any, <:Any, <:Any, N, <:CuArray}}) where N = CuArrayStyle{N}()

import .CUDA: CuArray, CuArrayStyle, cudaconvert
import Base.Broadcast: BroadcastStyle, ArrayStyle
BroadcastStyle(::Type{<:OneHotMatrix{<:CuArray}}) = CuArrayStyle{2}()
cudaconvert(x::OneHotMatrix{<:CuArray}) = OneHotMatrix(x.height, cudaconvert(x.data))
Base.argmax(x::OneHotArray; dims = Colon()) =
(dims == 1) ? reshape(CartesianIndex.(x.indices, CartesianIndices(x.indices)), 1, size(x.indices)...) :
argmax(convert(_onehot_bool_type(x), x); dims = dims)

"""
onehot(l, labels[, unk])
Expand All @@ -60,13 +65,13 @@ If `l` is not found in labels and `unk` is present, the function returns
# Examples
```jldoctest
julia> Flux.onehot(:b, [:a, :b, :c])
3-element Flux.OneHotVector:
3-element Flux.OneHotArray{UInt32,3,0,1,UInt32}:
0
1
0

julia> Flux.onehot(:c, [:a, :b, :c])
3-element Flux.OneHotVector:
3-element Flux.OneHotArray{UInt32,3,0,1,UInt32}:
0
0
1
Expand All @@ -75,13 +80,13 @@ julia> Flux.onehot(:c, [:a, :b, :c])
function onehot(l, labels)
i = something(findfirst(isequal(l), labels), 0)
i > 0 || error("Value $l is not in labels")
OneHotVector(i, length(labels))
OneHotVector{UInt32, length(labels)}(i)
end

function onehot(l, labels, unk)
i = something(findfirst(isequal(l), labels), 0)
i > 0 || return onehot(unk, labels)
OneHotVector(i, length(labels))
OneHotVector{UInt32, length(labels)}(i)
end

"""
Expand All @@ -95,16 +100,13 @@ return [`onehot(unk, labels)`](@ref) ; otherwise the function will raise an erro
# Examples
```jldoctest
julia> Flux.onehotbatch([:b, :a, :b], [:a, :b, :c])
3×3 Flux.OneHotMatrix{Array{Flux.OneHotVector,1}}:
3×3 Flux.OneHotArray{UInt32,3,1,2,Array{UInt32,1}}:
0 1 0
1 0 1
0 0 0
```
"""
onehotbatch(ls, labels, unk...) =
OneHotMatrix(length(labels), [onehot(l, labels, unk...) for l in ls])

Base.argmax(xs::OneHotVector) = xs.ix
onehotbatch(ls, labels, unk...) = batch([onehot(l, labels, unk...) for l in ls])

"""
onecold(y[, labels = 1:length(y)])
Expand All @@ -120,11 +122,20 @@ julia> Flux.onecold([0.3, 0.2, 0.5], [:a, :b, :c])
:c
```
"""
onecold(y::AbstractVector, labels = 1:length(y)) = labels[Base.argmax(y)]
onecold(y::AbstractVector, labels = 1:length(y)) = labels[argmax(y)]
function onecold(y::AbstractArray, labels = 1:size(y, 1))
indices = _fast_argmax(y)
xs = isbits(labels) ? indices : collect(indices) # non-bit type cannot be handled by CUDA

onecold(y::AbstractMatrix, labels...) =
dropdims(mapslices(y -> onecold(y, labels...), y, dims=1), dims=1)
return map(xi -> labels[xi[1]], xs)
end

onecold(y::OneHotMatrix, labels...) = map(x -> Flux.onecold(x, labels...), y.data)
_fast_argmax(x::AbstractArray) = dropdims(argmax(x; dims = 1); dims = 1)
_fast_argmax(x::OneHotArray) = convert(AbstractArray, x.indices)

@nograd onecold, onehot, onehotbatch
@nograd OneHotArray, onecold, onehot, onehotbatch

function Base.:(*)(A::AbstractMatrix, B::OneHotArray{<:Any, L}) where L
size(A, 2) == L || throw(DimensionMismatch("Matrix column must correspond with OneHot size: $(size(A, 2)) != $L"))
return A[:, onecold(B)]
end
2 changes: 1 addition & 1 deletion test/cuda/cuda.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ using LinearAlgebra: I, cholesky, Cholesky

x = Flux.onehotbatch([1, 2, 3], 1:3)
cx = gpu(x)
@test cx isa Flux.OneHotMatrix && cx.data isa CuArray
@test cx isa Flux.OneHotMatrix && cx.indices isa CuArray
@test (cx .+ 1) isa CuArray

m = Chain(Dense(10, 5, tanh), Dense(5, 2), softmax)
Expand Down
4 changes: 2 additions & 2 deletions test/onehot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ end

@testset "abstractmatrix onehotvector multiplication" begin
A = [1 3 5; 2 4 6; 3 6 9]
b1 = Flux.OneHotVector(1,3)
b2 = Flux.OneHotVector(3,5)
b1 = Flux.OneHotVector{eltype(A), 3}(1)
b2 = Flux.OneHotVector{eltype(A), 5}(3)

@test A*b1 == A[:,1]
@test_throws DimensionMismatch A*b2
Expand Down