diff --git a/stdlib/public/Differentiation/DifferentiationSupport.swift b/stdlib/public/Differentiation/DifferentiationSupport.swift index 01bdd57cedef6..180a7e75f6f80 100644 --- a/stdlib/public/Differentiation/DifferentiationSupport.swift +++ b/stdlib/public/Differentiation/DifferentiationSupport.swift @@ -1262,7 +1262,36 @@ public extension Array where Element: Differentiable { } @usableFromInline - @derivative(of: differentiableReduce) + @derivative(of: differentiableReduce, wrt: (self, initialResult)) + func _jvpDifferentiableReduce( + _ initialResult: Result, + _ nextPartialResult: @differentiable (Result, Element) -> Result + ) -> (value: Result, + differential: (Array.TangentVector, Result.TangentVector) + -> Result.TangentVector) { + var differentials: + [(Result.TangentVector, Element.TangentVector) -> Result.TangentVector] + = [] + let count = self.count + differentials.reserveCapacity(count) + var result = initialResult + for element in self { + let (y, df) = + Swift.valueWithDifferential(at: result, element, in: nextPartialResult) + result = y + differentials.append(df) + } + return (value: result, differential: { dSelf, dInitial in + var dResult = dInitial + for (dElement, df) in zip(dSelf.base, differentials) { + dResult = df(dResult, dElement) + } + return dResult + }) + } + + @usableFromInline + @derivative(of: differentiableReduce, wrt: (self, initialResult)) internal func _vjpDifferentiableReduce( _ initialResult: Result, _ nextPartialResult: @differentiable (Result, Element) -> Result