Skip to content

[SR-13166] Default derivative implementations for protocol requirements #54231

@dan-zheng

Description

@dan-zheng
Previous ID SR-13166
Radar rdar://problem/69987698
Original Reporter @dan-zheng
Type New Feature
Additional Detail from JIRA
Votes 1
Component/s Compiler
Labels New Feature, AutoDiff
Assignee @rxwei
Priority Medium

md5: b0201d55480662f7a4792a7ddd4694d2

Sub-Tasks:

Issue Description:

Overview

Default derivative implementations enables protocol requirements (like requirements from AdditiveArithmetic, FloatingPoint, ElementaryFunctions, etc) to be differentiable by default.

See "default derivatives and transposes" from the Differentiable Programming Manifesto for more info.

Example:

// In the standard library:
// public protocol AdditiveArithmetic: Equatable {
//   static func +(lhs: Self, rhs: Self) -> Self
//   ...
// }

extension AdditiveArithmetic where Self: Differentiable {
  @derivative(of: +)
  static func vjpAdd(_ lhs: Self, _ rhs: Self) ->
    (value: Self, pullback: (TangentVector) -> (TangentVector, TangentVector)) {
    return (lhs + rhs, { v in (v, v) })
  }
}

This is important for idiomatic protocol-oriented programming to avoid tons of code duplication. Without this feature, all conforming types of the protocol must:

  • Define a @differentiable concrete implementation of the original protocol requirement, if none already exists (because a default implementation was used.)

  • Register a derivative function for the concrete implementation.


Details

There are two cases to consider:

1. Non-@differentiable protocol requirement.

protocol P {
  func foo(_ x: Float) -> Float
}
extension P {
  @derivative(of: foo)
  func vjpFoo(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
    return (x, { $0 })
  }
}

2. @differentiable protocol requirement.

protocol P {
  @differentiable
  func foo(_ x: Float) -> Float
}
extension P {
  @derivative(of: foo)
  func vjpFoo(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
    return (x, { $0 })
  }
}

Supporting this may require lifting the current restriction that "all implementations of @differentiable protocol requirements must themselves be marked as @differentiable":

protocol AdditiveArithmetic: Equatable {
  static func +(lhs: Self, rhs: Self) -> Self
}

extension AdditiveArithmetic where Self: Differentiable {
  @derivative(of: +)
  static func vjpAdd(_ lhs: Self, _ rhs: Self) ->
    (value: Self, pullback: (TangentVector) -> (TangentVector, TangentVector)) {
    return (lhs + rhs, { v in (v, v) })
  }
}

struct Foo: AdditiveArithmetic {
  static func +(lhs: Self, rhs: Self) -> Self {
    lhs
  }
}
deriv.swift:13:8: error: type 'Foo' does not conform to protocol 'AdditiveArithmetic'
struct Foo: AdditiveArithmetic {
       ^
deriv.swift:14:15: note: candidate is missing attribute '@differentiable(wrt: (lhs, rhs) where Self : Differentiable)'
  static func +(lhs: Self, rhs: Self) -> Self {
              ^

SIL default witness table support may be needed.

Metadata

Metadata

Assignees

Labels

AutoDiffcompilerThe Swift compiler itselfconformancesFeature → protocol: protocol conformancesdefault implementationsFeature → protocol → conformances: default implementationsfeatureA feature request or implementation

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions