Skip to content

Infinite time of gradient #2585

@mposysoev

Description

@mposysoev

After updating Flux from version 0.16.1 to 0.16.2, the gradient function enters an infinite loop during evaluation.

I think this issue is related to Zygote version. I saw that there was a discussion about some problems: #2574

Minimal Workable Example

using Flux

function compute_atomic_energy(
    input_layer::AbstractVector{T},
    model::Flux.Chain,
)::T where {T<:AbstractFloat}
    return only(model(input_layer))
end

function compute_system_total_energy_scalar(
    symm_func_matrix::AbstractMatrix{T},
    model::Flux.Chain,
) where {T<:AbstractFloat}
    return sum(compute_atomic_energy(row, model) for row in eachrow(symm_func_matrix))
end

function compute_energy_gradients(
    symm_func_matrix::AbstractMatrix{T},
    model::Flux.Chain,
) where {T<:AbstractFloat}
    # THE PROBLEMS IS HERE #################################################################
    gs = gradient(compute_system_total_energy_scalar, symm_func_matrix, model)
    # ######################################################################################
    energy_gradients = gs[2]
    return energy_gradients
end

function resnet_block(dim::Int)
    skip_chain = Chain(
        Dense(dim => dim),
        LayerNorm(dim),
        relu,
        Dense(dim => dim),
        LayerNorm(dim),
        relu,
    )
    return Chain(SkipConnection(skip_chain, +), relu)
end

function create_resnet(input_dim::Int, hidden_dim::Int, n_blocks::Int)
    initial_layers = Chain(Dense(input_dim => hidden_dim), LayerNorm(hidden_dim), relu)

    res_blocks = Chain([resnet_block(hidden_dim) for _ = 1:n_blocks]...)
    final_layer = Dense(hidden_dim => 1)

    return Chain(initial_layers, res_blocks, final_layer)
end

function model_init()
    input_dim = 10
    hidden_dim = 3
    n_blocks = 2
    model = create_resnet(input_dim, hidden_dim, n_blocks)

    model = f64(model)
    @show model

    return model
end

model = model_init()

input_matrix = rand(512, 10)

resulting_gradient = compute_energy_gradients(input_matrix, model)

Environment

Julia Version: 1.11.3
Flux Version: 0.16.2
Zygote Version: 0.7.3
Operating System: macOS 15.1.1

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions