diff --git a/stdlib/public/Differentiation/Differentiable.swift b/stdlib/public/Differentiation/Differentiable.swift index 635005d750974..1d67347e8a43e 100644 --- a/stdlib/public/Differentiation/Differentiable.swift +++ b/stdlib/public/Differentiation/Differentiable.swift @@ -36,17 +36,39 @@ public protocol Differentiable { mutating func move(along direction: TangentVector) // SWIFT_ENABLE_TENSORFLOW - /// A tangent vector such that `move(along: zeroTangentVector)` will not - /// modify `self`. - /// - Note: `zeroTangentVector` can be `TangentVector.zero` in most cases, - /// but types whose tangent vectors depend on instance properties of `self` - /// need to provide a different implementation. For example, the tangent - /// vector of an `Array` depends on the array's `count`. - @available(*, deprecated, message: """ - `zeroTangentVector` derivation has not been implemented; do not use \ - this property - """) - var zeroTangentVector: TangentVector { get } + /// A closure that produces a zero tangent vector, capturing minimal + /// necessary information from `self`. + /// + /// `move(along: zeroTangentVectorInitializer())` should not modify + /// `self`. + /// + /// In some cases, the zero tangent vector of `self` is equal to + /// `TangentVector.zero`. In other cases, the zero tangent vector depends on + /// information in `self`, such as shape for an n-dimensional array type. + /// For differentiable programming, it is more memory-efficient to define a + /// custom `zeroTangentVectorInitializer` property which returns a closure + /// that captures and uses only the necessary information to create a zero + /// tangent vector. For example: + /// + /// struct Vector { + /// var scalars: [Float] + /// var count: Int { scalars.count } + /// init(scalars: [Float]) { ... } + /// init(repeating repeatedElement: Float, count: Int) { ... } + /// } + /// + /// extension Vector: AdditiveArithmetic { ... } + /// + /// extension Vector: Differentiable { + /// typealias TangentVector = Vector + /// + /// @noDerivative + /// var zeroTangentVectorInitializer: () -> TangentVector { + /// let count = self.count + /// return { TangentVector(repeating: 0, count: count) } + /// } + /// } + var zeroTangentVectorInitializer: () -> TangentVector { get } // SWIFT_ENABLE_TENSORFLOW END } @@ -59,12 +81,23 @@ public extension Differentiable where TangentVector == Self { // SWIFT_ENABLE_TENSORFLOW public extension Differentiable { - // This is a temporary solution that allows us to add `zeroTangentVector` - // without implementing derived conformances. This property is marked - // unavailable because it will produce incorrect results when tangent vectors - // depend on instance properties of `self`. - // FIXME: Implement derived conformance and remove this default + // This is a temporary solution enabling the addition of + // `zeroTangentVectorInitializer` without implementing derived conformances. + // This property will produce incorrect results when tangent vectors depend + // on instance-specific information from `self`. + // FIXME: Implement derived conformances and remove this default // implementation. - var zeroTangentVector: TangentVector { .zero } + @available(*, deprecated, message: """ + `zeroTangentVectorInitializer` derivation has not been implemented; this \ + default implementation is not correct when tangent vectors depend on \ + instance-specific information from `self` and should not be used + """) + var zeroTangentVectorInitializer: () -> TangentVector { + { TangentVector.zero } + } + + /// A tangent vector initialized using `zeroTangentVectorInitializer`. + /// `move(along: zeroTangentVector)` should not modify `self`. + var zeroTangentVector: TangentVector { zeroTangentVectorInitializer() } } // SWIFT_ENABLE_TENSORFLOW END