diff --git a/stdlib/public/Differentiation/TgmathDerivatives.swift.gyb b/stdlib/public/Differentiation/TgmathDerivatives.swift.gyb index 92fde0c6fbfa3..1a770907f5a55 100644 --- a/stdlib/public/Differentiation/TgmathDerivatives.swift.gyb +++ b/stdlib/public/Differentiation/TgmathDerivatives.swift.gyb @@ -130,6 +130,7 @@ func _${derivative_kind}Trunc ( } %end # for derivative_kind in ['jvp', 'vjp']: +// Unary functions %for derivative_kind in ['jvp', 'vjp']: % linear_map_kind = 'differential' if derivative_kind == 'jvp' else 'pullback' % for T in ['Float', 'Double', 'Float80']: @@ -271,3 +272,30 @@ func _${derivative_kind}Erfc(_ x: ${T}) -> (value: ${T}, ${linear_map_kind}: (${ % end # if T == 'Float80': % end # for T in ['Float', 'Double', 'Float80']: %end # for derivative_kind in ['jvp', 'vjp']: + +// Binary functions +%for T in ['Float', 'Float80']: +% if T == 'Float80': +#if !(os(Windows) || os(Android)) && (arch(i386) || arch(x86_64)) +% end +@inlinable +@derivative(of: pow) +func _vjpPow(_ x: ${T}, _ y: ${T}) -> (value: ${T}, pullback: (${T}) -> (${T}, ${T})) { + let value = pow(x, y) + return (value, { v in ( + v * y * pow(x, y - 1), v * value * log(x.isLessThanOrEqualTo(0) ? ${T}(1) : x) + ) }) +} + +@inlinable +@derivative(of: pow) +func _jvpPow(_ x: ${T}, _ y: ${T}) -> (value: ${T}, differential: (${T}, ${T}) -> ${T}) { + let value = pow(x, y) + return (value, { (dx, dy) in + dx * y * pow(x, y - 1) + dy * value * log(x.isLessThanOrEqualTo(0) ? ${T}(1) : x) + }) +} +% if T == 'Float80': +#endif +% end # if T == 'Float80': +%end # for T in ['Float', 'Float80']: diff --git a/test/AutoDiff/stdlib/tgmath_derivatives.swift.gyb b/test/AutoDiff/stdlib/tgmath_derivatives.swift.gyb index 08d019f7b2a5e..331955bf4fe39 100644 --- a/test/AutoDiff/stdlib/tgmath_derivatives.swift.gyb +++ b/test/AutoDiff/stdlib/tgmath_derivatives.swift.gyb @@ -26,7 +26,7 @@ func expectEqualWithTolerance(_ expected: TestLiteralType, _ actual: T, ulps allowed: T = 3, file: String = #file, line: UInt = #line) where T: BinaryFloatingPoint { - if actual == T(expected) || actual.isNaN && expected.isNaN { + if actual == T(expected) || actual.isNaN && expected.isNaN || actual.isInfinite && expected.isInfinite { return } // Compute error in ulp, compare to tolerance. @@ -38,17 +38,40 @@ func expectEqualWithTolerance(_ expected: TestLiteralType, _ actual: T, file: file, line: line) } +func computeDividedDifference ( + _ f: (T, T) -> T, + _ x: T, + _ y: T, + eps: T = 0.01 +) -> (dfdx: T, dfdy: T) { + let dfdx = (f(x + eps, y) - f(x, y)) / eps + let dfdy = (f(x, y + eps) - f(x, y)) / eps + return (dfdx, dfdy) +} + func checkGradient( _ f: @differentiable (T, T) -> T, _ x: T, - _ y: T) + _ y: T, + ulps: T = 192) where T == T.TangentVector { let eps = T(0.01) let grad = gradient(at: x, y, in: f) - let dfdx = (f(x + eps, y) - f(x, y)) / eps - let dfdy = (f(x, y + eps) - f(x, y)) / eps - expectEqualWithTolerance(TestLiteralType(dfdx), grad.0, ulps: 192) - expectEqualWithTolerance(TestLiteralType(dfdy), grad.1, ulps: 192) + let (dfdx, dfdy) = computeDividedDifference(f, x, y, eps: eps) + expectEqualWithTolerance(TestLiteralType(dfdx), grad.0, ulps: ulps) + expectEqualWithTolerance(TestLiteralType(dfdy), grad.1, ulps: ulps) +} + +func checkDerivative( + _ f: @differentiable (T, T) -> T, + _ x: T, + _ y: T, + ulps: T = 192) +where T == T.TangentVector { + let eps = T(0.01) + let deriv = derivative(at: x, y, in: f) + let (dfdx, dfdy) = computeDividedDifference(f, x, y, eps: eps) + expectEqualWithTolerance(TestLiteralType(dfdx + dfdy), deriv, ulps: ulps) } %for op in ['derivative', 'gradient']: @@ -111,6 +134,68 @@ DerivativeTests.test("${op}_${T}") { checkGradient({ fmod($0, $1) }, x, y) %else: # if op == 'derivative' // TODO(TF-1108): Implement JVPs for `remainder` and `fmod`. +%end + } + } + + // pow + let eps:${T} = 0.01 + let ulps:${T} = eps/eps.ulp + + // Checks for negative base. + for a in -3..<0 { + let x = ${T}(a) + for b in -3...3 { + let y = ${T}(b) + let expectedDx = y * pow(x, y - 1) + let expectedDy = ${T}.zero + let dpow = ${op}(at: x, y, in: pow) +%if op == 'gradient': + expectEqualWithTolerance(TestLiteralType(expectedDx), dpow.0) + expectEqualWithTolerance(TestLiteralType(expectedDy), dpow.1) +%else: # if op == 'derivative' + expectEqualWithTolerance(TestLiteralType(expectedDx + expectedDy), dpow) +%end + } + } + + // Checks for 0 base. + for b in -3...3 { + let y = ${T}(b) + var expectedValues: (dx: ${T}, dy: ${T})? + if y.isLess(than: 0) { + expectedValues = (dx: ${T}.infinity, dy: ${T}.nan) + } else if y.isZero { + expectedValues = (dx: ${T}.nan, dy: ${T}.zero) + } else if !y.isEqual(to: 1) { + expectedValues = (dx: ${T}.zero, dy: ${T}.zero) + } + if let (expectedDx, expectedDy) = expectedValues { + let dpow = ${op}(at: 0.0, y, in: pow) +%if op == 'gradient': + expectEqualWithTolerance(TestLiteralType(expectedDx), dpow.0) + expectEqualWithTolerance(TestLiteralType(expectedDy), dpow.1) +%else: # if op == 'derivative' + expectEqualWithTolerance(TestLiteralType(expectedDx + expectedDy), dpow) +%end + } else { +%if op == 'gradient': + checkGradient({ pow($0, $1) }, 0.0, y, ulps: ulps) +%else: # if op == 'derivative' + checkDerivative({ pow($0, $1) }, 0.0, y, ulps: ulps) +%end + } + } + + // Checks for positive base. + for a in 1...3 { + let x = ${T}(a) + for b in -3...3 { + let y = ${T}(b) +%if op == 'gradient': + checkGradient({ pow($0, $1) }, x, y, ulps: ulps) +%else: # if op == 'derivative' + checkDerivative({ pow($0, $1) }, x, y, ulps: ulps) %end } }