diff --git a/Project.toml b/Project.toml index 94da4fc..73ca990 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "OneHotArrays" uuid = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f" -version = "0.2.2" +version = "0.2.3" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/src/onehot.jl b/src/onehot.jl index ca2efa5..d2d5e9d 100644 --- a/src/onehot.jl +++ b/src/onehot.jl @@ -101,14 +101,24 @@ function _onehotbatch(data, labels, default) end function onehotbatch(data::AbstractArray{<:Integer}, labels::AbstractUnitRange{<:Integer}) - # lo, hi = extrema(data) # fails on Julia 1.6 - lo, hi = minimum(data), maximum(data) + lo, hi = extrema(data) lo < first(labels) && error("Value $lo not found in labels") hi > last(labels) && error("Value $hi not found in labels") offset = 1 - first(labels) indices = UInt32.(data .+ offset) return OneHotArray(indices, length(labels)) end +# That bounds check with extrema synchronises on GPU, much slower than rest of the function, +# hence add a special method, with a less helpful error message: +function onehotbatch(data::AbstractGPUArray{<:Integer}, labels::AbstractUnitRange{<:Integer}) + offset = 1 - first(labels) + indices = map(data) do datum + i = UInt32(datum + offset) + checkbounds(labels, i) + i + end + return OneHotArray(indices, length(labels)) +end """ onecold(y::AbstractArray, labels = 1:size(y,1)) diff --git a/test/gpu.jl b/test/gpu.jl index cd04815..681353f 100644 --- a/test/gpu.jl +++ b/test/gpu.jl @@ -30,10 +30,14 @@ end y1 = onehotbatch([1, 3, 0, 2], 0:9) |> cu y2 = onehotbatch([1, 3, 0, 2] |> cu, 0:9) @test y1.indices == y2.indices - @test_broken y1 == y2 + @test_broken y1 == y2 # issue 28 - @test_throws Exception onehotbatch([1, 3, 0, 2] |> cu, 1:10) - @test_throws Exception onehotbatch([1, 3, 0, 2] |> cu, -2:2) + if !CUDA.functional() + # Here CUDA gives an error which @test_throws does not notice, + # although with JLArrays @test_throws it's fine. + @test_throws Exception onehotbatch([1, 3, 0, 2] |> cu, 1:10) + @test_throws Exception onehotbatch([1, 3, 0, 2] |> cu, -2:2) + end end @testset "onecold gpu" begin