Skip to content

[AutoDiff] Invalid derivative type calculation for same-type conformance #78358

Open
@asl

Description

@asl

Description

Consider the reproduction. If, instead of S.TangentVector == Double conformance in validateVJPWithError the more generic S.TangentVector : FloatingPoint is used, then the pullback value printed is correct (just comment out the line in reproduction and uncomment the preceding one). It turns out that the abstraction differences in derivatives are not taken into account, the function expects the pullback to return value indirectly, while it returns one direct.

Indeed, we are having:

  %28 = function_ref @$s4conf20validateVJPWithError2of2atyq_xYjrXE_xt16_Differentiation14DifferentiableRzSFR_AeFR_13TangentVectorAeFPQy_Rs_SdAGRtzr0_lF : $@convention(thin) <τ_0_0, τ_0_1 where τ_0_0 : Differentiable, τ_0_1 : FloatingPoint, τ_0_1 : Differentiable, τ_0_1 == τ_0_1.TangentVector, τ_0_0.TangentVector == Double> (@guaranteed @differentiable(reverse) @noescape @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <τ_0_0, τ_0_1>, @in_guaranteed τ_0_0) -> () // user: %29
  %29 = apply %28<Double, Double>(%25, %26) : $@convention(thin) <τ_0_0, τ_0_1 where τ_0_0 : Differentiable, τ_0_1 : FloatingPoint, τ_0_1 : Differentiable, τ_0_1 == τ_0_1.TangentVector, τ_0_0.TangentVector == Double> (@guaranteed @differentiable(reverse) @noescape @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <τ_0_0, τ_0_1>, @in_guaranteed τ_0_0) -> ()

Note that validateVJPWithError accepts a differentiable function returning its result indirect (@out τ_0_1). We are passing the following function to it (%25):

  %20 = differentiable_function_extract [vjp] %9  // user: %22
  // function_ref thunk for @callee_guaranteed (@unowned Double) -> (@unowned Double, @owned @escaping @callee_guaranteed (@unowned Double) -> (@unowned Double))
  %21 = function_ref @$sS4dIegyd_Igydo_S2dxSdRi_zRi0_zlySdIsegnd_Iegnro_TR : $@convention(thin) (@in_guaranteed Double, @guaranteed @noescape @callee_guaranteed (Double) -> (Double, @owned @callee_guaranteed (Double) -> Double)) -> (@out Double, @owned @callee_guaranteed @substituted <τ_0_0> (@in_guaranteed τ_0_0) -> Double for <Double>) // user: %22
  %22 = partial_apply [callee_guaranteed] %21(%20) : $@convention(thin) (@in_guaranteed Double, @guaranteed @noescape @callee_guaranteed (Double) -> (Double, @owned @callee_guaranteed (Double) -> Double)) -> (@out Double, @owned @callee_guaranteed @substituted <τ_0_0> (@in_guaranteed τ_0_0) -> Double for <Double>) // user: %23
  %23 = convert_function %22 to $@callee_guaranteed @substituted <τ_0_0, τ_0_1, τ_0_2> (@in_guaranteed τ_0_0) -> (@out τ_0_1, @owned @callee_guaranteed @substituted <τ_0_0> (@in_guaranteed τ_0_0) -> Double for <τ_0_2>) for <Double, Double, Double> // users: %34, %24
  %24 = convert_escape_to_noescape %23 to $@noescape @callee_guaranteed @substituted <τ_0_0, τ_0_1, τ_0_2> (@in_guaranteed τ_0_0) -> (@out τ_0_1, @owned @callee_guaranteed @substituted <τ_0_0> (@in_guaranteed τ_0_0) -> Double for <τ_0_2>) for <Double, Double, Double> // user: %25
  %25 = differentiable_function [parameters 0] [results 0] %14 with_derivative {%19, %24} // user: %29

Note that:

  • VJP type is $@noescape @callee_guaranteed @substituted <τ_0_0, τ_0_1, τ_0_2> (@in_guaranteed τ_0_0) -> (@out τ_0_1, @owned @callee_guaranteed @substituted <τ_0_0> (@in_guaranteed τ_0_0) -> Double for <τ_0_2>) for <Double, Double, Double>, so the return value of VJP is returned indirect. However, the pullback result value is returned direct
  • validateVJPWithError<A, B>(of:at:) expects pullback to return value indirectly:
  // function_ref valueWithPullback<A, B>(at:of:)
  %7 = function_ref @$s16_Differentiation17valueWithPullback2at2ofq_0B0_13TangentVectorQzAFQy_c8pullbacktx_q_xYjrXEtAA14DifferentiableRzAaJR_r0_lF : $@convention(thin) <τ_0_0, τ_0_1 where τ_0_0 : Differentiable, τ_0_1 : Differentiable> (@in_guaranteed τ_0_0, @guaranteed @differentiable(reverse) @noescape @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <τ_0_0, τ_0_1>) -> (@out τ_0_1, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <τ_0_1.TangentVector, τ_0_0.TangentVector>) // user: %8
  %8 = apply %7<S, T>(%5, %1, %0) : $@convention(thin) <τ_0_0, τ_0_1 where τ_0_0 : Differentiable, τ_0_1 : Differentiable> (@in_guaranteed τ_0_0, @guaranteed @differentiable(reverse) @noescape @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <τ_0_0, τ_0_1>) -> (@out τ_0_1, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <τ_0_1.TangentVector, τ_0_0.TangentVector>) // user: %9
  %9 = convert_function %8 to $@callee_guaranteed (@in_guaranteed T) -> @out Double // user: %11
  // function_ref thunk for @escaping @callee_guaranteed (@in_guaranteed B) -> (@out Double)
  %10 = function_ref @$sq_SdIegnr_q_SdIegnd_16_Differentiation14DifferentiableRzSFR_AaBR_13TangentVectorAaBPQy_Rs_SdACRtzr0_lTR :$@convention(thin) <τ_0_0, τ_0_1 where τ_0_0 : Differentiable, τ_0_1 : FloatingPoint, τ_0_1 : Differentiable, τ_0_1 == τ_0_1.TangentVector, τ_0_0.TangentVector == Double> (@in_guaranteed τ_0_1, @guaranteed @callee_guaranteed (@in_guaranteed τ_0_1) -> @out Double) -> Double // user: %11
  %11 = partial_apply [callee_guaranteed] %10<S, T>(%9) : $@convention(thin) <τ_0_0, τ_0_1 where τ_0_0 : Differentiable, τ_0_1 : FloatingPoint, τ_0_1 : Differentiable, τ_0_1 == τ_0_1.TangentVector, τ_0_0.TangentVector == Double (@in_guaranteed τ_0_1, @guaranteed @callee_guaranteed (@in_guaranteed τ_0_1) -> @out Double) -> Double // user: %12
  %12 = convert_function %11 to $@callee_guaranteed @substituted <τ_0_0> (@in_guaranteed τ_0_0) -> Double for <T> // user: %13

As a result, we are having an abstraction difference and a junk value is returned. At the same time, removing the same-type conformance yields the following VJP type which introduced necessary reabstraction conversion:

  %20 = differentiable_function_extract [vjp] %9  // user: %22
  // function_ref thunk for @callee_guaranteed (@unowned Double) -> (@unowned Double, @owned @escaping @callee_guaranteed (@unowned Double) -> (@unowned Double))
  %21 = function_ref @$sS4dIegyd_Igydo_S2dxq_Ri_zRi0_zRi__Ri0__r0_lyS2dIsegnr_Iegnro_TR : $@convention(thin)(@in_guaranteed Double, @guaranteed @noescape @callee_guaranteed (Double) -> (Double, @owned @callee_guaranteed (Double) -> Double)) -> (@out Double, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <Double, Double>) // user: %22
  %22 = partial_apply [callee_guaranteed] %21(%20) : $@convention(thin) (@in_guaranteed Double, @guaranteed @noescape @callee_guaranteed (Double) -> (Double, @owned @callee_guaranteed (Double) -> Double)) -> (@out Double, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <Double, Double>) // user: %23
  %23 = convert_function %22 to $@callee_guaranteed @substituted <τ_0_0, τ_0_1, τ_0_2, τ_0_3> (@in_guaranteed τ_0_0) -> (@out τ_0_1, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <τ_0_2, τ_0_3>) for <Double, Double, Double, Double> // users: %34, %24
  %24 = convert_escape_to_noescape %23 to $@noescape @callee_guaranteed @substituted <τ_0_0, τ_0_1, τ_0_2, τ_0_3> (@in_guaranteed τ_0_0) -> (@out τ_0_1, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <τ_0_2, τ_0_3>) for <Double, Double, Double, Double> // user: %25

Reproduction

import _Differentiation

@differentiable(reverse)
@_silgen_name("_oneOverX")
func oneOverX(_ x: Double) -> Double {
    1 / x
}

@_silgen_name("_vjpOneOverX")
func _vjpOneOverX(_ x: Double) -> (value: Double, pullback: (Double) -> Double) {
    (
        value: 1 / x,
        pullback: { v in
            -v / (x * x)
        }
    )
}

@_silgen_name("_pp")
@inline(never)
func pp<T>(_ v : T) {
  print(v)
}

@inline(never)
func validateVJPWithError<S, T>(
    of function: @differentiable(reverse) (S) -> T,
    at point: S
) where S: Differentiable, T: Differentiable, T: FloatingPoint
    , T == T.TangentVector
    //, S.TangentVector: FloatingPoint
    , S.TangentVector == Double
    {
    let vwpb = valueWithPullback(at: point, of: function)
    let pullback = vwpb.pullback(.init(1))

    pp(pullback)
}

@_silgen_name("_testOneOverX")
func testOneOverX(_ x: Double) {
    validateVJPWithError(of: { x in oneOverX(x) }, 
                        at: x)
}

testOneOverX(10.0)

Expected behavior

Proper pullback value is printed. It would be also great if SIL verifier would catch this abstraction difference.

Environment

Swift version 6.2-dev (LLVM e404f8897f17aff, Swift b47b157)
Target: arm64-apple-macosx15.0

Additional information

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    AutoDiffbugA deviation from expected or documented behavior. Also: expected but undesirable behavior.

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions