Skip to content

Usage of OneHotMatrix for input to neural network is very slow. #1355

@racinmat

Description

@racinmat

Multiplication of OneHotMatrix by dense layer could be more optimized, e.g. by

using LinearAlgebra
function Base.:*(A::AbstractMatrix, B::Flux.OneHotMatrix)
	m = size(A,1)
	Y = similar(A, m, size(B,2))
	for (j,ohv) in enumerate(B.data)
		ix = ohv.ix
		for i in 1:m
			@inbounds Y[i,j] = A[i,ix]
		end
	end
	Y
end
function Base.:*(A::AbstractMatrix, B::Adjoint{Bool,<: Flux.OneHotMatrix})
	m = size(A,1)
	Y = similar(A, m, size(B,2))
	Y .= 0
	BT = B'
	for (j,ohv) in enumerate(BT.data)
		ix = ohv.ix
		for i in 1:m
			@inbounds Y[i,ix] += A[i,j]
		end
	end
	Y
end

should I make PR for this?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions