Description
Description
Consider the reproduction. If, instead of S.TangentVector == Double
conformance in validateVJPWithError
the more generic S.TangentVector : FloatingPoint
is used, then the pullback value printed is correct (just comment out the line in reproduction and uncomment the preceding one). It turns out that the abstraction differences in derivatives are not taken into account, the function expects the pullback to return value indirectly, while it returns one direct.
Indeed, we are having:
%28 = function_ref @$s4conf20validateVJPWithError2of2atyq_xYjrXE_xt16_Differentiation14DifferentiableRzSFR_AeFR_13TangentVectorAeFPQy_Rs_SdAGRtzr0_lF : $@convention(thin) <τ_0_0, τ_0_1 where τ_0_0 : Differentiable, τ_0_1 : FloatingPoint, τ_0_1 : Differentiable, τ_0_1 == τ_0_1.TangentVector, τ_0_0.TangentVector == Double> (@guaranteed @differentiable(reverse) @noescape @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <τ_0_0, τ_0_1>, @in_guaranteed τ_0_0) -> () // user: %29
%29 = apply %28<Double, Double>(%25, %26) : $@convention(thin) <τ_0_0, τ_0_1 where τ_0_0 : Differentiable, τ_0_1 : FloatingPoint, τ_0_1 : Differentiable, τ_0_1 == τ_0_1.TangentVector, τ_0_0.TangentVector == Double> (@guaranteed @differentiable(reverse) @noescape @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <τ_0_0, τ_0_1>, @in_guaranteed τ_0_0) -> ()
Note that validateVJPWithError
accepts a differentiable function returning its result indirect (@out τ_0_1
). We are passing the following function to it (%25
):
%20 = differentiable_function_extract [vjp] %9 // user: %22
// function_ref thunk for @callee_guaranteed (@unowned Double) -> (@unowned Double, @owned @escaping @callee_guaranteed (@unowned Double) -> (@unowned Double))
%21 = function_ref @$sS4dIegyd_Igydo_S2dxSdRi_zRi0_zlySdIsegnd_Iegnro_TR : $@convention(thin) (@in_guaranteed Double, @guaranteed @noescape @callee_guaranteed (Double) -> (Double, @owned @callee_guaranteed (Double) -> Double)) -> (@out Double, @owned @callee_guaranteed @substituted <τ_0_0> (@in_guaranteed τ_0_0) -> Double for <Double>) // user: %22
%22 = partial_apply [callee_guaranteed] %21(%20) : $@convention(thin) (@in_guaranteed Double, @guaranteed @noescape @callee_guaranteed (Double) -> (Double, @owned @callee_guaranteed (Double) -> Double)) -> (@out Double, @owned @callee_guaranteed @substituted <τ_0_0> (@in_guaranteed τ_0_0) -> Double for <Double>) // user: %23
%23 = convert_function %22 to $@callee_guaranteed @substituted <τ_0_0, τ_0_1, τ_0_2> (@in_guaranteed τ_0_0) -> (@out τ_0_1, @owned @callee_guaranteed @substituted <τ_0_0> (@in_guaranteed τ_0_0) -> Double for <τ_0_2>) for <Double, Double, Double> // users: %34, %24
%24 = convert_escape_to_noescape %23 to $@noescape @callee_guaranteed @substituted <τ_0_0, τ_0_1, τ_0_2> (@in_guaranteed τ_0_0) -> (@out τ_0_1, @owned @callee_guaranteed @substituted <τ_0_0> (@in_guaranteed τ_0_0) -> Double for <τ_0_2>) for <Double, Double, Double> // user: %25
%25 = differentiable_function [parameters 0] [results 0] %14 with_derivative {%19, %24} // user: %29
Note that:
- VJP type is
$@noescape @callee_guaranteed @substituted <τ_0_0, τ_0_1, τ_0_2> (@in_guaranteed τ_0_0) -> (@out τ_0_1, @owned @callee_guaranteed @substituted <τ_0_0> (@in_guaranteed τ_0_0) -> Double for <τ_0_2>) for <Double, Double, Double>
, so the return value of VJP is returned indirect. However, the pullback result value is returned direct validateVJPWithError<A, B>(of:at:)
expects pullback to return value indirectly:
// function_ref valueWithPullback<A, B>(at:of:)
%7 = function_ref @$s16_Differentiation17valueWithPullback2at2ofq_0B0_13TangentVectorQzAFQy_c8pullbacktx_q_xYjrXEtAA14DifferentiableRzAaJR_r0_lF : $@convention(thin) <τ_0_0, τ_0_1 where τ_0_0 : Differentiable, τ_0_1 : Differentiable> (@in_guaranteed τ_0_0, @guaranteed @differentiable(reverse) @noescape @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <τ_0_0, τ_0_1>) -> (@out τ_0_1, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <τ_0_1.TangentVector, τ_0_0.TangentVector>) // user: %8
%8 = apply %7<S, T>(%5, %1, %0) : $@convention(thin) <τ_0_0, τ_0_1 where τ_0_0 : Differentiable, τ_0_1 : Differentiable> (@in_guaranteed τ_0_0, @guaranteed @differentiable(reverse) @noescape @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <τ_0_0, τ_0_1>) -> (@out τ_0_1, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <τ_0_1.TangentVector, τ_0_0.TangentVector>) // user: %9
%9 = convert_function %8 to $@callee_guaranteed (@in_guaranteed T) -> @out Double // user: %11
// function_ref thunk for @escaping @callee_guaranteed (@in_guaranteed B) -> (@out Double)
%10 = function_ref @$sq_SdIegnr_q_SdIegnd_16_Differentiation14DifferentiableRzSFR_AaBR_13TangentVectorAaBPQy_Rs_SdACRtzr0_lTR :$@convention(thin) <τ_0_0, τ_0_1 where τ_0_0 : Differentiable, τ_0_1 : FloatingPoint, τ_0_1 : Differentiable, τ_0_1 == τ_0_1.TangentVector, τ_0_0.TangentVector == Double> (@in_guaranteed τ_0_1, @guaranteed @callee_guaranteed (@in_guaranteed τ_0_1) -> @out Double) -> Double // user: %11
%11 = partial_apply [callee_guaranteed] %10<S, T>(%9) : $@convention(thin) <τ_0_0, τ_0_1 where τ_0_0 : Differentiable, τ_0_1 : FloatingPoint, τ_0_1 : Differentiable, τ_0_1 == τ_0_1.TangentVector, τ_0_0.TangentVector == Double (@in_guaranteed τ_0_1, @guaranteed @callee_guaranteed (@in_guaranteed τ_0_1) -> @out Double) -> Double // user: %12
%12 = convert_function %11 to $@callee_guaranteed @substituted <τ_0_0> (@in_guaranteed τ_0_0) -> Double for <T> // user: %13
As a result, we are having an abstraction difference and a junk value is returned. At the same time, removing the same-type conformance yields the following VJP type which introduced necessary reabstraction conversion:
%20 = differentiable_function_extract [vjp] %9 // user: %22
// function_ref thunk for @callee_guaranteed (@unowned Double) -> (@unowned Double, @owned @escaping @callee_guaranteed (@unowned Double) -> (@unowned Double))
%21 = function_ref @$sS4dIegyd_Igydo_S2dxq_Ri_zRi0_zRi__Ri0__r0_lyS2dIsegnr_Iegnro_TR : $@convention(thin)(@in_guaranteed Double, @guaranteed @noescape @callee_guaranteed (Double) -> (Double, @owned @callee_guaranteed (Double) -> Double)) -> (@out Double, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <Double, Double>) // user: %22
%22 = partial_apply [callee_guaranteed] %21(%20) : $@convention(thin) (@in_guaranteed Double, @guaranteed @noescape @callee_guaranteed (Double) -> (Double, @owned @callee_guaranteed (Double) -> Double)) -> (@out Double, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <Double, Double>) // user: %23
%23 = convert_function %22 to $@callee_guaranteed @substituted <τ_0_0, τ_0_1, τ_0_2, τ_0_3> (@in_guaranteed τ_0_0) -> (@out τ_0_1, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <τ_0_2, τ_0_3>) for <Double, Double, Double, Double> // users: %34, %24
%24 = convert_escape_to_noescape %23 to $@noescape @callee_guaranteed @substituted <τ_0_0, τ_0_1, τ_0_2, τ_0_3> (@in_guaranteed τ_0_0) -> (@out τ_0_1, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <τ_0_2, τ_0_3>) for <Double, Double, Double, Double> // user: %25
Reproduction
import _Differentiation
@differentiable(reverse)
@_silgen_name("_oneOverX")
func oneOverX(_ x: Double) -> Double {
1 / x
}
@_silgen_name("_vjpOneOverX")
func _vjpOneOverX(_ x: Double) -> (value: Double, pullback: (Double) -> Double) {
(
value: 1 / x,
pullback: { v in
-v / (x * x)
}
)
}
@_silgen_name("_pp")
@inline(never)
func pp<T>(_ v : T) {
print(v)
}
@inline(never)
func validateVJPWithError<S, T>(
of function: @differentiable(reverse) (S) -> T,
at point: S
) where S: Differentiable, T: Differentiable, T: FloatingPoint
, T == T.TangentVector
//, S.TangentVector: FloatingPoint
, S.TangentVector == Double
{
let vwpb = valueWithPullback(at: point, of: function)
let pullback = vwpb.pullback(.init(1))
pp(pullback)
}
@_silgen_name("_testOneOverX")
func testOneOverX(_ x: Double) {
validateVJPWithError(of: { x in oneOverX(x) },
at: x)
}
testOneOverX(10.0)
Expected behavior
Proper pullback value is printed. It would be also great if SIL verifier would catch this abstraction difference.
Environment
Swift version 6.2-dev (LLVM e404f8897f17aff, Swift b47b157)
Target: arm64-apple-macosx15.0
Additional information
No response