Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 34 additions & 20 deletions src/onehot.jl
Original file line number Diff line number Diff line change
@@ -1,32 +1,44 @@
import Base: *

struct OneHotVector <: AbstractVector{Bool}
ix::UInt32
of::UInt32
struct OneHotVector{T <: Integer} <: AbstractVector{Bool}
ix::T
of::T
end

Base.size(xs::OneHotVector) = (Int64(xs.of),)
Base.size(xs::OneHotVector) = (Int(xs.of),)

Base.getindex(xs::OneHotVector, i::Integer) = i == xs.ix

A::AbstractMatrix * b::OneHotVector = A[:, b.ix]

struct OneHotMatrix{A<:AbstractVector{OneHotVector}} <: AbstractMatrix{Bool}
"""
A matrix of one-hot column vectors
"""
struct OneHotMatrix{A<:AbstractVector{<:Integer}} <: AbstractMatrix{Bool}
height::Int
data::A
end

Base.size(xs::OneHotMatrix) = (Int64(xs.height),length(xs.data))
function OneHotMatrix(xs::Vector{<:OneHotVector})
height = length(xs[1])
OneHotMatrix(height, map(xs) do x
length(x) == height || error("All one hot vectors must be the same length")
x.ix
end)
end


Base.getindex(xs::OneHotMatrix, i::Integer, j::Integer) = xs.data[j][i]
Base.getindex(xs::OneHotMatrix, ::Colon, i::Integer) = xs.data[i]
Base.size(xs::OneHotMatrix) = (xs.height, length(xs.data))

Base.getindex(xs::OneHotMatrix, ::Colon, i::Integer) = OneHotVector(xs.data[i], xs.height)
Base.getindex(xs::OneHotMatrix, i::Integer, j::Integer) = xs[:, j][i]
Base.getindex(xs::OneHotMatrix, ::Colon, i::AbstractArray) = OneHotMatrix(xs.height, xs.data[i])

A::AbstractMatrix * B::OneHotMatrix = A[:, map(x->x.ix, B.data)]
A::AbstractMatrix * B::OneHotMatrix = A[:, B.data]

Base.hcat(x::OneHotVector, xs::OneHotVector...) = OneHotMatrix(length(x), [x, xs...])
Base.hcat(x::OneHotVector, xs::OneHotVector...) = OneHotMatrix([x, xs...])

batch(xs::AbstractArray{<:OneHotVector}) = OneHotMatrix(length(first(xs)), xs)
batch(xs::AbstractArray{<:OneHotVector}) = OneHotMatrix(xs)

import Adapt: adapt, adapt_structure

Expand All @@ -39,20 +51,22 @@ adapt_structure(T, xs::OneHotMatrix) = OneHotMatrix(xs.height, adapt(T, xs.data)
cudaconvert(x::OneHotMatrix{<:CuArray}) = OneHotMatrix(x.height, cudaconvert(x.data))
end

function onehot(l, labels)
i = something(findfirst(isequal(l), labels), 0)
i > 0 || error("Value $l is not in labels")
OneHotVector(i, length(labels))
function onehotidx(l, labels)
i = findfirst(isequal(l), labels)
i !== nothing || error("Value $(repr(l; context=:limited=>true)) is not in labels")
i
end

function onehot(l, labels, unk)
i = something(findfirst(isequal(l), labels), 0)
i > 0 || return onehot(unk, labels)
OneHotVector(i, length(labels))
function onehotidx(l, labels, unk)
i = findfirst(isequal(l), labels)
i !== nothing || return onehotidx(unk, labels)
i
end

onehot(l, labels, unk...) = OneHotVector(onhotidx(l, labels, unk...), length(labels))

onehotbatch(ls, labels, unk...) =
OneHotMatrix(length(labels), [onehot(l, labels, unk...) for l in ls])
OneHotMatrix(length(labels), [onehotidx(l, labels, unk...) for l in ls])

onecold(y::AbstractVector, labels = 1:length(y)) = labels[Base.argmax(y)]

Expand Down