Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "OneHotArrays"
uuid = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f"
version = "0.2.2"
version = "0.2.3"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down
14 changes: 12 additions & 2 deletions src/onehot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,14 +101,24 @@ function _onehotbatch(data, labels, default)
end

function onehotbatch(data::AbstractArray{<:Integer}, labels::AbstractUnitRange{<:Integer})
# lo, hi = extrema(data) # fails on Julia 1.6
lo, hi = minimum(data), maximum(data)
lo, hi = extrema(data) # fails on Julia 1.6
lo < first(labels) && error("Value $lo not found in labels")
hi > last(labels) && error("Value $hi not found in labels")
offset = 1 - first(labels)
indices = UInt32.(data .+ offset)
return OneHotArray(indices, length(labels))
end
# That bounds check with extrema synchronises on GPU, much slower than rest of the function,
# hence add a special method, with a less helpful error message:
function onehotbatch(data::AbstractGPUArray{<:Integer}, labels::AbstractUnitRange{<:Integer})
offset = 1 - first(labels)
indices = map(data) do datum
i = UInt32(datum + offset)
checkbounds(labels, i)
i
end
return OneHotArray(indices, length(labels))
end

"""
onecold(y::AbstractArray, labels = 1:size(y,1))
Expand Down
10 changes: 7 additions & 3 deletions test/gpu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,14 @@ end
y1 = onehotbatch([1, 3, 0, 2], 0:9) |> cu
y2 = onehotbatch([1, 3, 0, 2] |> cu, 0:9)
@test y1.indices == y2.indices
@test_broken y1 == y2
@test_broken y1 == y2 # issue 28

@test_throws Exception onehotbatch([1, 3, 0, 2] |> cu, 1:10)
@test_throws Exception onehotbatch([1, 3, 0, 2] |> cu, -2:2)
if !CUDA.functional()
# Here CUDA gives an error which @test_throws does not notice,
# although with JLArrays @test_throws it's fine.
@test_throws Exception onehotbatch([1, 3, 0, 2] |> cu, 1:10)
@test_throws Exception onehotbatch([1, 3, 0, 2] |> cu, -2:2)
end
end

@testset "onecold gpu" begin
Expand Down