diff --git a/Project.toml b/Project.toml index a92b7e4ff..b31689afd 100644 --- a/Project.toml +++ b/Project.toml @@ -27,6 +27,7 @@ SciMLLogging = "a6db7da4-7206-11f0-1eab-35f2a5dbe1d1" SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961" Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" +TriangularSolve = "d5829a12-d9aa-46ab-831f-fb7c9ab06edf" UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed" [weakdeps] @@ -124,7 +125,7 @@ PrecompileTools = "1.2" Preferences = "1.4" Random = "1.10" RecursiveArrayTools = "3.37" -RecursiveFactorization = "0.2.23" +RecursiveFactorization = "0.2.26" Reexport = "1.2.2" SafeTestsets = "0.1" SciMLBase = "2.70" @@ -137,6 +138,7 @@ StableRNGs = "1.0" StaticArrays = "1.9" StaticArraysCore = "1.4.3" Test = "1.10" +TriangularSolve = "0.2.1" UnPack = "1.0.2" Zygote = "0.7" blis_jll = "0.9.0" diff --git a/benchmarks/lu.jl b/benchmarks/lu.jl index 896ee952e..d75e838f3 100644 --- a/benchmarks/lu.jl +++ b/benchmarks/lu.jl @@ -1,7 +1,8 @@ using BenchmarkTools, Random, VectorizationBase using LinearAlgebra, LinearSolve, MKL_jll +using RecursiveFactorization + nc = min(Int(VectorizationBase.num_cores()), Threads.nthreads()) -BLAS.set_num_threads(nc) BenchmarkTools.DEFAULT_PARAMETERS.seconds = 0.5 function luflop(m, n = m; innerflop = 2) @@ -24,10 +25,10 @@ algs = [ RFLUFactorization(), MKLLUFactorization(), FastLUFactorization(), - SimpleLUFactorization() + SimpleLUFactorization(), + ButterflyFactorization(Val(true)) ] res = [Float64[] for i in 1:length(algs)] - ns = 4:8:500 for i in 1:length(ns) n = ns[i] @@ -65,3 +66,4 @@ p savefig("lubench.png") savefig("lubench.pdf") + diff --git a/ext/LinearSolveRecursiveFactorizationExt.jl b/ext/LinearSolveRecursiveFactorizationExt.jl index a5c62b4c0..2bd0b0f73 100644 --- a/ext/LinearSolveRecursiveFactorizationExt.jl +++ b/ext/LinearSolveRecursiveFactorizationExt.jl @@ -1,11 +1,12 @@ module LinearSolveRecursiveFactorizationExt using LinearSolve: LinearSolve, userecursivefactorization, LinearCache, @get_cacheval, - RFLUFactorization, RF32MixedLUFactorization, default_alias_A, - default_alias_b, LinearVerbosity + RFLUFactorization, ButterflyFactorization, RF32MixedLUFactorization, + default_alias_A, default_alias_b, LinearVerbosity using LinearSolve.LinearAlgebra, LinearSolve.ArrayInterface, RecursiveFactorization using SciMLBase: SciMLBase, ReturnCode using SciMLLogging: @SciMLMessage +using TriangularSolve LinearSolve.userecursivefactorization(A::Union{Nothing, AbstractMatrix}) = true @@ -20,7 +21,6 @@ function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::RFLUFactorization end fact = RecursiveFactorization.lu!(A, ipiv, Val(P), Val(T), check = false) cache.cacheval = (fact, ipiv) - if !LinearAlgebra.issuccess(fact) @SciMLMessage("Solver failed", cache.verbose, :solver_failure) return SciMLBase.build_linear_solution( @@ -107,4 +107,41 @@ function SciMLBase.solve!( alg, cache.u, nothing, cache; retcode = ReturnCode.Success) end +function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::ButterflyFactorization; + kwargs...) + cache_A = cache.A + cache_A = convert(AbstractMatrix, cache_A) + cache_b = cache.b + M, N = size(cache_A) + workspace = cache.cacheval[1] + thread = alg.thread + + if cache.isfresh + @assert M==N "A must be square" + if (size(workspace.A, 1) != M) + workspace = RecursiveFactorization.🦋workspace(cache_A, cache_b) + end + (;A, b, ws, U, V, out, tmp, n) = workspace + RecursiveFactorization.🦋mul!(A, ws) + F = RecursiveFactorization.lu!(A, Val(false), thread) + cache.cacheval = (workspace, F) + cache.isfresh = false + end + + workspace, F = cache.cacheval + (;A, b, ws, U, V, out, tmp, n) = workspace + b[1:M] .= cache_b + mul!(tmp, U', b) + TriangularSolve.ldiv!(F, tmp, thread) + mul!(b, V, tmp) + out .= @view b[1:n] + SciMLBase.build_linear_solution(alg, out, nothing, cache) +end + +function LinearSolve.init_cacheval(alg::ButterflyFactorization, A, b, u, Pl, Pr, maxiters::Int, + abstol, reltol, verbose::Bool, assumptions::LinearSolve.OperatorAssumptions) + ws = RecursiveFactorization.🦋workspace(A, b), RecursiveFactorization.lu!(rand(1, 1), Val(false), alg.thread) end + +end + diff --git a/src/LinearSolve.jl b/src/LinearSolve.jl index fa2eadaa0..4c94150df 100644 --- a/src/LinearSolve.jl +++ b/src/LinearSolve.jl @@ -446,7 +446,7 @@ for kralg in (Krylov.lsmr!, Krylov.craigmr!) end for alg in (:LUFactorization, :FastLUFactorization, :SVDFactorization, :GenericFactorization, :GenericLUFactorization, :SimpleLUFactorization, - :RFLUFactorization, :UMFPACKFactorization, :KLUFactorization, :SparspakFactorization, + :RFLUFactorization, :ButterflyFactorization, :UMFPACKFactorization, :KLUFactorization, :SparspakFactorization, :DiagonalFactorization, :CholeskyFactorization, :BunchKaufmanFactorization, :CHOLMODFactorization, :LDLtFactorization, :AppleAccelerateLUFactorization, :MKLLUFactorization, :MetalLUFactorization, :CUSOLVERRFFactorization) @@ -480,7 +480,7 @@ cudss_loaded(A) = false is_cusparse(A) = false export LUFactorization, SVDFactorization, QRFactorization, GenericFactorization, - GenericLUFactorization, SimpleLUFactorization, RFLUFactorization, + GenericLUFactorization, SimpleLUFactorization, RFLUFactorization, ButterflyFactorization, NormalCholeskyFactorization, NormalBunchKaufmanFactorization, UMFPACKFactorization, KLUFactorization, FastLUFactorization, FastQRFactorization, SparspakFactorization, DiagonalFactorization, CholeskyFactorization, diff --git a/src/extension_algs.jl b/src/extension_algs.jl index 51cdb901f..70d373ad2 100644 --- a/src/extension_algs.jl +++ b/src/extension_algs.jl @@ -254,6 +254,29 @@ function RFLUFactorization(; pivot = Val(true), thread = Val(true), throwerror = RFLUFactorization(pivot, thread; throwerror) end +""" +`ButterflyFactorization()` + +A fast pure Julia LU-factorization implementation +using RecursiveFactorization.jl. This method utilizes a butterly +factorization approach rather than pivoting. +""" +struct ButterflyFactorization{T} <: AbstractDenseFactorization + thread::Val{T} + function ButterflyFactorization(::Val{T}; throwerror = true) where {T} + if !userecursivefactorization(nothing) + throwerror && + error("ButterflyFactorization requires that RecursiveFactorization.jl is loaded, i.e. `using RecursiveFactorization`") + end + new{T}() + end +end + +function ButterflyFactorization(; thread = Val(true), throwerror = true) + ButterflyFactorization(thread; throwerror) +end + + # There's no options like pivot here. # But I'm not sure it makes sense as a GenericFactorization # since it just uses `LAPACK.getrf!`. diff --git a/test/butterfly.jl b/test/butterfly.jl new file mode 100644 index 000000000..9e10ae43d --- /dev/null +++ b/test/butterfly.jl @@ -0,0 +1,35 @@ +using LinearAlgebra, LinearSolve +using Test +using RecursiveFactorization + +@testset "Random Matricies" begin + for i in 490 : 510 + A = rand(i, i) + b = rand(i) + prob = LinearProblem(A, b) + x = solve(prob, ButterflyFactorization()) + @test norm(A * x .- b) <= 1e-6 + end +end + +function wilkinson(N) + A = zeros(N, N) + A[1:(N+1):N*N] .= 1 + A[:, end] .= 1 + for n in 1:(N - 1) + for r in (n + 1):N + @inbounds A[r, n] = -1 + end + end + A +end + +@testset "Wilkinson" begin + for i in 790 : 810 + A = wilkinson(i) + b = rand(i) + prob = LinearProblem(A, b) + x = solve(prob, ButterflyFactorization()) + @test norm(A * x .- b) <= 1e-10 + end +end