diff --git a/stdlib/public/Differentiation/ArrayDifferentiation.swift b/stdlib/public/Differentiation/ArrayDifferentiation.swift index cdd85bd6c1d7e..265746ce58448 100644 --- a/stdlib/public/Differentiation/ArrayDifferentiation.swift +++ b/stdlib/public/Differentiation/ArrayDifferentiation.swift @@ -242,6 +242,33 @@ extension Array where Element: Differentiable { } } +extension Array where Element: Differentiable { + @usableFromInline + @derivative(of: +=) + static func _vjpAppend(_ lhs: inout Self, _ rhs: Self) -> ( + value: Void, pullback: (inout TangentVector) -> TangentVector + ) { + let lhsCount = lhs.count + lhs += rhs + return ((), { v in + let drhs = + TangentVector(.init(v.base.dropFirst(lhsCount))) + let rhsCount = drhs.base.count + v.base.removeLast(rhsCount) + return drhs + }) + } + + @usableFromInline + @derivative(of: +=) + static func _jvpAppend(_ lhs: inout Self, _ rhs: Self) -> ( + value: Void, differential: (inout TangentVector, TangentVector) -> Void + ) { + lhs += rhs + return ((), { $0.base += $1.base }) + } +} + extension Array where Element: Differentiable { @usableFromInline @derivative(of: init(repeating:count:)) diff --git a/test/AutoDiff/stdlib/array.swift b/test/AutoDiff/stdlib/array.swift index 06807cc9c32ee..bce8fcfd14922 100644 --- a/test/AutoDiff/stdlib/array.swift +++ b/test/AutoDiff/stdlib/array.swift @@ -317,42 +317,53 @@ ArrayAutoDiffTests.test("ExpressibleByArrayLiteralIndirect") { } ArrayAutoDiffTests.test("Array.+") { - struct TwoArrays : Differentiable { - var a: [Float] - var b: [Float] + func sumFirstThreeConcatenating(_ a: [Float], _ b: [Float]) -> Float { + let c = a + b + return c[0] + c[1] + c[2] } - func sumFirstThreeConcatenated(_ arrs: TwoArrays) -> Float { - let c = arrs.a + arrs.b + expectEqual( + (.init([1, 1]), .init([1, 0])), + gradient(at: [0, 0], [0, 0], in: sumFirstThreeConcatenating)) + expectEqual( + (.init([1, 1, 1, 0]), .init([0, 0])), + gradient(at: [0, 0, 0, 0], [0, 0], in: sumFirstThreeConcatenating)) + expectEqual( + (.init([]), .init([1, 1, 1, 0])), + gradient(at: [], [0, 0, 0, 0], in: sumFirstThreeConcatenating)) + + func identity(_ array: [Float]) -> [Float] { + var results: [Float] = [] + for i in withoutDerivative(at: array.indices) { + results = results + [array[i]] + } + return results + } + let v = FloatArrayTan([4, -5, 6]) + expectEqual(v, pullback(at: [1, 2, 3], in: identity)(v)) +} + +ArrayAutoDiffTests.test("Array.+=") { + func sumFirstThreeConcatenating(_ a: [Float], _ b: [Float]) -> Float { + var c = a + c += b return c[0] + c[1] + c[2] } expectEqual( - TwoArrays.TangentVector( - a: FloatArrayTan([1, 1]), - b: FloatArrayTan([1, 0])), - gradient( - at: TwoArrays(a: [0, 0], b: [0, 0]), - in: sumFirstThreeConcatenated)) + (.init([1, 1]), .init([1, 0])), + gradient(at: [0, 0], [0, 0], in: sumFirstThreeConcatenating)) expectEqual( - TwoArrays.TangentVector( - a: FloatArrayTan([1, 1, 1, 0]), - b: FloatArrayTan([0, 0])), - gradient( - at: TwoArrays(a: [0, 0, 0, 0], b: [0, 0]), - in: sumFirstThreeConcatenated)) + (.init([1, 1, 1, 0]), .init([0, 0])), + gradient(at: [0, 0, 0, 0], [0, 0], in: sumFirstThreeConcatenating)) expectEqual( - TwoArrays.TangentVector( - a: FloatArrayTan([]), - b: FloatArrayTan([1, 1, 1, 0])), - gradient( - at: TwoArrays(a: [], b: [0, 0, 0, 0]), - in: sumFirstThreeConcatenated)) + (.init([]), .init([1, 1, 1, 0])), + gradient(at: [], [0, 0, 0, 0], in: sumFirstThreeConcatenating)) func identity(_ array: [Float]) -> [Float] { var results: [Float] = [] for i in withoutDerivative(at: array.indices) { - results = results + [array[i]] + results += [array[i]] } return results }