Skip to content

[SR-12151] Enable differentiation of class initializers #53546

@dan-zheng

Description

@dan-zheng
Previous ID SR-12151
Radar None
Original Reporter @dan-zheng
Type Sub-task
Status Resolved
Resolution Done
Additional Detail from JIRA
Votes 0
Component/s
Labels Sub-task
Assignee @dan-zheng
Priority Medium

md5: 7cd503fb0d18c97a1b671a7f91c7b875

Parent-Task:

  • SR-12148 Make differentiation work with class types

Issue Description:

(previously TF-654)

Support differentiation of class initializers:

class Super : Differentiable {
  var base: Float
  // FIXME(TF-648): Dummy to make `Super.AllDifferentiableVariables` be nontrivial.
  var _nontrivial: [Float] = []

  // TODO(TF-645): Remove `vjpInit` when differentiation supports `ref_element_addr`.
  @differentiable(vjp: vjpInit)
  required init(base: Float) {
    self.base = base
  }
  static func vjpInit(base: Float) -> (Super, (TangentVector) -> Float) {
    return (Super(base: base), { x in x.base })
    // return (Super(base: base), { _ in 300 })
  }

  @differentiable(wrt: (self, x), jvp: jvpf, vjp: vjpf)
  func f(_ x: Float) -> Float {
    return base * x
  }
  final func jvpf(_ x: Float) -> (Float, (TangentVector, Float) -> Float) {
    return (f(x), { (dself, dx) in dself.base * dx })
  }
  final func vjpf(_ x: Float) -> (Float, (Float) -> (TangentVector, Float)) {
    let base = self.base
    return (f(x), { v in
      (TangentVector(base: v * x, _nontrivial: []), base * v)
    })
  }
}

class SubOverride : Super {
  @differentiable(wrt: (self, x))
  override func f(_ x: Float) -> Float {
    return 3 * x
  }
}

class SubOverrideCustomDerivatives : Super {
  @differentiable(wrt: (self, x))
  @differentiable(wrt: x, jvp: jvpf2, vjp: vjpf2)
  override func f(_ x: Float) -> Float {
    return 3 * x
  }
  final func jvpf2(_ x: Float) -> (Float, (Float) -> Float) {
    return (f(x), { v in 3 * v })
  }
  final func vjpf2(_ x: Float) -> (Float, (Float) -> Float) {
    return (f(x), { v in 3 * v })
  }
}

let v = Super.TangentVector(base: 100, _nontrivial: [])
print(pullback(at: 1337) { x in Super(base: x) }(v))
print(pullback(at: 1337) { x in SubOverride(base: x) }(v))
print(pullback(at: 1337) { x in SubOverrideCustomDerivatives(base: x) }(v))

Things are bit tricky because classes have allocators in addition to initializers.
SIL allocators take a @thick Class.Type metatype.
SIL initializers take an existing class instance.
An AST-level ConstructorDecl lowers to both an allocator and an initializer, so both SIL functions get an `[differentiable]` attribute lowered from the same `@differentiable` attribute.

// Super.__allocating_init(base:)
sil hidden [differentiable source 0 wrt 0 vjp @$s5error5SuperC7vjpInit4baseAC_SfAC26AllDifferentiableVariablesVctSf_tFZ]
 [ossa] @$s5error5SuperC4baseACSf_tcfC : $@convention(method) (Float, @thick Super.Type) -> @owned Super {
// %0                                             // user: %4
bb0(%0 : $Float, %1 : $@thick Super.Type):
  %2 = alloc_ref $Super                           // user: %4
  // function_ref Super.init(base:)
  %3 = function_ref @$s5error5SuperC4baseACSf_tcfc : $@convention(method) (Float, @owned Super) -> @owned Super // user: %4
  %4 = apply %3(%0, %2) : $@convention(method) (Float, @owned Super) -> @owned Super // user: %5
  return %4 : $Super                              // id: %5
} // end sil function '$s5error5SuperC4baseACSf_tcfC'

// Super.init(base:)
sil hidden [differentiable source 0 wrt 0 vjp @$s5error5SuperC7vjpInit4baseAC_SfAC26AllDifferentiableVariablesVctSf_tFZ] [ossa] @$s5error5SuperC4baseACSf_tcfc : $@convention(method) (Float, @owned Super) -> @owned Super {
// %0                                             // users: %16, %2
// %1                                             // users: %4, %3
bb0(%0 : $Float, %1 : @owned $Super):
  debug_value %0 : $Float, let, name "base", argno 1 // id: %2
  debug_value %1 : $Super, let, name "self", argno 2 // id: %3
  %4 = mark_uninitialized [rootself] %1 : $Super  // users: %20, %19, %13, %7
  // function_ref variable initialization expression of Super._nontrivial
  %5 = function_ref @$s5error5SuperC11_nontrivialSaySfGvpfi : $@convention(thin) () -> @owned Array<Float> // user: %6
  %6 = apply %5() : $@convention(thin) () -> @owned Array<Float> // user: %10
  %7 = begin_borrow %4 : $Super                   // users: %12, %8
  %8 = ref_element_addr %7 : $Super, #Super._nontrivial // user: %9
  %9 = begin_access [modify] [dynamic] %8 : $*Array<Float> // users: %11, %10
  assign %6 to %9 : $*Array<Float>                // id: %10
  end_access %9 : $*Array<Float>                  // id: %11
  end_borrow %7 : $Super                          // id: %12
  %13 = begin_borrow %4 : $Super                  // users: %18, %14
  %14 = ref_element_addr %13 : $Super, #Super.base // user: %15
  %15 = begin_access [modify] [dynamic] %14 : $*Float // users: %17, %16
  assign %0 to %15 : $*Float                      // id: %16
  end_access %15 : $*Float                        // id: %17
  end_borrow %13 : $Super                         // id: %18
  %19 = copy_value %4 : $Super                    // user: %21
  destroy_value %4 : $Super                       // id: %20
  return %19 : $Super                             // id: %21
} // end sil function '$s5error5SuperC4baseACSf_tcfc'

However, if a JVP/VJP is registered in the @differentiable, a type mismatch occurs for the initializer (which doesn't take a metatype):

SIL verification failed: VJP type does not match expected VJP type
  $@convention(method) (Float, @thick Super.Type) -> (@owned Super, @owned @callee_guaranteed (@guaranteed Super.AllDifferentiableVariables) -> Float)
  $@convention(method) (Float, @owned Super) -> (@owned Super, @owned @callee_guaranteed (@guaranteed Super.AllDifferentiableVariables) -> Float)
Verifying instruction:
     // function_ref Super.init(base:)
  %3 = function_ref @$s5error5SuperC4baseACSf_tcfc : $@convention(method) (Float, @owned Super) -> @owned Super // users: %7, %4
     // function_ref AD__$s5error5SuperC4baseACSf_tcfc__jvp_src_0_wrt_0
  %5 = function_ref @AD__$s5error5SuperC4baseACSf_tcfc__jvp_src_0_wrt_0 : $@convention(method) (Float, @owned Super) -> (@owned Super, @owned @callee_guaranteed (Float) -> @owned Super.AllDifferentiableVariables) // user: %7
     // function_ref static Super.vjpInit(base:)
  %6 = function_ref @$s5error5SuperC7vjpInit4baseAC_SfAC26AllDifferentiableVariablesVctSf_tFZ : $@convention(method) (Float, @thick Super.Type) -> (@owned Super, @owned @callee_guaranteed (@guaranteed Super.AllDifferentiableVariables) -> Float) // users: %8, %7
->   %7 = autodiff_function [wrt 0] [order 1] %3 : $@convention(method) (Float, @owned Super) -> @owned Super with {%5 : $@convention(method) (Float, @owned Super) -> (@owned Super, @owned @callee_guaranteed (Float) -> @owned Super.AllDifferentiableVariables), %6 : $@convention(method) (Float, @thick Super.Type) -> (@owned Super, @owned @callee_guaranteed (@guaranteed Super.AllDifferentiableVariables) -> Float)} // user: %9
     release_value %7 : $@differentiable @convention(method) (Float, @nondiff @owned Super) -> @owned Super // id: %9

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions