-
Notifications
You must be signed in to change notification settings - Fork 10.5k
Description
Previous ID | SR-13166 |
Radar | rdar://problem/69987698 |
Original Reporter | @dan-zheng |
Type | New Feature |
Additional Detail from JIRA
Votes | 1 |
Component/s | Compiler |
Labels | New Feature, AutoDiff |
Assignee | @rxwei |
Priority | Medium |
md5: b0201d55480662f7a4792a7ddd4694d2
Sub-Tasks:
- [SR-13168] SILGen: add SIL default witness table entries for default derivatives of protocol requirements #55610
- [SR-13169] Sema: lift derivative registration restriction for protocol requirements #55611
- [SR-13167] Parse: add parsing/syntax support for
@differentiable(default, ...)
#55609
Issue Description:
Overview
Default derivative implementations enables protocol requirements (like requirements from AdditiveArithmetic
, FloatingPoint
, ElementaryFunctions
, etc) to be differentiable by default.
See "default derivatives and transposes" from the Differentiable Programming Manifesto for more info.
Example:
// In the standard library:
// public protocol AdditiveArithmetic: Equatable {
// static func +(lhs: Self, rhs: Self) -> Self
// ...
// }
extension AdditiveArithmetic where Self: Differentiable {
@derivative(of: +)
static func vjpAdd(_ lhs: Self, _ rhs: Self) ->
(value: Self, pullback: (TangentVector) -> (TangentVector, TangentVector)) {
return (lhs + rhs, { v in (v, v) })
}
}
This is important for idiomatic protocol-oriented programming to avoid tons of code duplication. Without this feature, all conforming types of the protocol must:
-
Define a
@differentiable
concrete implementation of the original protocol requirement, if none already exists (because a default implementation was used.) -
Register a derivative function for the concrete implementation.
Details
There are two cases to consider:
1. Non-@differentiable
protocol requirement.
protocol P {
func foo(_ x: Float) -> Float
}
extension P {
@derivative(of: foo)
func vjpFoo(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
return (x, { $0 })
}
}
2. @differentiable
protocol requirement.
protocol P {
@differentiable
func foo(_ x: Float) -> Float
}
extension P {
@derivative(of: foo)
func vjpFoo(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
return (x, { $0 })
}
}
Supporting this may require lifting the current restriction that "all implementations of @differentiable
protocol requirements must themselves be marked as @differentiable
":
protocol AdditiveArithmetic: Equatable {
static func +(lhs: Self, rhs: Self) -> Self
}
extension AdditiveArithmetic where Self: Differentiable {
@derivative(of: +)
static func vjpAdd(_ lhs: Self, _ rhs: Self) ->
(value: Self, pullback: (TangentVector) -> (TangentVector, TangentVector)) {
return (lhs + rhs, { v in (v, v) })
}
}
struct Foo: AdditiveArithmetic {
static func +(lhs: Self, rhs: Self) -> Self {
lhs
}
}
deriv.swift:13:8: error: type 'Foo' does not conform to protocol 'AdditiveArithmetic'
struct Foo: AdditiveArithmetic {
^
deriv.swift:14:15: note: candidate is missing attribute '@differentiable(wrt: (lhs, rhs) where Self : Differentiable)'
static func +(lhs: Self, rhs: Self) -> Self {
^
SIL default witness table support may be needed.