Skip to content

[AutoDiff] Derive Differentiable.zeroTangentVectorInitializer. #32064

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 29, 2020

Conversation

dan-zheng
Copy link
Contributor

Cherry-pick of #31823 from master branch.


Differentiable conformance derivation now supports
Differentiable.zeroTangentVectorInitializer.

zeroTangentVectorInitializer is a closure that produces a zero tangent vector,
capturing minimal necessary information from self.

It is an instance property, unlike the static property AdditiveArithmetic.zero,
and should be used by the differentiation transform for correctness.

There are two potential derivation cases:

  1. Memberwise derivation: done when TangentVector can be initialized memberwise.
  2. { TangentVector.zero } derivation: done as a fallback.

Remove Differentiable.zeroTangentVectorInitializer dummy default implementation.

Update stdlib Differentiable conformances and tests.
Clean up DerivedConformanceDifferentiable.cpp cruft.

Resolves TF-1007.
Progress towards TF-1008: differentiation correctness for projection operations.


Examples:

import _Differentiation

struct Struct<T: Differentiable, U: Differentiable>: Differentiable {
  var x: T
  var y: U

  // Compiler synthesizes:
  // var zeroTangentVectorInitializer: () -> TangentVector {
  //   { [xZeroTanInit = x.zeroTangentVectorInitializer,
  //      yZeroTanInit = y.zeroTangentVectorInitializer] in
  //     return TangentVector(x: xZeroTanInit(), y: yZeroTanInit())
  //   }
  // }
}

let s = Struct(x: [1, 2, 3], y: [[4, 5, 6], [], [2]])
// `Differentiable.zeroTangentVector` default implementation calls
// `zeroTangentVectorInitializer`.
print(s.zeroTangentVector)

// Before (via dummy `zeroTangentVectorInitializer` default implementation):
// TangentVector(x: [], y: [])

// After (via derived conformances):
// TangentVector(x: [0.0, 0.0, 0.0], y: [[0.0, 0.0, 0.0], [], [0.0]])
struct CustomTangentVector<T: Differentiable, U: Differentiable>: Differentiable {
  var x: T
  var y: U

  typealias TangentVector = T.TangentVector
  mutating func move(along direction: TangentVector) {}

  // Memberwise synthesis is not possible for custom `TangentVector` type.
  // Fallback synthesis is done instead:
  // var zeroTangentVectorInitializer: () -> TangentVector {
  //   { TangentVector.zero }
  // }
}

@dan-zheng dan-zheng added the tensorflow This is for "tensorflow" branch PRs. label May 28, 2020
@dan-zheng
Copy link
Contributor Author

@swift-ci Please test tensorflow

`Differentiable` conformance derivation now supports
`Differentiable.zeroTangentVectorInitializer`.

`zeroTangentVectorInitializer` is a closure that produces a zero tangent vector,
capturing minimal necessary information from `self`.

It is an instance property, unlike the static property `AdditiveArithmetic.zero`,
and should be used by the differentiation transform for correctness.

There are two potential derivation cases:
1. Memberwise derivation: done when `TangentVector` can be initialized memberwise.
2. `{ TangentVector.zero }` derivation: done as a fallback.

Remove `Differentiable.zeroTangentVectorInitializer` dummy default implementation.

Update stdlib `Differentiable` conformances and tests.
Clean up DerivedConformanceDifferentiable.cpp cruft.

Resolves TF-1007.
Progress towards TF-1008: differentiation correctness for projection operations.
@dan-zheng dan-zheng force-pushed the zero-tangent-vector-initializer-tf branch from e48c839 to ee23dcf Compare May 29, 2020 03:38
@dan-zheng
Copy link
Contributor Author

@swift-ci Please clean test TensorFlow

@dan-zheng dan-zheng merged commit eedaab7 into tensorflow May 29, 2020
@dan-zheng dan-zheng deleted the zero-tangent-vector-initializer-tf branch May 29, 2020 08:58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
tensorflow This is for "tensorflow" branch PRs.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant