Skip to content

Commit aa5e49d

Browse files
authored
fit/solver small refactors & type stability (#294)
* codestable resize Refactor equalize_size function to improve handling of data dimensions and simplify logic. * Rename and update equalize_size function Renamed function from _equalize_size to equalize_size and updated its signature. * small refactors & type stability * format
1 parent 434e791 commit aa5e49d

File tree

1 file changed

+15
-14
lines changed

1 file changed

+15
-14
lines changed

src/solver.jl

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -22,18 +22,21 @@ function solver_default(
2222
multithreading = true,
2323
show_progress = true,
2424
) where {T<:Union{Missing,<:Number}}
25-
minfo = Array{IterativeSolvers.ConvergenceHistory,1}(undef, size(data, 1))
25+
n_channels = size(data, 1)
26+
n_predictors = size(X, 2)
27+
minfo = Array{IterativeSolvers.ConvergenceHistory,1}(undef, n_channels)
2628

27-
beta = zeros(T, size(data, 1), size(X, 2)) # had issues with undef
29+
beta = zeros(T, n_channels, n_predictors)
2830

29-
p = Progress(size(data, 1); enabled = show_progress)
30-
X = SparseMatrixCSC(X) # X s often a SubArray, lsmr really doesnt like indexing into subarrays, one copy needed.
31-
@maybe_threads multithreading for ch = 1:size(data, 1)
31+
p = Progress(n_channels; enabled = show_progress)
32+
X_sparse = SparseMatrixCSC(X) # X s often a SubArray, lsmr really doesnt like indexing into subarrays, one copy needed.
33+
34+
@maybe_threads multithreading for ch = 1:n_channels
3235

3336
# use the previous channel as a starting point
34-
ch == 1 || copyto!(view(beta, ch, :), view(beta, ch - 1, :))
37+
ch > 1 && copyto!(view(beta, ch, :), view(beta, ch - 1, :))
3538

36-
h = _lsmr!(beta, X, data, ch)
39+
h = _lsmr!(beta, X_sparse, data, ch)
3740
minfo[ch] = h
3841
next!(p)
3942
end
@@ -55,13 +58,11 @@ function solver_default(
5558
multithreading = true,
5659
show_progress = true,
5760
) where {T<:Union{Missing,<:Number}}
58-
#beta = Array{Union{Missing,Number}}(undef, size(data, 1), size(data, 2), size(X, 2))
59-
beta = zeros(T, size(data, 1), size(data, 2), size(X, 2))
60-
p = Progress(size(data, 1); enabled = show_progress)
61-
@maybe_threads multithreading for ch = 1:size(data, 1)
62-
for t = 1:size(data, 2)
63-
# @debug("$(ndims(data,)),$t,$ch")
64-
61+
n_channels, n_times, n_predictors = size(data, 1), size(data, 2), size(X, 2)
62+
beta = zeros(T, n_channels, n_times, n_predictors)
63+
p = Progress(n_channels; enabled = show_progress)
64+
@maybe_threads multithreading for ch = 1:n_channels
65+
for t = 1:n_times
6566
dd = view(data, ch, t, :)
6667
ix = @. !ismissing(dd)
6768

0 commit comments

Comments
 (0)