-
Notifications
You must be signed in to change notification settings - Fork 10.5k
Closed
Description
Previous ID | SR-12151 |
Radar | None |
Original Reporter | @dan-zheng |
Type | Sub-task |
Status | Resolved |
Resolution | Done |
Additional Detail from JIRA
Votes | 0 |
Component/s | |
Labels | Sub-task |
Assignee | @dan-zheng |
Priority | Medium |
md5: 7cd503fb0d18c97a1b671a7f91c7b875
Parent-Task:
- SR-12148 Make differentiation work with class types
Issue Description:
(previously TF-654)
Support differentiation of class initializers:
class Super : Differentiable {
var base: Float
// FIXME(TF-648): Dummy to make `Super.AllDifferentiableVariables` be nontrivial.
var _nontrivial: [Float] = []
// TODO(TF-645): Remove `vjpInit` when differentiation supports `ref_element_addr`.
@differentiable(vjp: vjpInit)
required init(base: Float) {
self.base = base
}
static func vjpInit(base: Float) -> (Super, (TangentVector) -> Float) {
return (Super(base: base), { x in x.base })
// return (Super(base: base), { _ in 300 })
}
@differentiable(wrt: (self, x), jvp: jvpf, vjp: vjpf)
func f(_ x: Float) -> Float {
return base * x
}
final func jvpf(_ x: Float) -> (Float, (TangentVector, Float) -> Float) {
return (f(x), { (dself, dx) in dself.base * dx })
}
final func vjpf(_ x: Float) -> (Float, (Float) -> (TangentVector, Float)) {
let base = self.base
return (f(x), { v in
(TangentVector(base: v * x, _nontrivial: []), base * v)
})
}
}
class SubOverride : Super {
@differentiable(wrt: (self, x))
override func f(_ x: Float) -> Float {
return 3 * x
}
}
class SubOverrideCustomDerivatives : Super {
@differentiable(wrt: (self, x))
@differentiable(wrt: x, jvp: jvpf2, vjp: vjpf2)
override func f(_ x: Float) -> Float {
return 3 * x
}
final func jvpf2(_ x: Float) -> (Float, (Float) -> Float) {
return (f(x), { v in 3 * v })
}
final func vjpf2(_ x: Float) -> (Float, (Float) -> Float) {
return (f(x), { v in 3 * v })
}
}
let v = Super.TangentVector(base: 100, _nontrivial: [])
print(pullback(at: 1337) { x in Super(base: x) }(v))
print(pullback(at: 1337) { x in SubOverride(base: x) }(v))
print(pullback(at: 1337) { x in SubOverrideCustomDerivatives(base: x) }(v))
Things are bit tricky because classes have allocators in addition to initializers.
SIL allocators take a @thick Class.Type
metatype.
SIL initializers take an existing class instance.
An AST-level ConstructorDecl
lowers to both an allocator and an initializer, so both SIL functions get an `[differentiable]` attribute lowered from the same `@differentiable` attribute.
// Super.__allocating_init(base:)
sil hidden [differentiable source 0 wrt 0 vjp @$s5error5SuperC7vjpInit4baseAC_SfAC26AllDifferentiableVariablesVctSf_tFZ]
[ossa] @$s5error5SuperC4baseACSf_tcfC : $@convention(method) (Float, @thick Super.Type) -> @owned Super {
// %0 // user: %4
bb0(%0 : $Float, %1 : $@thick Super.Type):
%2 = alloc_ref $Super // user: %4
// function_ref Super.init(base:)
%3 = function_ref @$s5error5SuperC4baseACSf_tcfc : $@convention(method) (Float, @owned Super) -> @owned Super // user: %4
%4 = apply %3(%0, %2) : $@convention(method) (Float, @owned Super) -> @owned Super // user: %5
return %4 : $Super // id: %5
} // end sil function '$s5error5SuperC4baseACSf_tcfC'
// Super.init(base:)
sil hidden [differentiable source 0 wrt 0 vjp @$s5error5SuperC7vjpInit4baseAC_SfAC26AllDifferentiableVariablesVctSf_tFZ] [ossa] @$s5error5SuperC4baseACSf_tcfc : $@convention(method) (Float, @owned Super) -> @owned Super {
// %0 // users: %16, %2
// %1 // users: %4, %3
bb0(%0 : $Float, %1 : @owned $Super):
debug_value %0 : $Float, let, name "base", argno 1 // id: %2
debug_value %1 : $Super, let, name "self", argno 2 // id: %3
%4 = mark_uninitialized [rootself] %1 : $Super // users: %20, %19, %13, %7
// function_ref variable initialization expression of Super._nontrivial
%5 = function_ref @$s5error5SuperC11_nontrivialSaySfGvpfi : $@convention(thin) () -> @owned Array<Float> // user: %6
%6 = apply %5() : $@convention(thin) () -> @owned Array<Float> // user: %10
%7 = begin_borrow %4 : $Super // users: %12, %8
%8 = ref_element_addr %7 : $Super, #Super._nontrivial // user: %9
%9 = begin_access [modify] [dynamic] %8 : $*Array<Float> // users: %11, %10
assign %6 to %9 : $*Array<Float> // id: %10
end_access %9 : $*Array<Float> // id: %11
end_borrow %7 : $Super // id: %12
%13 = begin_borrow %4 : $Super // users: %18, %14
%14 = ref_element_addr %13 : $Super, #Super.base // user: %15
%15 = begin_access [modify] [dynamic] %14 : $*Float // users: %17, %16
assign %0 to %15 : $*Float // id: %16
end_access %15 : $*Float // id: %17
end_borrow %13 : $Super // id: %18
%19 = copy_value %4 : $Super // user: %21
destroy_value %4 : $Super // id: %20
return %19 : $Super // id: %21
} // end sil function '$s5error5SuperC4baseACSf_tcfc'
However, if a JVP/VJP is registered in the @differentiable
, a type mismatch occurs for the initializer (which doesn't take a metatype):
SIL verification failed: VJP type does not match expected VJP type
$@convention(method) (Float, @thick Super.Type) -> (@owned Super, @owned @callee_guaranteed (@guaranteed Super.AllDifferentiableVariables) -> Float)
$@convention(method) (Float, @owned Super) -> (@owned Super, @owned @callee_guaranteed (@guaranteed Super.AllDifferentiableVariables) -> Float)
Verifying instruction:
// function_ref Super.init(base:)
%3 = function_ref @$s5error5SuperC4baseACSf_tcfc : $@convention(method) (Float, @owned Super) -> @owned Super // users: %7, %4
// function_ref AD__$s5error5SuperC4baseACSf_tcfc__jvp_src_0_wrt_0
%5 = function_ref @AD__$s5error5SuperC4baseACSf_tcfc__jvp_src_0_wrt_0 : $@convention(method) (Float, @owned Super) -> (@owned Super, @owned @callee_guaranteed (Float) -> @owned Super.AllDifferentiableVariables) // user: %7
// function_ref static Super.vjpInit(base:)
%6 = function_ref @$s5error5SuperC7vjpInit4baseAC_SfAC26AllDifferentiableVariablesVctSf_tFZ : $@convention(method) (Float, @thick Super.Type) -> (@owned Super, @owned @callee_guaranteed (@guaranteed Super.AllDifferentiableVariables) -> Float) // users: %8, %7
-> %7 = autodiff_function [wrt 0] [order 1] %3 : $@convention(method) (Float, @owned Super) -> @owned Super with {%5 : $@convention(method) (Float, @owned Super) -> (@owned Super, @owned @callee_guaranteed (Float) -> @owned Super.AllDifferentiableVariables), %6 : $@convention(method) (Float, @thick Super.Type) -> (@owned Super, @owned @callee_guaranteed (@guaranteed Super.AllDifferentiableVariables) -> Float)} // user: %9
release_value %7 : $@differentiable @convention(method) (Float, @nondiff @owned Super) -> @owned Super // id: %9
Metadata
Metadata
Assignees
Labels
No labels