Skip to content

[AutoDiff] Add Differentiable.zeroTangentVectorInitializer. #28416

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 50 additions & 17 deletions stdlib/public/Differentiation/Differentiable.swift
Original file line number Diff line number Diff line change
Expand Up @@ -36,17 +36,39 @@ public protocol Differentiable {
mutating func move(along direction: TangentVector)

// SWIFT_ENABLE_TENSORFLOW
/// A tangent vector such that `move(along: zeroTangentVector)` will not
/// modify `self`.
/// - Note: `zeroTangentVector` can be `TangentVector.zero` in most cases,
/// but types whose tangent vectors depend on instance properties of `self`
/// need to provide a different implementation. For example, the tangent
/// vector of an `Array` depends on the array's `count`.
@available(*, deprecated, message: """
`zeroTangentVector` derivation has not been implemented; do not use \
this property
""")
var zeroTangentVector: TangentVector { get }
/// A closure that produces a zero tangent vector, capturing minimal
/// necessary information from `self`.
///
/// `move(along: zeroTangentVectorInitializer())` should not modify
/// `self`.
///
/// In some cases, the zero tangent vector of `self` is equal to
/// `TangentVector.zero`. In other cases, the zero tangent vector depends on
/// information in `self`, such as shape for an n-dimensional array type.
/// For differentiable programming, it is more memory-efficient to define a
/// custom `zeroTangentVectorInitializer` property which returns a closure
/// that captures and uses only the necessary information to create a zero
/// tangent vector. For example:
///
/// struct Vector {
/// var scalars: [Float]
/// var count: Int { scalars.count }
/// init(scalars: [Float]) { ... }
/// init(repeating repeatedElement: Float, count: Int) { ... }
/// }
///
/// extension Vector: AdditiveArithmetic { ... }
///
/// extension Vector: Differentiable {
/// typealias TangentVector = Vector
///
/// @noDerivative
/// var zeroTangentVectorInitializer: () -> TangentVector {
/// let count = self.count
/// return { TangentVector(repeating: 0, count: count) }
/// }
/// }
var zeroTangentVectorInitializer: () -> TangentVector { get }
// SWIFT_ENABLE_TENSORFLOW END
}

Expand All @@ -59,12 +81,23 @@ public extension Differentiable where TangentVector == Self {

// SWIFT_ENABLE_TENSORFLOW
public extension Differentiable {
// This is a temporary solution that allows us to add `zeroTangentVector`
// without implementing derived conformances. This property is marked
// unavailable because it will produce incorrect results when tangent vectors
// depend on instance properties of `self`.
// FIXME: Implement derived conformance and remove this default
// This is a temporary solution enabling the addition of
// `zeroTangentVectorInitializer` without implementing derived conformances.
// This property will produce incorrect results when tangent vectors depend
// on instance-specific information from `self`.
// FIXME: Implement derived conformances and remove this default
// implementation.
var zeroTangentVector: TangentVector { .zero }
@available(*, deprecated, message: """
`zeroTangentVectorInitializer` derivation has not been implemented; this \
default implementation is not correct when tangent vectors depend on \
instance-specific information from `self` and should not be used
""")
var zeroTangentVectorInitializer: () -> TangentVector {
{ TangentVector.zero }
}

/// A tangent vector initialized using `zeroTangentVectorInitializer`.
/// `move(along: zeroTangentVector)` should not modify `self`.
var zeroTangentVector: TangentVector { zeroTangentVectorInitializer() }
}
// SWIFT_ENABLE_TENSORFLOW END