-
Notifications
You must be signed in to change notification settings - Fork 10.5k
Closed
Description
Previous ID | SR-12153 |
Radar | None |
Original Reporter | @dan-zheng |
Type | Sub-task |
Status | Closed |
Resolution | Done |
Additional Detail from JIRA
Votes | 0 |
Component/s | |
Labels | Sub-task |
Assignee | @dan-zheng |
Priority | Medium |
md5: 30381c924b0a7bcac2fa8a47923d6aba
Parent-Task:
- SR-12148 Make differentiation work with class types
Issue Description:
Casting instructions are generated in subclass initializers, among other places.
They should be handled during pullback generation.
class Super : Differentiable {
var base: Float
// TODO(TF-645): Remove `vjpInit` when differentiation supports
// `ref_element_addr`.
@differentiable(vjp: vjpInit)
init(base: Float) { self.base = base }
static func vjpInit(base: Float) -> (Super, (TangentVector) -> Float) {
return (Super(base: base), { x in x.base })
}
}
// `Sub.init(base:)` is automatically generated for `Super`.
// In SIL: it uses `upcast` to upcast and `unchecked_ref_cast` to downcast.
class Sub : Super {}
print(pullback(at: 2) { x in Super(base: x) }(Super.TangentVector(base: 1)))
print(pullback(at: 2) { x in Sub(base: x) }(Super.TangentVector(base: 1)))
[AD] Running PullbackEmitter on
// Sub.init(base:)
sil hidden [differentiable source 0 wrt 0 jvp @AD__$s4main3SubC4baseACSf_tcfc__jvp_src_0_wrt_0 vjp @AD__$s4main3SubC4baseACSf_tcfc__vjp_src_0_wrt_0] @$s4main3SubC4baseACSf_tcfc : $@convention(method) (Float, @owned Sub) -> @owned Sub {
// %0 // users: %8, %3
// %1 // user: %4
bb0(%0 : $Float, %1 : $Sub):
%2 = alloc_stack $Sub, let, name "self" // users: %11, %10, %5, %4, %13, %14
debug_value %0 : $Float, let, name "base", argno 1 // id: %3
store %1 to %2 : $*Sub // id: %4
%5 = load %2 : $*Sub // user: %6
%6 = upcast %5 : $Sub to $Super // user: %8
// function_ref Super.init(base:)
%7 = function_ref @$s4main5SuperC4baseACSf_tcfc : $@convention(method) (Float, @owned Super) -> @owned Super // user: %8
%8 = apply %7(%0, %6) : $@convention(method) (Float, @owned Super) -> @owned Super // user: %9
%9 = unchecked_ref_cast %8 : $Super to $Sub // user: %10
store %9 to %2 : $*Sub // id: %10
%11 = load %2 : $*Sub // users: %15, %12
strong_retain %11 : $Sub // id: %12
destroy_addr %2 : $*Sub // id: %13
dealloc_stack %2 : $*Sub // id: %14
return %11 : $Sub // id: %15
} // end sil function '$s4main3SubC4baseACSf_tcfc'
...
[AD] Unhandled instruction in adjoint emitter: %9 = unchecked_ref_cast %8 : $Super to $Sub // user: %10
[AD] Diagnosing non-differentiability.
[AD] For instruction:
%9 = unchecked_ref_cast %8 : $Super to $Sub // user: %10
[AD] With invoker:
(differentiation_invoker autodiff_function_inst=( %5 = autodiff_function [wrt 0] [order 1] %3 : $@convention(method) (Float, @owned Sub) -> @owned Sub // users: %8, %6
))
<unknown>:0: error: expression is not differentiable
<unknown>:0: note: expression is not differentiable
Metadata
Metadata
Assignees
Labels
No labels