diff --git a/lib/SILOptimizer/SILCombiner/SILCombinerApplyVisitors.cpp b/lib/SILOptimizer/SILCombiner/SILCombinerApplyVisitors.cpp index e7f44ba1f3a8c..23134192c9a1a 100644 --- a/lib/SILOptimizer/SILCombiner/SILCombinerApplyVisitors.cpp +++ b/lib/SILOptimizer/SILCombiner/SILCombinerApplyVisitors.cpp @@ -1500,6 +1500,13 @@ SILInstruction *SILCombiner::visitApplyInst(ApplyInst *AI) { } } + // (apply (differentiable_function f)) to (apply f) + if (auto *DFI = dyn_cast(AI->getCallee())) { + return cloneFullApplySiteReplacingCallee(AI, DFI->getOperand(0), + Builder.getBuilderContext()) + .getInstruction(); + } + // (apply (thin_to_thick_function f)) to (apply f) if (auto *TTTFI = dyn_cast(AI->getCallee())) { // We currently don't remove any possible retain associated with the thick diff --git a/lib/SILOptimizer/SILCombiner/SILCombinerCastVisitors.cpp b/lib/SILOptimizer/SILCombiner/SILCombinerCastVisitors.cpp index 8b853d30b260f..c7fac705fdc0f 100644 --- a/lib/SILOptimizer/SILCombiner/SILCombinerCastVisitors.cpp +++ b/lib/SILOptimizer/SILCombiner/SILCombinerCastVisitors.cpp @@ -1110,39 +1110,42 @@ SILInstruction *SILCombiner::visitConvertEscapeToNoEscapeInst( // %vjp' = convert_escape_to_noescape %vjp // %y = differentiable_function(%orig', %jvp', %vjp') if (auto *DFI = dyn_cast(Cvt->getOperand())) { - auto createConvertEscapeToNoEscape = [&](NormalDifferentiableFunctionTypeComponent extractee) { - if (!DFI->hasExtractee(extractee)) - return SILValue(); + if (DFI->hasOneUse()) { + auto createConvertEscapeToNoEscape = + [&](NormalDifferentiableFunctionTypeComponent extractee) { + if (!DFI->hasExtractee(extractee)) + return SILValue(); - auto operand = DFI->getExtractee(extractee); - auto fnType = operand->getType().castTo(); - auto noEscapeFnType = - fnType->getWithExtInfo(fnType->getExtInfo().withNoEscape()); - auto noEscapeType = SILType::getPrimitiveObjectType(noEscapeFnType); - return Builder.createConvertEscapeToNoEscape( - operand.getLoc(), operand, noEscapeType, Cvt->isLifetimeGuaranteed())->getResult(0); - }; + auto operand = DFI->getExtractee(extractee); + auto fnType = operand->getType().castTo(); + auto noEscapeFnType = + fnType->getWithExtInfo(fnType->getExtInfo().withNoEscape()); + auto noEscapeType = SILType::getPrimitiveObjectType(noEscapeFnType); + return Builder.createConvertEscapeToNoEscape( + operand.getLoc(), operand, noEscapeType, Cvt->isLifetimeGuaranteed())->getResult(0); + }; - SILValue originalNoEscape = - createConvertEscapeToNoEscape(NormalDifferentiableFunctionTypeComponent::Original); - SILValue convertedJVP = createConvertEscapeToNoEscape( - NormalDifferentiableFunctionTypeComponent::JVP); - SILValue convertedVJP = createConvertEscapeToNoEscape( - NormalDifferentiableFunctionTypeComponent::VJP); - - llvm::Optional> derivativeFunctions; - if (convertedJVP && convertedVJP) - derivativeFunctions = std::make_pair(convertedJVP, convertedVJP); - - auto *newDFI = Builder.createDifferentiableFunction( - DFI->getLoc(), DFI->getParameterIndices(), DFI->getResultIndices(), - originalNoEscape, derivativeFunctions); - assert(newDFI->getType() == Cvt->getType() && - "New `@differentiable` function instruction should have same type " - "as the old `convert_escape_to_no_escape` instruction"); - return newDFI; - } + SILValue originalNoEscape = + createConvertEscapeToNoEscape(NormalDifferentiableFunctionTypeComponent::Original); + SILValue convertedJVP = createConvertEscapeToNoEscape( + NormalDifferentiableFunctionTypeComponent::JVP); + SILValue convertedVJP = createConvertEscapeToNoEscape( + NormalDifferentiableFunctionTypeComponent::VJP); + + llvm::Optional> derivativeFunctions; + if (convertedJVP && convertedVJP) + derivativeFunctions = std::make_pair(convertedJVP, convertedVJP); + auto *newDFI = Builder.createDifferentiableFunction( + DFI->getLoc(), DFI->getParameterIndices(), DFI->getResultIndices(), + originalNoEscape, derivativeFunctions); + assert(newDFI->getType() == Cvt->getType() && + "New `@differentiable` function instruction should have same type " + "as the old `convert_escape_to_no_escape` instruction"); + return newDFI; + } + } + return nullptr; } diff --git a/test/AutoDiff/SILOptimizer/differential_apply.swift b/test/AutoDiff/SILOptimizer/differential_apply.swift new file mode 100644 index 0000000000000..3474564ebb9f5 --- /dev/null +++ b/test/AutoDiff/SILOptimizer/differential_apply.swift @@ -0,0 +1,48 @@ +// RUN: %target-swift-frontend -emit-sil -O %s | %FileCheck %s +// REQUIRES: swift_in_compiler + +import _Differentiation + +@differentiable(reverse) +@_silgen_name("test_f") +// Check that (differentiable) closure apply is optimized out +// CHECK-LABEL: test_f : $@convention(thin) (@guaranteed Array) -> Double +// CHECK-NOT: differentiable_function [parameters 0] [results 0] +func f(array: [Double]) -> Double { + var array = array + array.update(at: 1, + byCalling: { + (element: inout Double) in + let initialElement = element; + element *= initialElement + } + ) + + return 0.0 +} + +public func valueWithPullback(at x: T, of f: @differentiable(reverse) (inout T) -> Void) -> (value: Void, pullback: (inout T.TangentVector) -> Void) {fatalError()} +public func pullback(at x: T, of f: @differentiable(reverse) (inout T) -> Void) -> (inout T.TangentVector) -> Void {return valueWithPullback(at: x, of: f).pullback} + +public extension Array { + @differentiable(reverse) + mutating func update(at index: Int, + byCalling closure: @differentiable(reverse) (inout Element) -> Void) where Element: Differentiable { + closure(&self[index]) + } +} + +public extension Array where Element: Differentiable { + @derivative(of: update(at:byCalling:)) + mutating func vjpUpdate(at index: Int, byCalling closure: @differentiable(reverse) (inout Element) -> Void) -> (value: Void, pullback: (inout Self.TangentVector) -> Void) { + let closurePullback = pullback(at: self[index], of: closure) + return (value: (), pullback: { closurePullback(&$0.base[index]) }) + } +} + +public struct D { + public subscript(_ index: I) -> D? { + get {fatalError()} + set {fatalError()} + } +}