-
-
Notifications
You must be signed in to change notification settings - Fork 615
Closed
FluxML/Zygote.jl
#1555Description
After updating Flux from version 0.16.1 to 0.16.2, the gradient function enters an infinite loop during evaluation.
I think this issue is related to Zygote version. I saw that there was a discussion about some problems: #2574
Minimal Workable Example
using Flux
function compute_atomic_energy(
input_layer::AbstractVector{T},
model::Flux.Chain,
)::T where {T<:AbstractFloat}
return only(model(input_layer))
end
function compute_system_total_energy_scalar(
symm_func_matrix::AbstractMatrix{T},
model::Flux.Chain,
) where {T<:AbstractFloat}
return sum(compute_atomic_energy(row, model) for row in eachrow(symm_func_matrix))
end
function compute_energy_gradients(
symm_func_matrix::AbstractMatrix{T},
model::Flux.Chain,
) where {T<:AbstractFloat}
# THE PROBLEMS IS HERE #################################################################
gs = gradient(compute_system_total_energy_scalar, symm_func_matrix, model)
# ######################################################################################
energy_gradients = gs[2]
return energy_gradients
end
function resnet_block(dim::Int)
skip_chain = Chain(
Dense(dim => dim),
LayerNorm(dim),
relu,
Dense(dim => dim),
LayerNorm(dim),
relu,
)
return Chain(SkipConnection(skip_chain, +), relu)
end
function create_resnet(input_dim::Int, hidden_dim::Int, n_blocks::Int)
initial_layers = Chain(Dense(input_dim => hidden_dim), LayerNorm(hidden_dim), relu)
res_blocks = Chain([resnet_block(hidden_dim) for _ = 1:n_blocks]...)
final_layer = Dense(hidden_dim => 1)
return Chain(initial_layers, res_blocks, final_layer)
end
function model_init()
input_dim = 10
hidden_dim = 3
n_blocks = 2
model = create_resnet(input_dim, hidden_dim, n_blocks)
model = f64(model)
@show model
return model
end
model = model_init()
input_matrix = rand(512, 10)
resulting_gradient = compute_energy_gradients(input_matrix, model)
Environment
Julia Version: 1.11.3
Flux Version: 0.16.2
Zygote Version: 0.7.3
Operating System: macOS 15.1.1