Skip to content

[AutoDiff] devirtualize diff witnesses #28480

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Nov 26, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions include/swift/SIL/SILDifferentiabilityWitness.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,9 @@ class SILDifferentiabilityWitness
GenericSignature derivativeGenSig, SILFunction *jvp, SILFunction *vjp,
bool isSerialized, DeclAttribute *attribute = nullptr);

void convertToDefinition(SILFunction *jvp, SILFunction *vjp,
bool isSerialized);

SILDifferentiabilityWitnessKey getKey() const;
SILModule &getModule() const { return Module; }
SILLinkage getLinkage() const { return Linkage; }
Expand Down
4 changes: 4 additions & 0 deletions include/swift/SIL/SILModule.h
Original file line number Diff line number Diff line change
Expand Up @@ -613,6 +613,10 @@ class SILModule {
/// Look up the differentiability witness corresponding to the given key.
SILDifferentiabilityWitness *
lookUpDifferentiabilityWitness(SILDifferentiabilityWitnessKey key);

/// Attempt to deserialize the SILDifferentiabilityWitness. Returns true if
/// deserialization succeeded, false otherwise.
bool loadDifferentiabilityWitness(SILDifferentiabilityWitness *W);
// SWIFT_ENABLE_TENSORFLOW_END

// Given a protocol, attempt to create a default witness table declaration
Expand Down
3 changes: 3 additions & 0 deletions include/swift/SILOptimizer/PassManager/Passes.def
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,9 @@ PASS(DiagnoseUnreachable, "diagnose-unreachable",
"Diagnose Unreachable Code")
PASS(DiagnosticConstantPropagation, "diagnostic-constant-propagation",
"Constants Propagation for Diagnostics")
PASS(DifferentiabilityWitnessDevirtualizer,
"differentiability-witness-devirtualizer",
"Inlines Differentiability Witnesses")
PASS(EagerSpecializer, "eager-specializer",
"Eager Specialization via @_specialize")
PASS(EarlyCodeMotion, "early-codemotion",
Expand Down
10 changes: 10 additions & 0 deletions lib/SIL/SILDifferentiabilityWitness.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,16 @@ SILDifferentiabilityWitness *SILDifferentiabilityWitness::createDefinition(
return diffWitness;
}

void SILDifferentiabilityWitness::convertToDefinition(SILFunction *jvp,
SILFunction *vjp,
bool isSerialized) {
assert(IsDeclaration);
IsDeclaration = false;
JVP = jvp;
VJP = vjp;
IsSerialized = isSerialized;
}

SILDifferentiabilityWitnessKey SILDifferentiabilityWitness::getKey() const {
return std::make_pair(getOriginalFunction()->getName(), getConfig());
}
9 changes: 9 additions & 0 deletions lib/SIL/SILModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -599,6 +599,15 @@ SILModule::lookUpDifferentiabilityWitness(SILDifferentiabilityWitnessKey key) {
mangler.mangleSILDifferentiabilityWitnessKey(key));
}

bool SILModule::loadDifferentiabilityWitness(SILDifferentiabilityWitness *W) {
auto *NewW = getSILLoader()->lookupDifferentiabilityWitness(W->getKey());
if (!NewW)
return false;

assert(W == NewW);
return true;
}

void SILModule::registerDeserializationNotificationHandler(
std::unique_ptr<DeserializationNotificationHandler> &&handler) {
deserializationNotificationHandlers.add(std::move(handler));
Expand Down
9 changes: 3 additions & 6 deletions lib/SILGen/SILGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -824,13 +824,10 @@ void SILGenModule::emitDifferentiabilityWitness(
derivativeCanGenSig = derivativeGenSig->getCanonicalSignature();
// Create new SIL differentiability witness.
// Witness JVP and VJP are set below.
// TODO(TF-919): Explore creating serialized differentiability witnesses.
// Currently, differentiability witnesses are never serialized to avoid
// deserialization issues where JVP/VJP functions cannot be found.
auto *diffWitness = SILDifferentiabilityWitness::createDefinition(
M, originalFunction->getLinkage(), originalFunction,
loweredParamIndices, config.resultIndices, derivativeCanGenSig,
/*jvp*/ nullptr, /*vjp*/ nullptr, /*isSerialized*/ false);
M, originalFunction->getLinkage(), originalFunction, loweredParamIndices,
config.resultIndices, derivativeCanGenSig,
/*jvp*/ nullptr, /*vjp*/ nullptr, originalFunction->isSerialized());

// Set derivative function in differentiability witness.
auto setDerivativeInDifferentiabilityWitness =
Expand Down
5 changes: 5 additions & 0 deletions lib/SILOptimizer/PassManager/PassPipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,11 @@ static void addPerfEarlyModulePassPipeline(SILPassPipelinePlan &P) {
// we do not spend time optimizing them.
P.addDeadFunctionElimination();

// SWIFT_ENABLE_TENSORFLOW
// This unblocks many other passes' optimizations (e.g. inlining) and this is
// not blocked by any other passes' optimizations, so do it early.
P.addDifferentiabilityWitnessDevirtualizer();

// Strip ownership from non-transparent functions.
if (P.getOptions().StripOwnershipAfterSerialization)
P.addNonTransparentFunctionOwnershipModelEliminator();
Expand Down
3 changes: 3 additions & 0 deletions lib/SILOptimizer/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ silopt_register_sources(
DeadStoreElimination.cpp
DestroyHoisting.cpp
Devirtualizer.cpp
# SWIFT_ENABLE_TENSORFLOW
DifferentiabilityWitnessDevirtualizer.cpp
# SWIFT_ENABLE_TENSORFLOW_END
GenericSpecializer.cpp
MergeCondFail.cpp
Outliner.cpp
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
//===--- DifferentiabilityWitnessDevirtualizer.cpp ------------------------===//
//
// This source file is part of the Swift.org open source project
//
// Copyright (c) 2019 Apple Inc. and the Swift project authors
// Licensed under Apache License v2.0 with Runtime Library Exception
//
// See https://swift.org/LICENSE.txt for license information
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
//
//===----------------------------------------------------------------------===//
//
// Devirtualized differentiability witnesses whose bodies are availabe, by
// turning "differentiability_witness_function" instructions into "function_ref"
// instructions referencing the appropriate function.
//
//===----------------------------------------------------------------------===//

#include "swift/SIL/SILBuilder.h"
#include "swift/SIL/SILFunction.h"
#include "swift/SIL/SILInstruction.h"
#include "swift/SILOptimizer/PassManager/Transforms.h"

using namespace swift;

namespace {
class DifferentiabilityWitnessDevirtualizer : public SILFunctionTransform {

/// Returns true if and changes were made.
bool devirtualizeDifferentiabilityWitnessesInFunction(SILFunction &f);

/// The entry point to the transformation.
void run() override {
if (devirtualizeDifferentiabilityWitnessesInFunction(*getFunction()))
invalidateAnalysis(SILAnalysis::InvalidationKind::CallsAndInstructions);
}
};
} // end anonymous namespace

bool DifferentiabilityWitnessDevirtualizer::
devirtualizeDifferentiabilityWitnessesInFunction(SILFunction &f) {
bool changed = false;
llvm::SmallVector<DifferentiabilityWitnessFunctionInst *, 8> insts;
for (auto &bb : f) {
for (auto &inst : bb) {
auto *dfwi = dyn_cast<DifferentiabilityWitnessFunctionInst>(&inst);
if (!dfwi)
continue;
insts.push_back(dfwi);
}
}
for (auto *inst : insts) {
auto *wit = inst->getWitness();
if (wit->isDeclaration())
f.getModule().loadDifferentiabilityWitness(wit);
if (wit->isDeclaration())
continue;
changed = true;
SILBuilderWithScope builder(inst);
auto kind = inst->getWitnessKind().getAsDerivativeFunctionKind();
assert(kind.hasValue());
auto *newInst = builder.createFunctionRefFor(inst->getLoc(),
wit->getDerivative(*kind));
inst->replaceAllUsesWith(newInst);
inst->getParent()->erase(inst);
}
return changed;
}

SILTransform *swift::createDifferentiabilityWitnessDevirtualizer() {
return new DifferentiabilityWitnessDevirtualizer();
}
29 changes: 20 additions & 9 deletions lib/Serialization/DeserializeSIL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3459,18 +3459,29 @@ SILDeserializer::readDifferentiabilityWitness(DeclID DId) {
ArrayRef<unsigned>(parameterAndResultIndices)
.take_back(numResultIndices));

if (isDeclaration) {
auto *diffWitness = SILDifferentiabilityWitness::createDeclaration(
AutoDiffConfig config(parameterIndices, resultIndices, derivativeGenSig);
auto *diffWitness =
SILMod.lookUpDifferentiabilityWitness({originalName, config});

// If there is no existing differentiability witness, create one.
if (!diffWitness)
diffWitness = SILDifferentiabilityWitness::createDeclaration(
SILMod, *linkage, original, parameterIndices, resultIndices,
derivativeGenSig);
diffWitnessOrOffset.set(diffWitness, /*isFullyDeserialized*/ true);
return diffWitness;
}

auto *diffWitness = SILDifferentiabilityWitness::createDefinition(
SILMod, *linkage, original, parameterIndices, resultIndices,
derivativeGenSig, jvp, vjp, isSerialized);
diffWitnessOrOffset.set(diffWitness, /*isFullyDeserialized*/ true);
// If the current differentiability witness is merely a declaration, and the
// deserialized witness is a definition, upgrade the current differentiability
// witness to a definition. This can happen in the following situations:
// 1. The witness was just created above.
// 2. The witness started out as a declaration (e.g. the differentiation
// pass emitted a witness for an external function) and now we're loading
// the definition (e.g. an optimization pass asked for the definition and
// we found the definition serialized in this module).
if (diffWitness->isDeclaration() && !isDeclaration)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please explain when this condition is true and convertToDefinition is called?
It doesn't seem wholly obvious, perhaps an explanatory comment would be good.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

diffWitness->convertToDefinition(jvp, vjp, isSerialized);

diffWitnessOrOffset.set(diffWitness,
/*isFullyDeserialized*/ diffWitness->isDefinition());
return diffWitness;
}

Expand Down
13 changes: 9 additions & 4 deletions lib/Serialization/SerializedSILLoader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -133,10 +133,15 @@ SerializedSILLoader::lookupDifferentiabilityWitness(
SILDifferentiabilityWitnessKey key) {
Mangle::ASTMangler mangler;
std::string mangledKey = mangler.mangleSILDifferentiabilityWitnessKey(key);
for (auto &Des : LoadedSILSections)
if (auto *diffWitness = Des->lookupDifferentiabilityWitness(mangledKey))
return diffWitness;
return nullptr;
// It is possible that one module has a declaration of a
// SILDifferentiabilityWitness, while another has the full definition.
SILDifferentiabilityWitness *wit = nullptr;
for (auto &Des : LoadedSILSections) {
wit = Des->lookupDifferentiabilityWitness(mangledKey);
if (wit && wit->isDefinition())
return wit;
}
return wit;
}
// SWIFT_ENABLE_TENSORFLOW END

Expand Down
42 changes: 42 additions & 0 deletions test/AutoDiff/differentiability_witness_inlining.sil
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
// RUN: %target-sil-opt -differentiability-witness-devirtualizer %s -enable-sil-verify-all | %FileCheck %s

sil_stage raw

import Swift
import Builtin

sil_differentiability_witness [parameters 0] [results 0] @witness_defined_in_module : $@convention(thin) (Float) -> Float {
jvp: @witness_defined_in_module_jvp : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
vjp: @witness_defined_in_module_vjp : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
}

sil_differentiability_witness [parameters 0] [results 0] @witness_definition_not_available : $@convention(thin) (Float) -> Float

// This is an example of a witness that is available (via deserialization)
// even though it is not defined in the current module.
// witness for static Swift.Float.+ infix(Swift.Float, Swift.Float) -> Swift.Float
sil_differentiability_witness [parameters 0 1] [results 0] @$sSf1poiyS2f_SftFZ : $@convention(method) (Float, Float, @thin Float.Type) -> Float

sil @witness_defined_in_module : $@convention(thin) (Float) -> Float

sil @witness_defined_in_module_jvp : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)

sil @witness_defined_in_module_vjp : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)

sil @witness_definition_not_available : $@convention(thin) (Float) -> Float

sil public_external [transparent] [serialized] @$sSf1poiyS2f_SftFZ : $@convention(method) (Float, Float, @thin Float.Type) -> Float

sil @test : $@convention(thin) (Float) -> () {
bb0(%0 : $Float):
%1 = differentiability_witness_function [vjp] [parameters 0] [results 0] @witness_defined_in_module : $@convention(thin) (Float) -> Float
// CHECK: %1 = function_ref @witness_defined_in_module_vjp : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)

%2 = differentiability_witness_function [vjp] [parameters 0] [results 0] @witness_definition_not_available : $@convention(thin) (Float) -> Float
// CHECK: %2 = differentiability_witness_function [vjp] [parameters 0] [results 0] @witness_definition_not_available : $@convention(thin) (Float) -> Float

%3 = differentiability_witness_function [vjp] [parameters 0 1] [results 0] @$sSf1poiyS2f_SftFZ : $@convention(method) (Float, Float, @thin Float.Type) -> Float
// CHECK: %3 = function_ref @AD__$sSf1poiyS2f_SftFZ__vjp_src_0_wrt_0_1 : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float) -> (Float, Float))

return undef : $()
}
2 changes: 1 addition & 1 deletion test/AutoDiff/sil_differentiability_witness_silgen.swift
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ public struct Foo: Differentiable {
public var x: Float

// CHECK-LABEL: // differentiability witness for Foo.x.getter
// CHECK-NEXT: sil_differentiability_witness [parameters 0] [results 0] @$s36sil_differentiability_witness_silgen3FooV1xSfvg : $@convention(method) (Foo) -> Float {
// CHECK-NEXT: sil_differentiability_witness [serialized] [parameters 0] [results 0] @$s36sil_differentiability_witness_silgen3FooV1xSfvg : $@convention(method) (Foo) -> Float {
// CHECK-NEXT: }

@differentiable
Expand Down