From 2fad0259b66316c7c59795bc5188f77f8dc90a70 Mon Sep 17 00:00:00 2001 From: Saleem Abdulrasool Date: Sat, 4 Jul 2020 09:31:13 -0700 Subject: [PATCH 01/16] test: make `test_util` more Python 3 friendly This uses `io.open` to allow `test_util` to open with the encoding and newline handling across python 2 and python 3. With this change, the test suite failures are within the single digits locally. --- utils/incrparse/test_util.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/utils/incrparse/test_util.py b/utils/incrparse/test_util.py index 970ab52312ab1..d099d9c6b3d68 100755 --- a/utils/incrparse/test_util.py +++ b/utils/incrparse/test_util.py @@ -3,6 +3,7 @@ from __future__ import print_function import argparse +import io import os import re import subprocess @@ -21,6 +22,7 @@ def escapeCmdArg(arg): def run_command(cmd): + cmd = list(map(lambda s: s.encode('utf-8'), cmd)) print(' '.join([escapeCmdArg(arg) for arg in cmd])) return subprocess.check_output(cmd, stderr=subprocess.STDOUT) @@ -56,7 +58,7 @@ def parseLine(line, line_no, test_case, incremental_edit_args, reparse_args, # Compute the -incremental-edit argument for swift-syntax-test column = len(pre_edit_line) + len(prefix) + 1 edit_arg = '%d:%d-%d:%d=%s' % \ - (line_no, column, line_no, column + len(pre_edit), + (line_no, column, line_no, column + len(pre_edit.encode('utf-8')), post_edit) incremental_edit_args.append('-incremental-edit') incremental_edit_args.append(edit_arg) @@ -102,14 +104,14 @@ def parseLine(line, line_no, test_case, incremental_edit_args, reparse_args, # Nothing more to do line = '' - return (pre_edit_line, post_edit_line, current_reparse_start) + return (pre_edit_line.encode('utf-8'), post_edit_line.encode('utf-8'), current_reparse_start) def prepareForIncrParse(test_file, test_case, pre_edit_file, post_edit_file, incremental_edit_args, reparse_args): - with open(test_file, mode='r') as test_file_handle, \ - open(pre_edit_file, mode='w+b') as pre_edit_file_handle, \ - open(post_edit_file, mode='w+b') as post_edit_file_handle: + with io.open(test_file, mode='r', encoding='utf-8', newline='\n') as test_file_handle, \ + io.open(pre_edit_file, mode='w+', encoding='utf-8', newline='\n') as pre_edit_file_handle, \ + io.open(post_edit_file, mode='w+', encoding='utf-8', newline='\n') as post_edit_file_handle: current_reparse_start = None @@ -121,8 +123,8 @@ def prepareForIncrParse(test_file, test_case, pre_edit_file, post_edit_file, (pre_edit_line, post_edit_line, current_reparse_start) = \ parseLineRes - pre_edit_file_handle.write(pre_edit_line) - post_edit_file_handle.write(post_edit_line) + pre_edit_file_handle.write(pre_edit_line.decode('utf-8')) + post_edit_file_handle.write(post_edit_line.decode('utf-8')) line_no += 1 @@ -231,7 +233,7 @@ def serializeIncrParseMarkupFile(test_file, test_case, mode, if print_visual_reuse_info: print(output) except subprocess.CalledProcessError as e: - raise TestFailedError(e.output) + raise TestFailedError(e.output.decode('utf-8')) def main(): From dcd33ca7336db8651df865db7afd242e643fb9b2 Mon Sep 17 00:00:00 2001 From: Saleem Abdulrasool Date: Mon, 6 Jul 2020 14:27:00 -0700 Subject: [PATCH 02/16] test: setup `PYTHONIOENCODING` on Windows This is required to ensure that we encode the strings correctly when redirected. This allows the printing to succeed where it would fail previously due to encoding failures. --- test/lit.cfg | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/lit.cfg b/test/lit.cfg index 08a38a66606f2..8b0c1b82625f4 100644 --- a/test/lit.cfg +++ b/test/lit.cfg @@ -183,6 +183,9 @@ def append_to_env_path(directory): config.environment['PATH'] = \ os.path.pathsep.join((directory, config.environment['PATH'])) +if kIsWindows: + config.environment['PYTHONIOENCODING'] = 'UTF8' + # Tweak the PATH to include the tools dir and the scripts dir. if swift_obj_root is not None: llvm_tools_dir = getattr(config, 'llvm_tools_dir', None) From 7bf676d335b80b2934f0a6dc75b9b055acf02adc Mon Sep 17 00:00:00 2001 From: Doug Gregor Date: Tue, 7 Jul 2020 22:11:19 -0700 Subject: [PATCH 03/16] [Frontend] Add compiler version information to -print-target-info output. Clients that use -print-target-info can avoid an extra frontend invocation by using this information. --- lib/FrontendTool/FrontendTool.cpp | 6 ++++++ test/Driver/print_target_info.swift | 4 ++++ 2 files changed, 10 insertions(+) diff --git a/lib/FrontendTool/FrontendTool.cpp b/lib/FrontendTool/FrontendTool.cpp index 369f92268751a..81de7bbd8a8a3 100644 --- a/lib/FrontendTool/FrontendTool.cpp +++ b/lib/FrontendTool/FrontendTool.cpp @@ -1985,6 +1985,12 @@ static void printTargetInfo(const CompilerInvocation &invocation, llvm::raw_ostream &out) { out << "{\n"; + // Compiler version, as produced by --version. + out << " \"compilerVersion\": \""; + out.write_escaped(version::getSwiftFullVersion( + version::Version::getCurrentLanguageVersion())); + out << "\",\n"; + // Target triple and target variant triple. auto &langOpts = invocation.getLangOptions(); out << " \"target\": "; diff --git a/test/Driver/print_target_info.swift b/test/Driver/print_target_info.swift index fdb636dcf1f5d..2bbafb07be7d0 100644 --- a/test/Driver/print_target_info.swift +++ b/test/Driver/print_target_info.swift @@ -9,6 +9,8 @@ // RUN: %swift_driver -print-target-info -target x86_64-apple-ios12.0 | %FileCheck -check-prefix CHECK-IOS-SIM %s +// CHECK-IOS: "compilerVersion": "Swift version + // CHECK-IOS: "target": { // CHECK-IOS: "triple": "arm64-apple-ios12.0", // CHECK-IOS: "unversionedTriple": "arm64-apple-ios", @@ -28,6 +30,8 @@ // CHECK-IOS: } +// CHECK-LINUX: "compilerVersion": "Swift version + // CHECK-LINUX: "target": { // CHECK-LINUX: "triple": "x86_64-unknown-linux", // CHECK-LINUX: "moduleTriple": "x86_64-unknown-linux", From f818a1e1b0a299dbf3c5c0d6973c7baae9c39f18 Mon Sep 17 00:00:00 2001 From: Doug Gregor Date: Tue, 7 Jul 2020 22:55:18 -0700 Subject: [PATCH 04/16] [Driver/IRGen] Put backward-deployment libraries into a table. Describe the backward-deployment libraries via a preprocessor-driven table. Macro-metaprogramming the two places in the code base---the driver and IRGen---to use this tabble to determine which backward-compatibility libraries to link against. --- include/swift/Frontend/BackDeploymentLibs.def | 31 +++++++++ lib/Driver/DarwinToolChains.cpp | 66 ++++++++++--------- lib/IRGen/GenDecl.cpp | 43 ++++++------ 3 files changed, 89 insertions(+), 51 deletions(-) create mode 100644 include/swift/Frontend/BackDeploymentLibs.def diff --git a/include/swift/Frontend/BackDeploymentLibs.def b/include/swift/Frontend/BackDeploymentLibs.def new file mode 100644 index 0000000000000..b2d9f8a84acd9 --- /dev/null +++ b/include/swift/Frontend/BackDeploymentLibs.def @@ -0,0 +1,31 @@ +//===------ BackDeploymentLibs.def - Backward Deployment Libraries --------===// +// +// This source file is part of the Swift.org open source project +// +// Copyright (c) 2014 - 2018 Apple Inc. and the Swift project authors +// Licensed under Apache License v2.0 with Runtime Library Exception +// +// See https://swift.org/LICENSE.txt for license information +// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors +// +//===----------------------------------------------------------------------===// +// Enumerates the backward deployment libraries that need to be linked +// into Swift targets. Clients of this file must define the macro +// +// BACK_DEPLOYMENT_LIB(Version, Filter, LibraryName) +// +// where: +// Version is a maximum Swift version written like a tuple, e.g., (5, 1) +// Filter is one of executable or all. +// LibraryName is the name of the library, e.g., "swiftCompatibility51" +//===----------------------------------------------------------------------===// + +#ifndef BACK_DEPLOYMENT_LIB +# error "Must define BACK_DEPLOYMENT_LIB(Version, Filter, Library)" +#endif + +BACK_DEPLOYMENT_LIB((5, 0), all, "swiftCompatibility50") +BACK_DEPLOYMENT_LIB((5, 1), all, "swiftCompatibility51") +BACK_DEPLOYMENT_LIB((5, 0), executable, "swiftCompatibilityDynamicReplacements") + +#undef BACK_DEPLOYMENT_LIB diff --git a/lib/Driver/DarwinToolChains.cpp b/lib/Driver/DarwinToolChains.cpp index a7153dc064e52..9ac64c30c5206 100644 --- a/lib/Driver/DarwinToolChains.cpp +++ b/lib/Driver/DarwinToolChains.cpp @@ -330,6 +330,26 @@ toolchains::Darwin::addSanitizerArgs(ArgStringList &Arguments, /*shared=*/false); } +namespace { + +enum class BackDeployLibFilter { + executable, + all +}; + +// Whether the given job matches the backward-deployment library filter. +bool jobMatchesFilter(LinkKind jobKind, BackDeployLibFilter filter) { + switch (filter) { + case BackDeployLibFilter::executable: + return jobKind == LinkKind::Executable; + + case BackDeployLibFilter::all: + return true; + } +} + +} + void toolchains::Darwin::addArgsToLinkStdlib(ArgStringList &Arguments, const DynamicLinkJobAction &job, @@ -359,47 +379,31 @@ toolchains::Darwin::addArgsToLinkStdlib(ArgStringList &Arguments, } if (runtimeCompatibilityVersion) { - if (*runtimeCompatibilityVersion <= llvm::VersionTuple(5, 0)) { - // Swift 5.0 compatibility library - SmallString<128> BackDeployLib; - BackDeployLib.append(SharedResourceDirPath); - llvm::sys::path::append(BackDeployLib, "libswiftCompatibility50.a"); - - if (llvm::sys::fs::exists(BackDeployLib)) { - Arguments.push_back("-force_load"); - Arguments.push_back(context.Args.MakeArgString(BackDeployLib)); - } - } + auto addBackDeployLib = [&](llvm::VersionTuple version, + BackDeployLibFilter filter, + StringRef libraryName) { + if (*runtimeCompatibilityVersion > version) + return; - if (*runtimeCompatibilityVersion <= llvm::VersionTuple(5, 1)) { - // Swift 5.1 compatibility library + if (!jobMatchesFilter(job.getKind(), filter)) + return; + SmallString<128> BackDeployLib; BackDeployLib.append(SharedResourceDirPath); - llvm::sys::path::append(BackDeployLib, "libswiftCompatibility51.a"); + llvm::sys::path::append(BackDeployLib, "lib" + libraryName + ".a"); if (llvm::sys::fs::exists(BackDeployLib)) { Arguments.push_back("-force_load"); Arguments.push_back(context.Args.MakeArgString(BackDeployLib)); } - } + }; + + #define BACK_DEPLOYMENT_LIB(Version, Filter, LibraryName) \ + addBackDeployLib( \ + llvm::VersionTuple Version, BackDeployLibFilter::Filter, LibraryName); + #include "swift/Frontend/BackDeploymentLibs.def" } - if (job.getKind() == LinkKind::Executable) { - if (runtimeCompatibilityVersion) - if (*runtimeCompatibilityVersion <= llvm::VersionTuple(5, 0)) { - // Swift 5.0 dynamic replacement compatibility library. - SmallString<128> BackDeployLib; - BackDeployLib.append(SharedResourceDirPath); - llvm::sys::path::append(BackDeployLib, - "libswiftCompatibilityDynamicReplacements.a"); - - if (llvm::sys::fs::exists(BackDeployLib)) { - Arguments.push_back("-force_load"); - Arguments.push_back(context.Args.MakeArgString(BackDeployLib)); - } - } - } - // Add the runtime library link path, which is platform-specific and found // relative to the compiler. SmallVector RuntimeLibPaths; diff --git a/lib/IRGen/GenDecl.cpp b/lib/IRGen/GenDecl.cpp index c8c22d832d097..ac7b9b10e2b70 100644 --- a/lib/IRGen/GenDecl.cpp +++ b/lib/IRGen/GenDecl.cpp @@ -471,28 +471,31 @@ void IRGenModule::emitSourceFile(SourceFile &SF) { // situations where it isn't useful, such as for dylibs, though this is // harmless aside from code size. if (!IRGen.Opts.UseJIT) { - if (auto compatibilityVersion - = IRGen.Opts.AutolinkRuntimeCompatibilityLibraryVersion) { - if (*compatibilityVersion <= llvm::VersionTuple(5, 0)) { - this->addLinkLibrary(LinkLibrary("swiftCompatibility50", - LibraryKind::Library, - /*forceLoad*/ true)); - } - if (*compatibilityVersion <= llvm::VersionTuple(5, 1)) { - this->addLinkLibrary(LinkLibrary("swiftCompatibility51", - LibraryKind::Library, - /*forceLoad*/ true)); + auto addBackDeployLib = [&](llvm::VersionTuple version, + StringRef libraryName) { + Optional compatibilityVersion; + if (libraryName == "swiftCompatibilityDynamicReplacements") { + compatibilityVersion = IRGen.Opts. + AutolinkRuntimeCompatibilityDynamicReplacementLibraryVersion; + } else { + compatibilityVersion = IRGen.Opts. + AutolinkRuntimeCompatibilityLibraryVersion; } - } - if (auto compatibilityVersion = - IRGen.Opts.AutolinkRuntimeCompatibilityDynamicReplacementLibraryVersion) { - if (*compatibilityVersion <= llvm::VersionTuple(5, 0)) { - this->addLinkLibrary(LinkLibrary("swiftCompatibilityDynamicReplacements", - LibraryKind::Library, - /*forceLoad*/ true)); - } - } + if (!compatibilityVersion) + return; + + if (*compatibilityVersion > version) + return; + + this->addLinkLibrary(LinkLibrary(libraryName, + LibraryKind::Library, + /*forceLoad*/ true)); + }; + + #define BACK_DEPLOYMENT_LIB(Version, Filter, LibraryName) \ + addBackDeployLib(llvm::VersionTuple Version, LibraryName); + #include "swift/Frontend/BackDeploymentLibs.def" } } From 618af0420ffd57b6d955b5c42a91f8b41a912fbb Mon Sep 17 00:00:00 2001 From: Doug Gregor Date: Tue, 7 Jul 2020 23:24:38 -0700 Subject: [PATCH 05/16] [Frontend] Add compatibility libraries to -print-target-info. The driver and any other client that attempts to properly link Swift code need to know which compatibility libraries should be linked on a per-target basis. Vend that information as part of -print-target-info. --- lib/FrontendTool/FrontendTool.cpp | 53 ++++++++++++++++++++++++++--- test/Driver/print_target_info.swift | 6 ++++ 2 files changed, 55 insertions(+), 4 deletions(-) diff --git a/lib/FrontendTool/FrontendTool.cpp b/lib/FrontendTool/FrontendTool.cpp index 81de7bbd8a8a3..da9d2c7bc740f 100644 --- a/lib/FrontendTool/FrontendTool.cpp +++ b/lib/FrontendTool/FrontendTool.cpp @@ -1948,8 +1948,36 @@ createJSONFixItDiagnosticConsumerIfNeeded( }); } +/// Print information about a +static void printCompatibilityLibrary( + llvm::VersionTuple runtimeVersion, llvm::VersionTuple maxVersion, + StringRef filter, StringRef libraryName, bool &printedAny, + llvm::raw_ostream &out) { + if (runtimeVersion > maxVersion) + return; + + if (printedAny) { + out << ","; + } + + out << "\n"; + out << " {\n"; + + out << " \"libraryName\": \""; + out.write_escaped(libraryName); + out << "\",\n"; + + out << " \"filter\": \""; + out.write_escaped(filter); + out << "\"\n"; + out << " }"; + + printedAny = true; +} + /// Print information about the target triple in JSON. static void printTripleInfo(const llvm::Triple &triple, + llvm::Optional runtimeVersion, llvm::raw_ostream &out) { out << "{\n"; @@ -1965,11 +1993,26 @@ static void printTripleInfo(const llvm::Triple &triple, out.write_escaped(getTargetSpecificModuleTriple(triple).getTriple()); out << "\",\n"; - if (auto runtimeVersion = getSwiftRuntimeCompatibilityVersionForTarget( - triple)) { + if (runtimeVersion) { out << " \"swiftRuntimeCompatibilityVersion\": \""; out.write_escaped(runtimeVersion->getAsString()); out << "\",\n"; + + // Compatibility libraries that need to be linked. + out << " \"compatibilityLibraries\": ["; + bool printedAnyCompatibilityLibrary = false; + #define BACK_DEPLOYMENT_LIB(Version, Filter, LibraryName) \ + printCompatibilityLibrary( \ + *runtimeVersion, llvm::VersionTuple Version, #Filter, LibraryName, \ + printedAnyCompatibilityLibrary, out); + #include "swift/Frontend/BackDeploymentLibs.def" + + if (printedAnyCompatibilityLibrary) { + out << "\n "; + } + out << " ],\n"; + } else { + out << " \"compatibilityLibraries\": [ ],\n"; } out << " \"librariesRequireRPath\": " @@ -1992,14 +2035,16 @@ static void printTargetInfo(const CompilerInvocation &invocation, out << "\",\n"; // Target triple and target variant triple. + auto runtimeVersion = + invocation.getIRGenOptions().AutolinkRuntimeCompatibilityLibraryVersion; auto &langOpts = invocation.getLangOptions(); out << " \"target\": "; - printTripleInfo(langOpts.Target, out); + printTripleInfo(langOpts.Target, runtimeVersion, out); out << ",\n"; if (auto &variant = langOpts.TargetVariant) { out << " \"targetVariant\": "; - printTripleInfo(*variant, out); + printTripleInfo(*variant, runtimeVersion, out); out << ",\n"; } diff --git a/test/Driver/print_target_info.swift b/test/Driver/print_target_info.swift index 2bbafb07be7d0..60cfcef332105 100644 --- a/test/Driver/print_target_info.swift +++ b/test/Driver/print_target_info.swift @@ -16,6 +16,12 @@ // CHECK-IOS: "unversionedTriple": "arm64-apple-ios", // CHECK-IOS: "moduleTriple": "arm64-apple-ios", // CHECK-IOS: "swiftRuntimeCompatibilityVersion": "5.0", +// CHECK-IOS: "compatibilityLibraries": [ +// CHECK-IOS: "libraryName": "swiftCompatibility50", +// CHECK-IOS: "libraryName": "swiftCompatibility51", +// CHECK-IOS: "libraryName": "swiftCompatibilityDynamicReplacements" +// CHECK-IOS: "filter": "executable" +// CHECK-IOS: ], // CHECK-IOS: "librariesRequireRPath": true // CHECK-IOS: } From be79d341e6d1bbb78e7bbc7a12d9adabfda4a1a1 Mon Sep 17 00:00:00 2001 From: Suyash Srijan Date: Wed, 8 Jul 2020 14:37:48 +0100 Subject: [PATCH 06/16] [AssociatedTypeInference] Strip 'self' parameter from function type of an enum case witness --- lib/Sema/TypeCheckProtocolInference.cpp | 5 ++--- .../req/associated_type_inference.swift | 19 +++++++++++++++++++ 2 files changed, 21 insertions(+), 3 deletions(-) diff --git a/lib/Sema/TypeCheckProtocolInference.cpp b/lib/Sema/TypeCheckProtocolInference.cpp index 81b4aad879314..d430c2e920cd1 100644 --- a/lib/Sema/TypeCheckProtocolInference.cpp +++ b/lib/Sema/TypeCheckProtocolInference.cpp @@ -546,9 +546,8 @@ static Type getWitnessTypeForMatching(NormalProtocolConformance *conformance, /// Remove the 'self' type from the given type, if it's a method type. static Type removeSelfParam(ValueDecl *value, Type type) { - if (auto func = dyn_cast(value)) { - if (func->getDeclContext()->isTypeContext()) - return type->castTo()->getResult(); + if (value->hasCurriedSelf()) { + return type->castTo()->getResult(); } return type; diff --git a/test/decl/protocol/req/associated_type_inference.swift b/test/decl/protocol/req/associated_type_inference.swift index c2250684c3437..97c9d1136787c 100644 --- a/test/decl/protocol/req/associated_type_inference.swift +++ b/test/decl/protocol/req/associated_type_inference.swift @@ -588,3 +588,22 @@ extension SR_12707_P2 { struct SR_12707_Conform_P2: SR_12707_P2 { typealias A = Never } + +// SR-13172: Inference when witness is an enum case +protocol SR_13172_P1 { + associatedtype Bar + static func bar(_ value: Bar) -> Self +} + +enum SR_13172_E1: SR_13172_P1 { + case bar(String) // Okay +} + +protocol SR_13172_P2 { + associatedtype Bar + static var bar: Bar { get } +} + +enum SR_13172_E2: SR_13172_P2 { + case bar // Okay +} From e547182c9191b7475964b10c9915c07f3f384945 Mon Sep 17 00:00:00 2001 From: Doug Gregor Date: Wed, 8 Jul 2020 07:53:12 -0700 Subject: [PATCH 07/16] Loosen a test to deal with vendor nane --- test/Driver/print_target_info.swift | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/Driver/print_target_info.swift b/test/Driver/print_target_info.swift index 60cfcef332105..5f8bd50b2bae5 100644 --- a/test/Driver/print_target_info.swift +++ b/test/Driver/print_target_info.swift @@ -9,7 +9,7 @@ // RUN: %swift_driver -print-target-info -target x86_64-apple-ios12.0 | %FileCheck -check-prefix CHECK-IOS-SIM %s -// CHECK-IOS: "compilerVersion": "Swift version +// CHECK-IOS: "compilerVersion": "{{.*}}Swift version // CHECK-IOS: "target": { // CHECK-IOS: "triple": "arm64-apple-ios12.0", From c1ce8b6a67b32ba3fa07ccc9d915f1e69717f0dd Mon Sep 17 00:00:00 2001 From: Saleem Abdulrasool Date: Wed, 8 Jul 2020 12:51:24 -0700 Subject: [PATCH 08/16] Update test_util.py Appease the python linter --- utils/incrparse/test_util.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/utils/incrparse/test_util.py b/utils/incrparse/test_util.py index d099d9c6b3d68..35c965276c17e 100755 --- a/utils/incrparse/test_util.py +++ b/utils/incrparse/test_util.py @@ -104,14 +104,19 @@ def parseLine(line, line_no, test_case, incremental_edit_args, reparse_args, # Nothing more to do line = '' - return (pre_edit_line.encode('utf-8'), post_edit_line.encode('utf-8'), current_reparse_start) + return (pre_edit_line.encode('utf-8'), + post_edit_line.encode('utf-8'), + current_reparse_start) def prepareForIncrParse(test_file, test_case, pre_edit_file, post_edit_file, incremental_edit_args, reparse_args): - with io.open(test_file, mode='r', encoding='utf-8', newline='\n') as test_file_handle, \ - io.open(pre_edit_file, mode='w+', encoding='utf-8', newline='\n') as pre_edit_file_handle, \ - io.open(post_edit_file, mode='w+', encoding='utf-8', newline='\n') as post_edit_file_handle: + with io.open(test_file, mode='r', encoding='utf-8', + newline='\n') as test_file_handle, \ + io.open(pre_edit_file, mode='w+', encoding='utf-8', + newline='\n') as pre_edit_file_handle, \ + io.open(post_edit_file, mode='w+', encoding='utf-8', + newline='\n') as post_edit_file_handle: current_reparse_start = None From 6c89aa052689444e27e725d077e58da0e4edce16 Mon Sep 17 00:00:00 2001 From: Mishal Shah Date: Wed, 8 Jul 2020 14:12:51 -0700 Subject: [PATCH 09/16] [Build System] Support host target prefix in symbols package --- utils/build-script | 1 + 1 file changed, 1 insertion(+) diff --git a/utils/build-script b/utils/build-script index 043b0df37bc0e..b5f140ed1921f 100755 --- a/utils/build-script +++ b/utils/build-script @@ -1334,6 +1334,7 @@ def main_normal(): if platform.system() == 'Darwin': prefix = targets.darwin_toolchain_prefix(args.install_prefix) + prefix = os.path.join(args.host_target, prefix.lstrip('/')) else: prefix = args.install_prefix From a778f517d1b6e3789b80c9f59f07913eee06308e Mon Sep 17 00:00:00 2001 From: Rintaro Ishizaki Date: Wed, 8 Jul 2020 10:57:24 -0700 Subject: [PATCH 10/16] [CodeCompletion] Remove redundant entries from possible callee analysis This used to cause duplicated results in call signature completions. i.e.: AlertViewController(#^HERE^# // 2 x (coder: NSCoder) rdar://problem/65081358 --- lib/IDE/ExprContextAnalysis.cpp | 49 ++++++++----- test/IDE/complete_expr_after_paren.swift | 91 ++++++++++++++++++++++++ 2 files changed, 124 insertions(+), 16 deletions(-) create mode 100644 test/IDE/complete_expr_after_paren.swift diff --git a/lib/IDE/ExprContextAnalysis.cpp b/lib/IDE/ExprContextAnalysis.cpp index 602f5f32b19c3..14bb5e1b5ef56 100644 --- a/lib/IDE/ExprContextAnalysis.cpp +++ b/lib/IDE/ExprContextAnalysis.cpp @@ -388,6 +388,7 @@ static void collectPossibleCalleesByQualifiedLookup( decls)) return; + llvm::DenseMap, size_t> known; auto *baseNominal = baseInstanceTy->getAnyNominal(); for (auto *VD : decls) { if ((!isa(VD) && !isa(VD)) || @@ -420,27 +421,43 @@ static void collectPossibleCalleesByQualifiedLookup( DC.getParentModule(), VD, VD->getInnermostDeclContext()->getGenericEnvironmentOfContext()); auto fnType = declaredMemberType.subst(subs); - if (!fnType) + if (!fnType || !fnType->is()) continue; - if (fnType->is()) { - // If we are calling to typealias type, - if (isa(baseInstanceTy.getPointer())) { - auto canBaseTy = baseInstanceTy->getCanonicalType(); - fnType = fnType.transform([&](Type t) -> Type { - if (t->getCanonicalType()->isEqual(canBaseTy)) - return baseInstanceTy; - return t; - }); - } - auto semanticContext = SemanticContextKind::CurrentNominal; - if (baseNominal && - VD->getDeclContext()->getSelfNominalTypeDecl() != baseNominal) - semanticContext = SemanticContextKind::Super; + // If we are calling on a type alias type, replace the canonicalized type + // in the function type with the type alias. + if (isa(baseInstanceTy.getPointer())) { + auto canBaseTy = baseInstanceTy->getCanonicalType(); + fnType = fnType.transform([&](Type t) -> Type { + if (t->getCanonicalType()->isEqual(canBaseTy)) + return baseInstanceTy; + return t; + }); + } - candidates.emplace_back(fnType->castTo(), VD, + auto semanticContext = SemanticContextKind::CurrentNominal; + if (baseNominal && + VD->getDeclContext()->getSelfNominalTypeDecl() != baseNominal) + semanticContext = SemanticContextKind::Super; + + FunctionTypeAndDecl entry(fnType->castTo(), VD, semanticContext); + // Remember the index of the entry. + auto knownResult = known.insert( + {{VD->isStatic(), fnType->getCanonicalType()}, candidates.size()}); + if (knownResult.second) { + candidates.push_back(entry); + continue; } + + auto idx = knownResult.first->second; + if (AvailableAttr::isUnavailable(candidates[idx].Decl) && + !AvailableAttr::isUnavailable(VD)) { + // Replace the previously found "unavailable" with the "available" one. + candidates[idx] = entry; + } + + // Otherwise, skip redundant results. } } diff --git a/test/IDE/complete_expr_after_paren.swift b/test/IDE/complete_expr_after_paren.swift new file mode 100644 index 0000000000000..e927dcbfb418e --- /dev/null +++ b/test/IDE/complete_expr_after_paren.swift @@ -0,0 +1,91 @@ +// RUN: %swift-ide-test -code-completion -source-filename %s -code-completion-token=INITIALIZER | %FileCheck %s --check-prefix=INITIALIZER +// RUN: %swift-ide-test -code-completion -source-filename %s -code-completion-token=METHOD | %FileCheck %s --check-prefix=METHOD +// RUN: %swift-ide-test -code-completion -source-filename %s -code-completion-token=AVAILABILITY | %FileCheck %s --check-prefix=AVAILABILITY +// RUN: %swift-ide-test -code-completion -source-filename %s -code-completion-token=STATIC | %FileCheck %s --check-prefix=STATIC + +protocol MyProtocol { + init(init1: Int) + init(init2: Int) + + func method(method1: Int) + func method(method2: Int) +} + +extension MyProtocol { + init(init2: Int) { self.init(init1: init2) } + init(init3: Int) { self.init(init1: init3) } + + func method(method2: Int) {} + func method(method3: Int) {} +} + +class Base { + init(init4: Int) { } + func method(method4: Int) {} +} + +class MyClass: Base, MyProtocol { + + required init(init1: Int) { super.init(init4: init1) } + required init(init2: Int) { super.init(init4: init1) } + init(init3: Int) { super.init(init4: init1) } + override init(init4: Int) { super.init(init4: init1) } + + func method(method1: Int) + func method(method2: Int) {} + func method(method3: Int) {} + override func method(method4: Int) {} +} + +func testConstructer() { + MyClass(#^INITIALIZER^#) +// INITIALIZER: Begin completions, 4 items +// INITIALIZER-DAG: Decl[Constructor]/CurrNominal: ['(']{#init1: Int#}[')'][#MyClass#]; +// INITIALIZER-DAG: Decl[Constructor]/CurrNominal: ['(']{#init2: Int#}[')'][#MyClass#]; +// INITIALIZER-DAG: Decl[Constructor]/CurrNominal: ['(']{#init3: Int#}[')'][#MyClass#]; +// INITIALIZER-DAG: Decl[Constructor]/CurrNominal: ['(']{#init4: Int#}[')'][#MyClass#]; +// INITIALIZER: End completions +} + +func testMethod(obj: MyClass) { + obj.method(#^METHOD^#) +// METHOD: Begin completions, 4 items +// METHOD-DAG: Decl[InstanceMethod]/CurrNominal: ['(']{#method1: Int#}[')'][#Void#]; +// METHOD-DAG: Decl[InstanceMethod]/CurrNominal: ['(']{#method2: Int#}[')'][#Void#]; +// METHOD-DAG: Decl[InstanceMethod]/CurrNominal: ['(']{#method3: Int#}[')'][#Void#]; +// METHOD-DAG: Decl[InstanceMethod]/CurrNominal: ['(']{#method4: Int#}[')'][#Void#]; +// METHOD: End completions +} + +protocol HasUnavailable {} +extension HasUnavailable { + func method(method1: Int) {} + + @available(*, unavailable) + func method(method2: Int) {} +} +struct MyStruct: HasUnavailable { + @available(*, unavailable) + func method(method1: Int) {} + + func method(method2: Int) {} +} +func testUnavailable(val: MyStruct) { + val.method(#^AVAILABILITY^#) +// AVAILABILITY: Begin completions, 2 items +// AVAILABILITY-DAG: Decl[InstanceMethod]/CurrNominal: ['(']{#method2: Int#}[')'][#Void#]; +// AVAILABILITY-DAG: Decl[InstanceMethod]/Super: ['(']{#method1: Int#}[')'][#Void#]; +// AVAILABILITY: End completions +} + +struct TestStatic { + static func method(_ self: TestStatic) -> () -> Void {} + func method() -> Void {} +} +func testStaticFunc() { + TestStatic.method(#^STATIC^#) +// STATIC: Begin completions +// STATIC-DAG: Decl[StaticMethod]/CurrNominal: ['(']{#(self): TestStatic#}[')'][#() -> Void#]; +// STATIC-DAG: Decl[InstanceMethod]/CurrNominal: ['(']{#(self): TestStatic#}[')'][#() -> Void#]; +// STATIC: End completions +} From f8d8091c98276857d44095d51ef94dee29f18a21 Mon Sep 17 00:00:00 2001 From: Meghana Gupta Date: Wed, 8 Jul 2020 14:54:59 -0700 Subject: [PATCH 11/16] [ownership] Move ome after GlobalOpt (#32742) --- lib/SILOptimizer/PassManager/PassPipeline.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/lib/SILOptimizer/PassManager/PassPipeline.cpp b/lib/SILOptimizer/PassManager/PassPipeline.cpp index d2285a1620244..3f586c5cc2734 100644 --- a/lib/SILOptimizer/PassManager/PassPipeline.cpp +++ b/lib/SILOptimizer/PassManager/PassPipeline.cpp @@ -453,14 +453,14 @@ static void addPerfEarlyModulePassPipeline(SILPassPipelinePlan &P) { // is linked in from the stdlib. P.addTempRValueOpt(); - // We earlier eliminated ownership if we are not compiling the stdlib. Now - // handle the stdlib functions. - P.addNonTransparentFunctionOwnershipModelEliminator(); - // Needed to serialize static initializers of globals for cross-module // optimization. P.addGlobalOpt(); + // We earlier eliminated ownership if we are not compiling the stdlib. Now + // handle the stdlib functions. + P.addNonTransparentFunctionOwnershipModelEliminator(); + // Add the outliner pass (Osize). P.addOutliner(); From caca83e9e1b48cf7f643b2bfe02211158bbe761d Mon Sep 17 00:00:00 2001 From: Nate Chandler Date: Wed, 8 Jul 2020 18:01:46 -0700 Subject: [PATCH 12/16] [Test] Un-XFAIL ParseableInterface/verify_all_overlays.py. The test is passing now--un-XFAIL it. rdar://problem/50648519 --- validation-test/ParseableInterface/verify_all_overlays.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/validation-test/ParseableInterface/verify_all_overlays.py b/validation-test/ParseableInterface/verify_all_overlays.py index f17a690c9b46e..631ddd0666a92 100755 --- a/validation-test/ParseableInterface/verify_all_overlays.py +++ b/validation-test/ParseableInterface/verify_all_overlays.py @@ -14,9 +14,6 @@ # REQUIRES: nonexecutable_test -# rdar://problem/50648519 -# XFAIL: asan - # Expected failures by platform # ----------------------------- # macosx: XCTest From 068581cfb17fb2e5ae9e28fdc8fbb6de3590f841 Mon Sep 17 00:00:00 2001 From: Michael Gottesman Date: Wed, 8 Jul 2020 19:11:18 -0700 Subject: [PATCH 13/16] [gardening] Extract out a lambda out of a loop and reduce some indentationwithin it by inverting if statements. NFCI. --- lib/SILOptimizer/Utils/Generics.cpp | 81 +++++++++++++++-------------- 1 file changed, 43 insertions(+), 38 deletions(-) diff --git a/lib/SILOptimizer/Utils/Generics.cpp b/lib/SILOptimizer/Utils/Generics.cpp index 1cfe889d04301..2bde614efd66c 100644 --- a/lib/SILOptimizer/Utils/Generics.cpp +++ b/lib/SILOptimizer/Utils/Generics.cpp @@ -1922,48 +1922,53 @@ prepareCallArguments(ApplySite AI, SILBuilder &Builder, SILLocation Loc = AI.getLoc(); auto substConv = AI.getSubstCalleeConv(); unsigned ArgIdx = AI.getCalleeArgIndexOfFirstAppliedArg(); - for (auto &Op : AI.getArgumentOperands()) { - auto handleConversion = [&]() { - // Rewriting SIL arguments is only for lowered addresses. - if (!substConv.useLoweredAddresses()) - return false; - if (ArgIdx < substConv.getSILArgIndexOfFirstParam()) { - // Handle result arguments. - unsigned formalIdx = - substConv.getIndirectFormalResultIndexForSILArg(ArgIdx); - if (ReInfo.isFormalResultConverted(formalIdx)) { - // The result is converted from indirect to direct. We need to insert - // a store later. - assert(!StoreResultTo); - StoreResultTo = Op.get(); - return true; - } - } else { - // Handle arguments for formal parameters. - unsigned paramIdx = ArgIdx - substConv.getSILArgIndexOfFirstParam(); - if (ReInfo.isParamConverted(paramIdx)) { - // An argument is converted from indirect to direct. Instead of the - // address we pass the loaded value. - auto argConv = substConv.getSILArgumentConvention(ArgIdx); - SILValue Val; - if (!argConv.isGuaranteedConvention() || isa(AI)) { - Val = Builder.emitLoadValueOperation(Loc, Op.get(), - LoadOwnershipQualifier::Take); - } else { - Val = Builder.emitLoadBorrowOperation(Loc, Op.get()); - if (Val.getOwnershipKind() == ValueOwnershipKind::Guaranteed) - ArgAtIndexNeedsEndBorrow.push_back(Arguments.size()); - } - Arguments.push_back(Val); - return true; - } + auto handleConversion = [&](SILValue InputValue) { + // Rewriting SIL arguments is only for lowered addresses. + if (!substConv.useLoweredAddresses()) + return false; + + if (ArgIdx < substConv.getSILArgIndexOfFirstParam()) { + // Handle result arguments. + unsigned formalIdx = + substConv.getIndirectFormalResultIndexForSILArg(ArgIdx); + if (!ReInfo.isFormalResultConverted(formalIdx)) { + return false; } + + // The result is converted from indirect to direct. We need to insert + // a store later. + assert(!StoreResultTo); + StoreResultTo = InputValue; + return true; + } + + // Handle arguments for formal parameters. + unsigned paramIdx = ArgIdx - substConv.getSILArgIndexOfFirstParam(); + if (!ReInfo.isParamConverted(paramIdx)) { return false; - }; - if (!handleConversion()) - Arguments.push_back(Op.get()); + } + + // An argument is converted from indirect to direct. Instead of the + // address we pass the loaded value. + auto argConv = substConv.getSILArgumentConvention(ArgIdx); + SILValue Val; + if (!argConv.isGuaranteedConvention() || isa(AI)) { + Val = Builder.emitLoadValueOperation(Loc, InputValue, + LoadOwnershipQualifier::Take); + } else { + Val = Builder.emitLoadBorrowOperation(Loc, InputValue); + if (Val.getOwnershipKind() == ValueOwnershipKind::Guaranteed) + ArgAtIndexNeedsEndBorrow.push_back(Arguments.size()); + } + + Arguments.push_back(Val); + return true; + }; + for (auto &Op : AI.getArgumentOperands()) { + if (!handleConversion(Op.get())) + Arguments.push_back(Op.get()); ++ArgIdx; } } From ec5da7427b3c346ca67e8d6145704bfb59d9abf5 Mon Sep 17 00:00:00 2001 From: Dan Zheng Date: Wed, 8 Jul 2020 19:26:36 -0700 Subject: [PATCH 14/16] [AutoDiff] NFC: garden differentiation transform. (#32770) Rename "emitters" to "cloners", for consistency: - `JVPEmitter` -> `JVPCloner` - `VJPEmitter` -> `VJPCloner` - `PullbackEmitter` -> `PullbackCloner` Improve `PullbackCloner` documentation. - Document previously undocumented methods. - Update outdated documentation. - For adjoint value accumulation helpers: rename "buffer access" occurrences to "address". Pullback generation no logner uses buffer accesses (`begin_apply`). --- .../{JVPEmitter.h => JVPCloner.h} | 18 +- .../{PullbackEmitter.h => PullbackCloner.h} | 233 +++++----- .../{VJPEmitter.h => VJPCloner.h} | 22 +- .../Differentiation/CMakeLists.txt | 6 +- .../{JVPEmitter.cpp => JVPCloner.cpp} | 115 +++-- ...PullbackEmitter.cpp => PullbackCloner.cpp} | 425 +++++++++--------- .../{VJPEmitter.cpp => VJPCloner.cpp} | 78 ++-- .../Mandatory/Differentiation.cpp | 12 +- 8 files changed, 454 insertions(+), 455 deletions(-) rename include/swift/SILOptimizer/Differentiation/{JVPEmitter.h => JVPCloner.h} (96%) rename include/swift/SILOptimizer/Differentiation/{PullbackEmitter.h => PullbackCloner.h} (80%) rename include/swift/SILOptimizer/Differentiation/{VJPEmitter.h => VJPCloner.h} (89%) rename lib/SILOptimizer/Differentiation/{JVPEmitter.cpp => JVPCloner.cpp} (94%) rename lib/SILOptimizer/Differentiation/{PullbackEmitter.cpp => PullbackCloner.cpp} (89%) rename lib/SILOptimizer/Differentiation/{VJPEmitter.cpp => VJPCloner.cpp} (93%) diff --git a/include/swift/SILOptimizer/Differentiation/JVPEmitter.h b/include/swift/SILOptimizer/Differentiation/JVPCloner.h similarity index 96% rename from include/swift/SILOptimizer/Differentiation/JVPEmitter.h rename to include/swift/SILOptimizer/Differentiation/JVPCloner.h index 935f4dca145de..3e0ba11e5beb6 100644 --- a/include/swift/SILOptimizer/Differentiation/JVPEmitter.h +++ b/include/swift/SILOptimizer/Differentiation/JVPCloner.h @@ -1,4 +1,4 @@ -//===--- JVPEmitter.h - JVP Generation in Differentiation -----*- C++ -*---===// +//===--- JVPCloner.h - JVP function generation ----------------*- C++ -*---===// // // This source file is part of the Swift.org open source project // @@ -15,8 +15,8 @@ // //===----------------------------------------------------------------------===// -#ifndef SWIFT_SILOPTIMIZER_UTILS_DIFFERENTIATION_JVPEMITTER_H -#define SWIFT_SILOPTIMIZER_UTILS_DIFFERENTIATION_JVPEMITTER_H +#ifndef SWIFT_SILOPTIMIZER_UTILS_DIFFERENTIATION_JVPCLONER_H +#define SWIFT_SILOPTIMIZER_UTILS_DIFFERENTIATION_JVPCLONER_H #include "swift/SILOptimizer/Differentiation/AdjointValue.h" #include "swift/SILOptimizer/Differentiation/DifferentiationInvoker.h" @@ -40,8 +40,8 @@ namespace autodiff { class ADContext; -class JVPEmitter final - : public TypeSubstCloner { +class JVPCloner final + : public TypeSubstCloner { private: /// The global context. ADContext &context; @@ -368,9 +368,9 @@ class JVPEmitter final void prepareForDifferentialGeneration(); public: - explicit JVPEmitter(ADContext &context, SILFunction *original, - SILDifferentiabilityWitness *witness, SILFunction *jvp, - DifferentiationInvoker invoker); + explicit JVPCloner(ADContext &context, SILFunction *original, + SILDifferentiabilityWitness *witness, SILFunction *jvp, + DifferentiationInvoker invoker); static SILFunction * createEmptyDifferential(ADContext &context, @@ -411,4 +411,4 @@ class JVPEmitter final } // end namespace autodiff } // end namespace swift -#endif // SWIFT_SILOPTIMIZER_UTILS_DIFFERENTIATION_VJPEMITTER_H +#endif // SWIFT_SILOPTIMIZER_UTILS_DIFFERENTIATION_JVPCLONER_H diff --git a/include/swift/SILOptimizer/Differentiation/PullbackEmitter.h b/include/swift/SILOptimizer/Differentiation/PullbackCloner.h similarity index 80% rename from include/swift/SILOptimizer/Differentiation/PullbackEmitter.h rename to include/swift/SILOptimizer/Differentiation/PullbackCloner.h index ccdfb99768de1..514557c0f03e9 100644 --- a/include/swift/SILOptimizer/Differentiation/PullbackEmitter.h +++ b/include/swift/SILOptimizer/Differentiation/PullbackCloner.h @@ -1,4 +1,4 @@ -//===--- PullbackEmitter.h - Pullback in differentiation ------*- C++ -*---===// +//===--- PullbackCloner.h - Pullback function generation -----*- C++ -*----===// // // This source file is part of the Swift.org open source project // @@ -15,8 +15,8 @@ // //===----------------------------------------------------------------------===// -#ifndef SWIFT_SILOPTIMIZER_UTILS_DIFFERENTIATION_PULLBACKEMITTER_H -#define SWIFT_SILOPTIMIZER_UTILS_DIFFERENTIATION_PULLBACKEMITTER_H +#ifndef SWIFT_SILOPTIMIZER_UTILS_DIFFERENTIATION_PULLBACKCLONER_H +#define SWIFT_SILOPTIMIZER_UTILS_DIFFERENTIATION_PULLBACKCLONER_H #include "swift/SILOptimizer/Analysis/DifferentiableActivityAnalysis.h" #include "swift/SILOptimizer/Differentiation/AdjointValue.h" @@ -36,12 +36,12 @@ class SILInstruction; namespace autodiff { class ADContext; -class VJPEmitter; +class VJPCloner; -class PullbackEmitter final : public SILInstructionVisitor { +class PullbackCloner final : public SILInstructionVisitor { private: - /// The parent VJP emitter. - VJPEmitter &vjpEmitter; + /// The parent VJP cloner. + VJPCloner &vjpCloner; /// Dominance info for the original function. DominanceInfo *domInfo = nullptr; @@ -60,7 +60,7 @@ class PullbackEmitter final : public SILInstructionVisitor { /// adjoint values. llvm::DenseMap, AdjointValue> valueMap; - /// Mapping from original basic blocks and original buffers to corresponding + /// Mapping from original basic blocks and original values to corresponding /// adjoint buffers. llvm::DenseMap, SILValue> bufferMap; @@ -124,7 +124,7 @@ class PullbackEmitter final : public SILInstructionVisitor { const DifferentiableActivityInfo &getActivityInfo() const; public: - explicit PullbackEmitter(VJPEmitter &vjpEmitter); + explicit PullbackCloner(VJPCloner &vjpCloner); private: //--------------------------------------------------------------------------// @@ -134,18 +134,32 @@ class PullbackEmitter final : public SILInstructionVisitor { void initializePullbackStructElements(SILBasicBlock *origBB, SILInstructionResultArray values); + /// Returns the pullback struct element value corresponding to the given + /// original block and pullback struct field. SILValue getPullbackStructElement(SILBasicBlock *origBB, VarDecl *field); //--------------------------------------------------------------------------// - // Adjoint value factory methods + // Type transformer //--------------------------------------------------------------------------// - AdjointValue makeZeroAdjointValue(SILType type); + /// Get the type lowering for the given AST type. + const Lowering::TypeLowering &getTypeLowering(Type type); - AdjointValue makeConcreteAdjointValue(SILValue value); + /// Remap any archetypes into the current function's context. + SILType remapType(SILType ty); - template - AdjointValue makeAggregateAdjointValue(SILType type, EltRange elements); + Optional getTangentSpace(CanType type); + + /// Returns the tangent value category of the given value. + SILValueCategory getTangentValueCategory(SILValue v); + + /// Assuming the given type conforms to `Differentiable` after remapping, + /// returns the associated tangent space type. + SILType getRemappedTangentType(SILType type); + + /// Substitutes all replacement types of the given substitution map using the + /// pullback function's substitution map. + SubstitutionMap remapSubstitutionMap(SubstitutionMap substMap); //--------------------------------------------------------------------------// // Temporary value management @@ -158,97 +172,62 @@ class PullbackEmitter final : public SILInstructionVisitor { void cleanUpTemporariesForBlock(SILBasicBlock *bb, SILLocation loc); //--------------------------------------------------------------------------// - // Symbolic value materializers + // Adjoint value factory methods //--------------------------------------------------------------------------// - /// Materialize an adjoint value. The type of the given adjoint value must be - /// loadable. - SILValue materializeAdjointDirect(AdjointValue val, SILLocation loc); - - /// Materialize an adjoint value indirectly to a SIL buffer. - void materializeAdjointIndirect(AdjointValue val, SILValue destBuffer, - SILLocation loc); - - //--------------------------------------------------------------------------// - // Helpers for symbolic value materializers - //--------------------------------------------------------------------------// + AdjointValue makeZeroAdjointValue(SILType type); - /// Emit a zero value by calling `AdditiveArithmetic.zero`. The given type - /// must conform to `AdditiveArithmetic`. - void emitZeroIndirect(CanType type, SILValue bufferAccess, SILLocation loc); + AdjointValue makeConcreteAdjointValue(SILValue value); - /// Emit a zero value by calling `AdditiveArithmetic.zero`. The given type - /// must conform to `AdditiveArithmetic` and be loadable in SIL. - SILValue emitZeroDirect(CanType type, SILLocation loc); + template + AdjointValue makeAggregateAdjointValue(SILType type, EltRange elements); //--------------------------------------------------------------------------// - // Accumulator + // Adjoint value materialization //--------------------------------------------------------------------------// - /// Materialize an adjoint value in the most efficient way. - SILValue materializeAdjoint(AdjointValue val, SILLocation loc); - - /// Given two adjoint values, accumulate them. - AdjointValue accumulateAdjointsDirect(AdjointValue lhs, AdjointValue rhs, - SILLocation loc); - - /// Given two materialized adjoint values, accumulate them. These two - /// adjoints must be objects of loadable type. - SILValue accumulateDirect(SILValue lhs, SILValue rhs, SILLocation loc); - - /// Given two materialized adjoint values, accumulate them using - /// `AdditiveArithmetic.+`, depending on the differentiation mode. - void accumulateIndirect(SILValue resultBufAccess, SILValue lhsBufAccess, - SILValue rhsBufAccess, SILLocation loc); + /// Materializes an adjoint value. The type of the given adjoint value must be + /// loadable. + SILValue materializeAdjointDirect(AdjointValue val, SILLocation loc); - /// Given two buffers of an `AdditiveArithmetic` type, accumulate the right - /// hand side into the left hand side using `+=`. - void accumulateIndirect(SILValue lhsDestAccess, SILValue rhsAccess, - SILLocation loc); + /// Materializes an adjoint value indirectly to a SIL buffer. + void materializeAdjointIndirect(AdjointValue val, SILValue destBuffer, + SILLocation loc); //--------------------------------------------------------------------------// - // Type transformer + // Helpers for adjoint value materialization //--------------------------------------------------------------------------// - /// Get the type lowering for the given AST type. - const Lowering::TypeLowering &getTypeLowering(Type type); + /// Emits a zero value into the given address by calling + /// `AdditiveArithmetic.zero`. The given type must conform to + /// `AdditiveArithmetic`. + void emitZeroIndirect(CanType type, SILValue address, SILLocation loc); - /// Remap any archetypes into the current function's context. - SILType remapType(SILType ty); - - Optional getTangentSpace(CanType type); - - /// Returns the tangent value category of the given value. - SILValueCategory getTangentValueCategory(SILValue v); - - /// Assuming the given type conforms to `Differentiable` after remapping, - /// returns the associated tangent space type. - SILType getRemappedTangentType(SILType type); - - /// Substitutes all replacement types of the given substitution map using the - /// pullback function's substitution map. - SubstitutionMap remapSubstitutionMap(SubstitutionMap substMap); + /// Emits a zero value by calling `AdditiveArithmetic.zero`. The given type + /// must conform to `AdditiveArithmetic` and be loadable in SIL. + SILValue emitZeroDirect(CanType type, SILLocation loc); //--------------------------------------------------------------------------// - // Managed value mapping + // Adjoint value mapping //--------------------------------------------------------------------------// - /// Returns true if the original value has a corresponding adjoint value. + /// Returns true if the given value in the original function has a + /// corresponding adjoint value. bool hasAdjointValue(SILBasicBlock *origBB, SILValue originalValue) const; - /// Initializes an original value's corresponding adjoint value. It must not - /// have an adjoint value before this function is called. + /// Initializes the adjoint value for the original value. Asserts that the + /// original value does not already have an adjoint value. void setAdjointValue(SILBasicBlock *origBB, SILValue originalValue, AdjointValue adjointValue); - /// Get the adjoint for an original value. The given value must be in the - /// original function. + /// Returns the adjoint value for a value in the original function. /// - /// This method first tries to find an entry in `adjointMap`. If an adjoint - /// doesn't exist, create a zero adjoint. + /// This method first tries to find an existing entry in the adjoint value + /// mapping. If no entry exists, creates a zero adjoint value. AdjointValue getAdjointValue(SILBasicBlock *origBB, SILValue originalValue); - /// Add an adjoint value for the given original value. + /// Adds `newAdjointValue` to the adjoint value for `originalValue` and sets + /// the sum as the new adjoint value. void addAdjointValue(SILBasicBlock *origBB, SILValue originalValue, AdjointValue newAdjointValue, SILLocation loc); @@ -258,30 +237,58 @@ class PullbackEmitter final : public SILInstructionVisitor { SILValue activeValue); //--------------------------------------------------------------------------// - // Buffer mapping + // Adjoint value accumulation //--------------------------------------------------------------------------// - void setAdjointBuffer(SILBasicBlock *origBB, SILValue originalBuffer, - SILValue adjointBuffer); + /// Given two adjoint values, accumulates them and returns their sum. + AdjointValue accumulateAdjointsDirect(AdjointValue lhs, AdjointValue rhs, + SILLocation loc); - SILValue getAdjointProjection(SILBasicBlock *origBB, - SILValue originalProjection); + /// Generates code returning `result = lhs + rhs`. + /// + /// Given two materialized adjoint values, accumulates them and returns their + /// sum. The adjoint values must have a loadable type. + SILValue accumulateDirect(SILValue lhs, SILValue rhs, SILLocation loc); - SILValue &getAdjointBuffer(SILBasicBlock *origBB, SILValue originalBuffer); + /// Generates code for `resultAddress = lhsAddress + rhsAddress`. + /// + /// Given two addresses with the same `AdditiveArithmetic`-conforming type, + /// accumulates them into a result address using `AdditiveArithmetic.+`. + void accumulateIndirect(SILValue resultAddress, SILValue lhsAddress, + SILValue rhsAddress, SILLocation loc); - SILBasicBlock::iterator getNextFunctionLocalAllocationInsertionPoint(); + /// Generates code for `lhsDestAddress += rhsAddress`. + /// + /// Given two addresses with the same `AdditiveArithmetic`-conforming type, + /// accumulates the rhs into the lhs using `AdditiveArithmetic.+=`. + void accumulateIndirect(SILValue lhsDestAddress, SILValue rhsAddress, + SILLocation loc); - /// Creates and returns a local allocation with the given type. + //--------------------------------------------------------------------------// + // Adjoint buffer mapping + //--------------------------------------------------------------------------// + + /// If the given original value is an address projection, returns a + /// corresponding adjoint projection to be used as its adjoint buffer. /// - /// Local allocations are created uninitialized in the pullback entry and - /// deallocated in the pullback exit. All local allocations not in - /// `destroyedLocalAllocations` are also destroyed in the pullback exit. - AllocStackInst *createFunctionLocalAllocation(SILType type, SILLocation loc); + /// Helper function for `getAdjointBuffer`. + SILValue getAdjointProjection(SILBasicBlock *origBB, SILValue originalValue); + + /// Returns the adjoint buffer for the original value. + /// + /// This method first tries to find an existing entry in the adjoint buffer + /// mapping. If no entry exists, creates a zero adjoint buffer. + SILValue &getAdjointBuffer(SILBasicBlock *origBB, SILValue originalValue); - /// Accumulates `rhsBufferAccess` into the adjoint buffer corresponding to - /// `originalBuffer`. - void addToAdjointBuffer(SILBasicBlock *origBB, SILValue originalBuffer, - SILValue rhsBufferAccess, SILLocation loc); + /// Initializes the adjoint buffer for the original value. Asserts that the + /// original value does not already have an adjoint buffer. + void setAdjointBuffer(SILBasicBlock *origBB, SILValue originalValue, + SILValue adjointBuffer); + + /// Accumulates `rhsAddress` into the adjoint buffer corresponding to the + /// original value. + void addToAdjointBuffer(SILBasicBlock *origBB, SILValue originalValue, + SILValue rhsAddress, SILLocation loc); /// Given the adjoint value of an array initialized from an /// `array.uninitialized_intrinsic` application and an array element index, @@ -291,12 +298,28 @@ class PullbackEmitter final : public SILInstructionVisitor { int eltIndex, SILLocation loc); /// Given the adjoint value of an array initialized from an - /// `array.uninitialized_intrinsic` application, accumulate the adjoint + /// `array.uninitialized_intrinsic` application, accumulates the adjoint /// value's elements into the adjoint buffers of its element addresses. void accumulateArrayLiteralElementAddressAdjoints( SILBasicBlock *origBB, SILValue originalValue, AdjointValue arrayAdjointValue, SILLocation loc); + /// Returns a next insertion point for creating a local allocation: either + /// before the previous local allocation, or at the start of the pullback + /// entry if no local allocations exist. + /// + /// Helper for `createFunctionLocalAllocation`. + SILBasicBlock::iterator getNextFunctionLocalAllocationInsertionPoint(); + + /// Creates and returns a local allocation with the given type. + /// + /// Local allocations are created uninitialized in the pullback entry and + /// deallocated in the pullback exit. All local allocations not in + /// `destroyedLocalAllocations` are also destroyed in the pullback exit. + /// + /// Helper for `getAdjointBuffer`. + AllocStackInst *createFunctionLocalAllocation(SILType type, SILLocation loc); + //--------------------------------------------------------------------------// // CFG mapping //--------------------------------------------------------------------------// @@ -346,19 +369,21 @@ class PullbackEmitter final : public SILInstructionVisitor { using TrampolineBlockSet = SmallPtrSet; - /// Determine the pullback successor block for a given original block and one - /// of its predecessors. When a trampoline block is necessary, emit code into + /// Determines the pullback successor block for a given original block and one + /// of its predecessors. When a trampoline block is necessary, emits code into /// the trampoline block to trampoline the original block's active value's - /// adjoint values. A dense map `trampolineArgs` will be populated to keep - /// track of which pullback successor blocks each active value's adjoint value - /// is used, so that we can release those values in pullback successor blocks - /// that are not using them. + /// adjoint values. + /// + /// Populates `pullbackTrampolineBlockMap`, which maps active values' adjoint + /// values to the pullback successor blocks in which they are used. This + /// allows us to release those values in pullback successor blocks that do not + /// use them. SILBasicBlock * buildPullbackSuccessor(SILBasicBlock *origBB, SILBasicBlock *origPredBB, llvm::SmallDenseMap &pullbackTrampolineBlockMap); - /// Emit pullback code in the corresponding pullback block. + /// Emits pullback code in the corresponding pullback block. void visitSILBasicBlock(SILBasicBlock *bb); void visit(SILInstruction *inst); @@ -513,4 +538,4 @@ class PullbackEmitter final : public SILInstructionVisitor { } // end namespace autodiff } // end namespace swift -#endif // SWIFT_SILOPTIMIZER_UTILS_DIFFERENTIATION_PULLBACKEMITTER_H +#endif // SWIFT_SILOPTIMIZER_UTILS_DIFFERENTIATION_PULLBACKCLONER_H diff --git a/include/swift/SILOptimizer/Differentiation/VJPEmitter.h b/include/swift/SILOptimizer/Differentiation/VJPCloner.h similarity index 89% rename from include/swift/SILOptimizer/Differentiation/VJPEmitter.h rename to include/swift/SILOptimizer/Differentiation/VJPCloner.h index db196e7374b89..346bf2f490c80 100644 --- a/include/swift/SILOptimizer/Differentiation/VJPEmitter.h +++ b/include/swift/SILOptimizer/Differentiation/VJPCloner.h @@ -1,4 +1,4 @@ -//===--- VJPEmitter.h - VJP Generation in differentiation -----*- C++ -*---===// +//===--- VJPCloner.h - VJP function generation ----------------*- C++ -*---===// // // This source file is part of the Swift.org open source project // @@ -15,8 +15,8 @@ // //===----------------------------------------------------------------------===// -#ifndef SWIFT_SILOPTIMIZER_UTILS_DIFFERENTIATION_VJPEMITTER_H -#define SWIFT_SILOPTIMIZER_UTILS_DIFFERENTIATION_VJPEMITTER_H +#ifndef SWIFT_SILOPTIMIZER_UTILS_DIFFERENTIATION_VJPCLONER_H +#define SWIFT_SILOPTIMIZER_UTILS_DIFFERENTIATION_VJPCLONER_H #include "swift/SILOptimizer/Analysis/DifferentiableActivityAnalysis.h" #include "swift/SILOptimizer/Differentiation/DifferentiationInvoker.h" @@ -35,11 +35,11 @@ class SILInstruction; namespace autodiff { class ADContext; -class PullbackEmitter; +class PullbackCloner; -class VJPEmitter final - : public TypeSubstCloner { - friend class PullbackEmitter; +class VJPCloner final + : public TypeSubstCloner { + friend class PullbackCloner; private: /// The global context. @@ -90,9 +90,9 @@ class VJPEmitter final SILAutoDiffIndices indices, SILFunction *vjp); public: - explicit VJPEmitter(ADContext &context, SILFunction *original, - SILDifferentiabilityWitness *witness, SILFunction *vjp, - DifferentiationInvoker invoker); + explicit VJPCloner(ADContext &context, SILFunction *original, + SILDifferentiabilityWitness *witness, SILFunction *vjp, + DifferentiationInvoker invoker); SILFunction *createEmptyPullback(); @@ -170,4 +170,4 @@ class VJPEmitter final } // end namespace autodiff } // end namespace swift -#endif // SWIFT_SILOPTIMIZER_UTILS_DIFFERENTIATION_VJPEMITTER_H +#endif // SWIFT_SILOPTIMIZER_UTILS_DIFFERENTIATION_VJPCLONER_H diff --git a/lib/SILOptimizer/Differentiation/CMakeLists.txt b/lib/SILOptimizer/Differentiation/CMakeLists.txt index ce2d9571b95c6..1dc732cd95997 100644 --- a/lib/SILOptimizer/Differentiation/CMakeLists.txt +++ b/lib/SILOptimizer/Differentiation/CMakeLists.txt @@ -2,8 +2,8 @@ target_sources(swiftSILOptimizer PRIVATE ADContext.cpp Common.cpp DifferentiationInvoker.cpp - JVPEmitter.cpp + JVPCloner.cpp LinearMapInfo.cpp - PullbackEmitter.cpp + PullbackCloner.cpp Thunk.cpp - VJPEmitter.cpp) + VJPCloner.cpp) diff --git a/lib/SILOptimizer/Differentiation/JVPEmitter.cpp b/lib/SILOptimizer/Differentiation/JVPCloner.cpp similarity index 94% rename from lib/SILOptimizer/Differentiation/JVPEmitter.cpp rename to lib/SILOptimizer/Differentiation/JVPCloner.cpp index 0855d2d57a988..0b33bbf439b02 100644 --- a/lib/SILOptimizer/Differentiation/JVPEmitter.cpp +++ b/lib/SILOptimizer/Differentiation/JVPCloner.cpp @@ -1,4 +1,4 @@ -//===--- JVPEmitter.cpp - JVP generation in differentiation ---*- C++ -*---===// +//===--- JVPCloner.cpp - JVP function generation --------------*- C++ -*---===// // // This source file is part of the Swift.org open source project // @@ -17,7 +17,7 @@ #define DEBUG_TYPE "differentiation" -#include "swift/SILOptimizer/Differentiation/JVPEmitter.h" +#include "swift/SILOptimizer/Differentiation/JVPCloner.h" #include "swift/SILOptimizer/Differentiation/ADContext.h" #include "swift/SILOptimizer/Differentiation/Thunk.h" @@ -32,8 +32,8 @@ namespace autodiff { //--------------------------------------------------------------------------// /*static*/ -SubstitutionMap JVPEmitter::getSubstitutionMap(SILFunction *original, - SILFunction *jvp) { +SubstitutionMap JVPCloner::getSubstitutionMap(SILFunction *original, + SILFunction *jvp) { auto substMap = original->getForwardingSubstitutionMap(); if (auto *jvpGenEnv = jvp->getGenericEnvironment()) { auto jvpSubstMap = jvpGenEnv->getForwardingSubstitutionMap(); @@ -46,8 +46,8 @@ SubstitutionMap JVPEmitter::getSubstitutionMap(SILFunction *original, /*static*/ const DifferentiableActivityInfo & -JVPEmitter::getActivityInfo(ADContext &context, SILFunction *original, - SILAutoDiffIndices indices, SILFunction *jvp) { +JVPCloner::getActivityInfo(ADContext &context, SILFunction *original, + SILAutoDiffIndices indices, SILFunction *jvp) { // Get activity info of the original function. auto &passManager = context.getPassManager(); auto *activityAnalysis = @@ -60,9 +60,9 @@ JVPEmitter::getActivityInfo(ADContext &context, SILFunction *original, return activityInfo; } -JVPEmitter::JVPEmitter(ADContext &context, SILFunction *original, - SILDifferentiabilityWitness *witness, SILFunction *jvp, - DifferentiationInvoker invoker) +JVPCloner::JVPCloner(ADContext &context, SILFunction *original, + SILDifferentiabilityWitness *witness, SILFunction *jvp, + DifferentiationInvoker invoker) : TypeSubstCloner(*jvp, *original, getSubstitutionMap(original, jvp)), context(context), original(original), witness(witness), jvp(jvp), invoker(invoker), @@ -81,7 +81,7 @@ JVPEmitter::JVPEmitter(ADContext &context, SILFunction *original, // Differential struct mapping //--------------------------------------------------------------------------// -void JVPEmitter::initializeDifferentialStructElements( +void JVPCloner::initializeDifferentialStructElements( SILBasicBlock *origBB, SILInstructionResultArray values) { auto *diffStructDecl = differentialInfo.getLinearMapStruct(origBB); assert(diffStructDecl->getStoredProperties().size() == values.size() && @@ -99,8 +99,8 @@ void JVPEmitter::initializeDifferentialStructElements( } } -SILValue JVPEmitter::getDifferentialStructElement(SILBasicBlock *origBB, - VarDecl *field) { +SILValue JVPCloner::getDifferentialStructElement(SILBasicBlock *origBB, + VarDecl *field) { assert(differentialInfo.getLinearMapStruct(origBB) == cast(field->getDeclContext())); assert(differentialStructElements.count(field) && @@ -113,7 +113,7 @@ SILValue JVPEmitter::getDifferentialStructElement(SILBasicBlock *origBB, //--------------------------------------------------------------------------// SILBasicBlock::iterator -JVPEmitter::getNextDifferentialLocalAllocationInsertionPoint() { +JVPCloner::getNextDifferentialLocalAllocationInsertionPoint() { // If there are no local allocations, insert at the beginning of the tangent // entry. if (differentialLocalAllocations.empty()) @@ -126,20 +126,20 @@ JVPEmitter::getNextDifferentialLocalAllocationInsertionPoint() { return it; } -SILType JVPEmitter::getLoweredType(Type type) { +SILType JVPCloner::getLoweredType(Type type) { auto jvpGenSig = jvp->getLoweredFunctionType()->getSubstGenericSignature(); Lowering::AbstractionPattern pattern(jvpGenSig, type->getCanonicalType(jvpGenSig)); return jvp->getLoweredType(pattern, type); } -SILType JVPEmitter::getNominalDeclLoweredType(NominalTypeDecl *nominal) { +SILType JVPCloner::getNominalDeclLoweredType(NominalTypeDecl *nominal) { auto nominalType = getOpASTType(nominal->getDeclaredInterfaceType()->getCanonicalType()); return getLoweredType(nominalType); } -StructInst *JVPEmitter::buildDifferentialValueStructValue(TermInst *termInst) { +StructInst *JVPCloner::buildDifferentialValueStructValue(TermInst *termInst) { assert(termInst->getFunction() == original); auto loc = termInst->getFunction()->getLocation(); auto *origBB = termInst->getParent(); @@ -160,11 +160,11 @@ StructInst *JVPEmitter::buildDifferentialValueStructValue(TermInst *termInst) { // Tangent value factory methods //--------------------------------------------------------------------------// -AdjointValue JVPEmitter::makeZeroTangentValue(SILType type) { +AdjointValue JVPCloner::makeZeroTangentValue(SILType type) { return AdjointValue::createZero(allocator, remapSILTypeInDifferential(type)); } -AdjointValue JVPEmitter::makeConcreteTangentValue(SILValue value) { +AdjointValue JVPCloner::makeConcreteTangentValue(SILValue value) { return AdjointValue::createConcrete(allocator, value); } @@ -172,8 +172,8 @@ AdjointValue JVPEmitter::makeConcreteTangentValue(SILValue value) { // Tangent materialization //--------------------------------------------------------------------------// -void JVPEmitter::emitZeroIndirect(CanType type, SILValue bufferAccess, - SILLocation loc) { +void JVPCloner::emitZeroIndirect(CanType type, SILValue bufferAccess, + SILLocation loc) { auto builder = getDifferentialBuilder(); auto tangentSpace = getTangentSpace(type); assert(tangentSpace && "No tangent space for this type"); @@ -194,7 +194,7 @@ void JVPEmitter::emitZeroIndirect(CanType type, SILValue bufferAccess, } } -SILValue JVPEmitter::emitZeroDirect(CanType type, SILLocation loc) { +SILValue JVPCloner::emitZeroDirect(CanType type, SILLocation loc) { auto diffBuilder = getDifferentialBuilder(); auto silType = getModule().Types.getLoweredLoadableType( type, TypeExpansionContext::minimal(), getModule()); @@ -206,8 +206,8 @@ SILValue JVPEmitter::emitZeroDirect(CanType type, SILLocation loc) { return loaded; } -SILValue JVPEmitter::materializeTangentDirect(AdjointValue val, - SILLocation loc) { +SILValue JVPCloner::materializeTangentDirect(AdjointValue val, + SILLocation loc) { assert(val.getType().isObject()); LLVM_DEBUG(getADDebugStream() << "Materializing tangents for " << val << '\n'); @@ -225,7 +225,7 @@ SILValue JVPEmitter::materializeTangentDirect(AdjointValue val, llvm_unreachable("invalid value kind"); } -SILValue JVPEmitter::materializeTangent(AdjointValue val, SILLocation loc) { +SILValue JVPCloner::materializeTangent(AdjointValue val, SILLocation loc) { if (val.isConcrete()) { LLVM_DEBUG(getADDebugStream() << "Materializing tangent: Value is concrete.\n"); @@ -240,9 +240,8 @@ SILValue JVPEmitter::materializeTangent(AdjointValue val, SILLocation loc) { // Tangent buffer mapping //--------------------------------------------------------------------------// -void JVPEmitter::setTangentBuffer(SILBasicBlock *origBB, - SILValue originalBuffer, - SILValue tangentBuffer) { +void JVPCloner::setTangentBuffer(SILBasicBlock *origBB, SILValue originalBuffer, + SILValue tangentBuffer) { assert(originalBuffer->getType().isAddress()); auto insertion = bufferMap.try_emplace({origBB, originalBuffer}, tangentBuffer); @@ -250,8 +249,8 @@ void JVPEmitter::setTangentBuffer(SILBasicBlock *origBB, (void)insertion; } -SILValue &JVPEmitter::getTangentBuffer(SILBasicBlock *origBB, - SILValue originalBuffer) { +SILValue &JVPCloner::getTangentBuffer(SILBasicBlock *origBB, + SILValue originalBuffer) { assert(originalBuffer->getType().isAddress()); assert(originalBuffer->getFunction() == original); auto insertion = bufferMap.try_emplace({origBB, originalBuffer}, SILValue()); @@ -266,23 +265,23 @@ SILValue &JVPEmitter::getTangentBuffer(SILBasicBlock *origBB, /// Substitutes all replacement types of the given substitution map using the /// tangent function's substitution map. SubstitutionMap -JVPEmitter::remapSubstitutionMapInDifferential(SubstitutionMap substMap) { +JVPCloner::remapSubstitutionMapInDifferential(SubstitutionMap substMap) { return substMap.subst(getDifferential().getForwardingSubstitutionMap()); } -Type JVPEmitter::remapTypeInDifferential(Type ty) { +Type JVPCloner::remapTypeInDifferential(Type ty) { if (ty->hasArchetype()) return getDifferential().mapTypeIntoContext(ty->mapTypeOutOfContext()); return getDifferential().mapTypeIntoContext(ty); } -SILType JVPEmitter::remapSILTypeInDifferential(SILType ty) { +SILType JVPCloner::remapSILTypeInDifferential(SILType ty) { if (ty.hasArchetype()) return getDifferential().mapTypeIntoContext(ty.mapTypeOutOfContext()); return getDifferential().mapTypeIntoContext(ty); } -Optional JVPEmitter::getTangentSpace(CanType type) { +Optional JVPCloner::getTangentSpace(CanType type) { // Use witness generic signature to remap types. if (auto witnessGenSig = witness->getDerivativeGenericSignature()) type = witnessGenSig->getCanonicalTypeInContext(type); @@ -290,7 +289,7 @@ Optional JVPEmitter::getTangentSpace(CanType type) { LookUpConformanceInModule(getModule().getSwiftModule())); } -SILType JVPEmitter::getRemappedTangentType(SILType type) { +SILType JVPCloner::getRemappedTangentType(SILType type) { return SILType::getPrimitiveType( getTangentSpace(remapSILTypeInDifferential(type).getASTType()) ->getCanonicalType(), @@ -301,7 +300,7 @@ SILType JVPEmitter::getRemappedTangentType(SILType type) { // Tangent value mapping //--------------------------------------------------------------------------// -AdjointValue JVPEmitter::getTangentValue(SILValue originalValue) { +AdjointValue JVPCloner::getTangentValue(SILValue originalValue) { assert(originalValue->getType().isObject()); assert(originalValue->getFunction() == original); auto insertion = tangentValueMap.try_emplace( @@ -310,8 +309,8 @@ AdjointValue JVPEmitter::getTangentValue(SILValue originalValue) { return insertion.first->getSecond(); } -void JVPEmitter::setTangentValue(SILBasicBlock *origBB, SILValue originalValue, - AdjointValue newTangentValue) { +void JVPCloner::setTangentValue(SILBasicBlock *origBB, SILValue originalValue, + AdjointValue newTangentValue) { #ifndef NDEBUG if (auto *defInst = originalValue->getDefiningInstruction()) { bool isTupleTypedApplyResult = @@ -339,12 +338,12 @@ void JVPEmitter::setTangentValue(SILBasicBlock *origBB, SILValue originalValue, //--------------------------------------------------------------------------// #define CLONE_AND_EMIT_TANGENT(INST, ID) \ - void JVPEmitter::visit##INST##Inst(INST##Inst *inst) { \ + void JVPCloner::visit##INST##Inst(INST##Inst *inst) { \ TypeSubstCloner::visit##INST##Inst(inst); \ if (differentialInfo.shouldDifferentiateInstruction(inst)) \ emitTangentFor##INST##Inst(inst); \ } \ - void JVPEmitter::emitTangentFor##INST##Inst(INST##Inst *(ID)) + void JVPCloner::emitTangentFor##INST##Inst(INST##Inst *(ID)) CLONE_AND_EMIT_TANGENT(BeginBorrow, bbi) { auto &diffBuilder = getDifferentialBuilder(); @@ -711,7 +710,7 @@ CLONE_AND_EMIT_TANGENT(DestructureTuple, dti) { /// Handle `apply` instruction. /// Original: y = apply f(x0, x1, ...) /// Tangent: tan[y] = apply diff_f(tan[x0], tan[x1], ...) -void JVPEmitter::emitTangentForApplyInst( +void JVPCloner::emitTangentForApplyInst( ApplyInst *ai, SILAutoDiffIndices applyIndices, CanSILFunctionType originalDifferentialType) { assert(differentialInfo.shouldDifferentiateApplySite(ai)); @@ -837,7 +836,7 @@ void JVPEmitter::emitTangentForApplyInst( } /// Generate a `return` instruction in the current differential basic block. -void JVPEmitter::emitReturnInstForDifferential() { +void JVPCloner::emitReturnInstForDifferential() { auto &differential = getDifferential(); auto diffLoc = differential.getLocation(); auto &diffBuilder = getDifferentialBuilder(); @@ -860,7 +859,7 @@ void JVPEmitter::emitReturnInstForDifferential() { joinElements(retElts, diffBuilder, diffLoc)); } -void JVPEmitter::prepareForDifferentialGeneration() { +void JVPCloner::prepareForDifferentialGeneration() { // Create differential blocks and arguments. auto &differential = getDifferential(); auto *origEntry = original->getEntryBlock(); @@ -959,9 +958,9 @@ void JVPEmitter::prepareForDifferentialGeneration() { } /*static*/ SILFunction * -JVPEmitter::createEmptyDifferential(ADContext &context, - SILDifferentiabilityWitness *witness, - LinearMapInfo *linearMapInfo) { +JVPCloner::createEmptyDifferential(ADContext &context, + SILDifferentiabilityWitness *witness, + LinearMapInfo *linearMapInfo) { auto &module = context.getModule(); auto *original = witness->getOriginalFunction(); auto *jvp = witness->getJVP(); @@ -1069,7 +1068,7 @@ JVPEmitter::createEmptyDifferential(ADContext &context, } /// Run JVP generation. Returns true on error. -bool JVPEmitter::run() { +bool JVPCloner::run() { PrettyStackTraceSILFunction trace("generating JVP and differential for", original); LLVM_DEBUG(getADDebugStream() << "Cloning original @" << original->getName() @@ -1095,23 +1094,23 @@ bool JVPEmitter::run() { return errorOccurred; } -void JVPEmitter::postProcess(SILInstruction *orig, SILInstruction *cloned) { +void JVPCloner::postProcess(SILInstruction *orig, SILInstruction *cloned) { if (errorOccurred) return; SILClonerWithScopes::postProcess(orig, cloned); } /// Remap original basic blocks. -SILBasicBlock *JVPEmitter::remapBasicBlock(SILBasicBlock *bb) { +SILBasicBlock *JVPCloner::remapBasicBlock(SILBasicBlock *bb) { auto *jvpBB = BBMap[bb]; return jvpBB; } -void JVPEmitter::visit(SILInstruction *inst) { +void JVPCloner::visit(SILInstruction *inst) { if (errorOccurred) return; if (differentialInfo.shouldDifferentiateInstruction(inst)) { - LLVM_DEBUG(getADDebugStream() << "JVPEmitter visited:\n[ORIG]" << *inst); + LLVM_DEBUG(getADDebugStream() << "JVPCloner visited:\n[ORIG]" << *inst); #ifndef NDEBUG auto diffBuilder = getDifferentialBuilder(); auto beforeInsertion = std::prev(diffBuilder.getInsertionPoint()); @@ -1128,13 +1127,13 @@ void JVPEmitter::visit(SILInstruction *inst) { } } -void JVPEmitter::visitSILInstruction(SILInstruction *inst) { +void JVPCloner::visitSILInstruction(SILInstruction *inst) { context.emitNondifferentiabilityError( inst, invoker, diag::autodiff_expression_not_differentiable_note); errorOccurred = true; } -void JVPEmitter::visitInstructionsInBlock(SILBasicBlock *bb) { +void JVPCloner::visitInstructionsInBlock(SILBasicBlock *bb) { // Destructure the differential struct to get the elements. auto &diffBuilder = getDifferentialBuilder(); auto diffLoc = getDifferential().getLocation(); @@ -1149,7 +1148,7 @@ void JVPEmitter::visitInstructionsInBlock(SILBasicBlock *bb) { // If an `apply` has active results or active inout parameters, replace it // with an `apply` of its JVP. -void JVPEmitter::visitApplyInst(ApplyInst *ai) { +void JVPCloner::visitApplyInst(ApplyInst *ai) { // If the function should not be differentiated or its the array literal // initialization intrinsic, just do standard cloning. if (!differentialInfo.shouldDifferentiateApplySite(ai) || @@ -1375,7 +1374,7 @@ void JVPEmitter::visitApplyInst(ApplyInst *ai) { emitTangentForApplyInst(ai, indices, originalDifferentialType); } -void JVPEmitter::visitReturnInst(ReturnInst *ri) { +void JVPCloner::visitReturnInst(ReturnInst *ri) { auto loc = ri->getOperand().getLoc(); auto *origExit = ri->getParent(); auto &builder = getBuilder(); @@ -1432,19 +1431,19 @@ void JVPEmitter::visitReturnInst(ReturnInst *ri) { builder.createReturn(ri->getLoc(), joinElements(directResults, builder, loc)); } -void JVPEmitter::visitBranchInst(BranchInst *bi) { +void JVPCloner::visitBranchInst(BranchInst *bi) { llvm_unreachable("Unsupported SIL instruction."); } -void JVPEmitter::visitCondBranchInst(CondBranchInst *cbi) { +void JVPCloner::visitCondBranchInst(CondBranchInst *cbi) { llvm_unreachable("Unsupported SIL instruction."); } -void JVPEmitter::visitSwitchEnumInst(SwitchEnumInst *sei) { +void JVPCloner::visitSwitchEnumInst(SwitchEnumInst *sei) { llvm_unreachable("Unsupported SIL instruction."); } -void JVPEmitter::visitDifferentiableFunctionInst( +void JVPCloner::visitDifferentiableFunctionInst( DifferentiableFunctionInst *dfi) { // Clone `differentiable_function` from original to JVP, then add the cloned // instruction to the `differentiable_function` worklist. diff --git a/lib/SILOptimizer/Differentiation/PullbackEmitter.cpp b/lib/SILOptimizer/Differentiation/PullbackCloner.cpp similarity index 89% rename from lib/SILOptimizer/Differentiation/PullbackEmitter.cpp rename to lib/SILOptimizer/Differentiation/PullbackCloner.cpp index da0da45b261c9..81e999aabc514 100644 --- a/lib/SILOptimizer/Differentiation/PullbackEmitter.cpp +++ b/lib/SILOptimizer/Differentiation/PullbackCloner.cpp @@ -1,4 +1,4 @@ -//===--- PullbackEmitter.cpp - Pullback in differentiation ----*- C++ -*---===// +//===--- PullbackCloner.cpp - Pullback function generation ---*- C++ -*----===// // // This source file is part of the Swift.org open source project // @@ -17,10 +17,10 @@ #define DEBUG_TYPE "differentiation" -#include "swift/SILOptimizer/Differentiation/PullbackEmitter.h" +#include "swift/SILOptimizer/Differentiation/PullbackCloner.h" #include "swift/SILOptimizer/Differentiation/ADContext.h" #include "swift/SILOptimizer/Differentiation/Thunk.h" -#include "swift/SILOptimizer/Differentiation/VJPEmitter.h" +#include "swift/SILOptimizer/Differentiation/VJPCloner.h" #include "swift/AST/Expr.h" #include "swift/AST/PropertyWrappers.h" @@ -40,64 +40,60 @@ class SILInstruction; namespace autodiff { class ADContext; -class VJPEmitter; +class VJPCloner; -PullbackEmitter::PullbackEmitter(VJPEmitter &vjpEmitter) - : vjpEmitter(vjpEmitter), builder(getPullback()), +PullbackCloner::PullbackCloner(VJPCloner &vjpCloner) + : vjpCloner(vjpCloner), builder(getPullback()), localAllocBuilder(getPullback()) { // Get dominance and post-order info for the original function. auto &passManager = getContext().getPassManager(); auto *domAnalysis = passManager.getAnalysis(); auto *postDomAnalysis = passManager.getAnalysis(); auto *postOrderAnalysis = passManager.getAnalysis(); - domInfo = domAnalysis->get(vjpEmitter.original); - postDomInfo = postDomAnalysis->get(vjpEmitter.original); - postOrderInfo = postOrderAnalysis->get(vjpEmitter.original); + domInfo = domAnalysis->get(vjpCloner.original); + postDomInfo = postDomAnalysis->get(vjpCloner.original); + postOrderInfo = postOrderAnalysis->get(vjpCloner.original); } -ADContext &PullbackEmitter::getContext() const { return vjpEmitter.context; } +ADContext &PullbackCloner::getContext() const { return vjpCloner.context; } -SILModule &PullbackEmitter::getModule() const { +SILModule &PullbackCloner::getModule() const { return getContext().getModule(); } -ASTContext &PullbackEmitter::getASTContext() const { +ASTContext &PullbackCloner::getASTContext() const { return getPullback().getASTContext(); } -SILFunction &PullbackEmitter::getOriginal() const { - return *vjpEmitter.original; -} +SILFunction &PullbackCloner::getOriginal() const { return *vjpCloner.original; } -SILFunction &PullbackEmitter::getPullback() const { - return *vjpEmitter.pullback; -} +SILFunction &PullbackCloner::getPullback() const { return *vjpCloner.pullback; } -SILDifferentiabilityWitness *PullbackEmitter::getWitness() const { - return vjpEmitter.witness; +SILDifferentiabilityWitness *PullbackCloner::getWitness() const { + return vjpCloner.witness; } -DifferentiationInvoker PullbackEmitter::getInvoker() const { - return vjpEmitter.invoker; +DifferentiationInvoker PullbackCloner::getInvoker() const { + return vjpCloner.invoker; } -LinearMapInfo &PullbackEmitter::getPullbackInfo() { - return vjpEmitter.pullbackInfo; +LinearMapInfo &PullbackCloner::getPullbackInfo() { + return vjpCloner.pullbackInfo; } -const SILAutoDiffIndices PullbackEmitter::getIndices() const { - return vjpEmitter.getIndices(); +const SILAutoDiffIndices PullbackCloner::getIndices() const { + return vjpCloner.getIndices(); } -const DifferentiableActivityInfo &PullbackEmitter::getActivityInfo() const { - return vjpEmitter.activityInfo; +const DifferentiableActivityInfo &PullbackCloner::getActivityInfo() const { + return vjpCloner.activityInfo; } //--------------------------------------------------------------------------// // Pullback struct mapping //--------------------------------------------------------------------------// -void PullbackEmitter::initializePullbackStructElements( +void PullbackCloner::initializePullbackStructElements( SILBasicBlock *origBB, SILInstructionResultArray values) { auto *pbStructDecl = getPullbackInfo().getLinearMapStruct(origBB); assert(pbStructDecl->getStoredProperties().size() == values.size() && @@ -114,8 +110,8 @@ void PullbackEmitter::initializePullbackStructElements( } } -SILValue PullbackEmitter::getPullbackStructElement(SILBasicBlock *origBB, - VarDecl *field) { +SILValue PullbackCloner::getPullbackStructElement(SILBasicBlock *origBB, + VarDecl *field) { assert(getPullbackInfo().getLinearMapStruct(origBB) == cast(field->getDeclContext())); assert(pullbackStructElements.count(field) && @@ -127,7 +123,7 @@ SILValue PullbackEmitter::getPullbackStructElement(SILBasicBlock *origBB, // Temporary value management //--------------------------------------------------------------------------// -SILValue PullbackEmitter::recordTemporary(SILValue value) { +SILValue PullbackCloner::recordTemporary(SILValue value) { assert(value->getType().isObject()); assert(value->getFunction() == &getPullback()); auto inserted = blockTemporaries[value->getParentBlock()].insert(value); @@ -137,8 +133,8 @@ SILValue PullbackEmitter::recordTemporary(SILValue value) { return value; } -void PullbackEmitter::cleanUpTemporariesForBlock(SILBasicBlock *bb, - SILLocation loc) { +void PullbackCloner::cleanUpTemporariesForBlock(SILBasicBlock *bb, + SILLocation loc) { assert(bb->getParent() == &getPullback()); LLVM_DEBUG(getADDebugStream() << "Cleaning up temporaries for pullback bb" << bb->getDebugID() << '\n'); @@ -151,7 +147,7 @@ void PullbackEmitter::cleanUpTemporariesForBlock(SILBasicBlock *bb, // Type transformer //--------------------------------------------------------------------------// -const Lowering::TypeLowering &PullbackEmitter::getTypeLowering(Type type) { +const Lowering::TypeLowering &PullbackCloner::getTypeLowering(Type type) { auto pbGenSig = getPullback().getLoweredFunctionType()->getSubstGenericSignature(); Lowering::AbstractionPattern pattern(pbGenSig, @@ -160,7 +156,7 @@ const Lowering::TypeLowering &PullbackEmitter::getTypeLowering(Type type) { } /// Remap any archetypes into the current function's context. -SILType PullbackEmitter::remapType(SILType ty) { +SILType PullbackCloner::remapType(SILType ty) { if (ty.hasArchetype()) ty = ty.mapTypeOutOfContext(); auto remappedType = ty.getASTType()->getCanonicalType( @@ -170,7 +166,7 @@ SILType PullbackEmitter::remapType(SILType ty) { return getPullback().mapTypeIntoContext(remappedSILType); } -Optional PullbackEmitter::getTangentSpace(CanType type) { +Optional PullbackCloner::getTangentSpace(CanType type) { // Use witness generic signature to remap types. if (auto witnessGenSig = getWitness()->getDerivativeGenericSignature()) type = witnessGenSig->getCanonicalTypeInContext(type); @@ -178,7 +174,7 @@ Optional PullbackEmitter::getTangentSpace(CanType type) { LookUpConformanceInModule(getModule().getSwiftModule())); } -SILValueCategory PullbackEmitter::getTangentValueCategory(SILValue v) { +SILValueCategory PullbackCloner::getTangentValueCategory(SILValue v) { // Tangent value category table: // // Let $L be a loadable type and $*A be an address-only type. @@ -210,31 +206,30 @@ SILValueCategory PullbackEmitter::getTangentValueCategory(SILValue v) { return SILValueCategory::Address; } -SILType PullbackEmitter::getRemappedTangentType(SILType type) { +SILType PullbackCloner::getRemappedTangentType(SILType type) { return SILType::getPrimitiveType( getTangentSpace(remapType(type).getASTType())->getCanonicalType(), type.getCategory()); } -SubstitutionMap -PullbackEmitter::remapSubstitutionMap(SubstitutionMap substMap) { +SubstitutionMap PullbackCloner::remapSubstitutionMap(SubstitutionMap substMap) { return substMap.subst(getPullback().getForwardingSubstitutionMap()); } //--------------------------------------------------------------------------// -// Managed value mapping +// Adjoint value mapping //--------------------------------------------------------------------------// -bool PullbackEmitter::hasAdjointValue(SILBasicBlock *origBB, - SILValue originalValue) const { +bool PullbackCloner::hasAdjointValue(SILBasicBlock *origBB, + SILValue originalValue) const { assert(origBB->getParent() == &getOriginal()); assert(originalValue->getType().isObject()); return valueMap.count({origBB, originalValue}); } -void PullbackEmitter::setAdjointValue(SILBasicBlock *origBB, - SILValue originalValue, - AdjointValue adjointValue) { +void PullbackCloner::setAdjointValue(SILBasicBlock *origBB, + SILValue originalValue, + AdjointValue adjointValue) { LLVM_DEBUG(getADDebugStream() << "Setting adjoint value for " << originalValue); assert(origBB->getParent() == &getOriginal()); @@ -253,8 +248,8 @@ void PullbackEmitter::setAdjointValue(SILBasicBlock *origBB, insertion.first->getSecond() = adjointValue; } -AdjointValue PullbackEmitter::getAdjointValue(SILBasicBlock *origBB, - SILValue originalValue) { +AdjointValue PullbackCloner::getAdjointValue(SILBasicBlock *origBB, + SILValue originalValue) { assert(origBB->getParent() == &getOriginal()); assert(originalValue->getType().isObject()); assert(getTangentValueCategory(originalValue) == SILValueCategory::Object); @@ -266,10 +261,10 @@ AdjointValue PullbackEmitter::getAdjointValue(SILBasicBlock *origBB, return it->getSecond(); } -void PullbackEmitter::addAdjointValue(SILBasicBlock *origBB, - SILValue originalValue, - AdjointValue newAdjointValue, - SILLocation loc) { +void PullbackCloner::addAdjointValue(SILBasicBlock *origBB, + SILValue originalValue, + AdjointValue newAdjointValue, + SILLocation loc) { assert(origBB->getParent() == &getOriginal()); assert(originalValue->getType().isObject()); assert(newAdjointValue.getType().isObject()); @@ -298,7 +293,7 @@ void PullbackEmitter::addAdjointValue(SILBasicBlock *origBB, setAdjointValue(origBB, originalValue, adjVal); } -void PullbackEmitter::accumulateArrayLiteralElementAddressAdjoints( +void PullbackCloner::accumulateArrayLiteralElementAddressAdjoints( SILBasicBlock *origBB, SILValue originalValue, AdjointValue arrayAdjointValue, SILLocation loc) { // Return if the original value is not the `Array` result of an @@ -339,8 +334,8 @@ void PullbackEmitter::accumulateArrayLiteralElementAddressAdjoints( } SILArgument * -PullbackEmitter::getActiveValuePullbackBlockArgument(SILBasicBlock *origBB, - SILValue activeValue) { +PullbackCloner::getActiveValuePullbackBlockArgument(SILBasicBlock *origBB, + SILValue activeValue) { assert(getTangentValueCategory(activeValue) == SILValueCategory::Object); assert(origBB->getParent() == &getOriginal()); auto pullbackBBArg = activeValuePullbackBBArgumentMap[{origBB, activeValue}]; @@ -350,21 +345,11 @@ PullbackEmitter::getActiveValuePullbackBlockArgument(SILBasicBlock *origBB, } //--------------------------------------------------------------------------// -// Buffer mapping +// Adjoint buffer mapping //--------------------------------------------------------------------------// -void PullbackEmitter::setAdjointBuffer(SILBasicBlock *origBB, - SILValue originalValue, - SILValue adjointBuffer) { - assert(getTangentValueCategory(originalValue) == SILValueCategory::Address); - auto insertion = - bufferMap.try_emplace({origBB, originalValue}, adjointBuffer); - assert(insertion.second); - (void)insertion; -} - -SILValue PullbackEmitter::getAdjointProjection(SILBasicBlock *origBB, - SILValue originalProjection) { +SILValue PullbackCloner::getAdjointProjection(SILBasicBlock *origBB, + SILValue originalProjection) { // Handle `struct_element_addr`. // Adjoint projection: a `struct_element_addr` into the base adjoint buffer. if (auto *seai = dyn_cast(originalProjection)) { @@ -492,8 +477,58 @@ SILValue PullbackEmitter::getAdjointProjection(SILBasicBlock *origBB, return SILValue(); } +SILValue &PullbackCloner::getAdjointBuffer(SILBasicBlock *origBB, + SILValue originalValue) { + assert(getTangentValueCategory(originalValue) == SILValueCategory::Address); + assert(originalValue->getFunction() == &getOriginal()); + auto insertion = bufferMap.try_emplace({origBB, originalValue}, SILValue()); + if (!insertion.second) // not inserted + return insertion.first->getSecond(); + + // If the original buffer is a projection, return a corresponding projection + // into the adjoint buffer. + if (auto adjProj = getAdjointProjection(origBB, originalValue)) + return (bufferMap[{origBB, originalValue}] = adjProj); + + auto bufType = getRemappedTangentType(originalValue->getType()); + // Set insertion point for local allocation builder: before the last local + // allocation, or at the start of the pullback function's entry if no local + // allocations exist yet. + auto *newBuf = createFunctionLocalAllocation( + bufType, RegularLocation::getAutoGeneratedLocation()); + // Temporarily change global builder insertion point and emit zero into the + // local allocation. + auto insertionPoint = builder.getInsertionBB(); + builder.setInsertionPoint(localAllocBuilder.getInsertionBB(), + localAllocBuilder.getInsertionPoint()); + emitZeroIndirect(bufType.getASTType(), newBuf, newBuf->getLoc()); + builder.setInsertionPoint(insertionPoint); + return (insertion.first->getSecond() = newBuf); +} + +void PullbackCloner::setAdjointBuffer(SILBasicBlock *origBB, + SILValue originalValue, + SILValue adjointBuffer) { + assert(getTangentValueCategory(originalValue) == SILValueCategory::Address); + auto insertion = + bufferMap.try_emplace({origBB, originalValue}, adjointBuffer); + assert(insertion.second && "Adjoint buffer already exists"); + (void)insertion; +} + +void PullbackCloner::addToAdjointBuffer(SILBasicBlock *origBB, + SILValue originalValue, + SILValue rhsAddress, SILLocation loc) { + assert(getTangentValueCategory(originalValue) == SILValueCategory::Address && + rhsAddress->getType().isAddress()); + assert(originalValue->getFunction() == &getOriginal()); + assert(rhsAddress->getFunction() == &getPullback()); + auto adjointBuffer = getAdjointBuffer(origBB, originalValue); + accumulateIndirect(adjointBuffer, rhsAddress, loc); +} + SILBasicBlock::iterator -PullbackEmitter::getNextFunctionLocalAllocationInsertionPoint() { +PullbackCloner::getNextFunctionLocalAllocationInsertionPoint() { // If there are no local allocations, insert at the pullback entry start. if (functionLocalAllocations.empty()) return getPullback().getEntryBlock()->begin(); @@ -504,8 +539,8 @@ PullbackEmitter::getNextFunctionLocalAllocationInsertionPoint() { return lastLocalAlloc->getDefiningInstruction()->getIterator(); } -AllocStackInst * -PullbackEmitter::createFunctionLocalAllocation(SILType type, SILLocation loc) { +AllocStackInst *PullbackCloner::createFunctionLocalAllocation(SILType type, + SILLocation loc) { // Set insertion point for local allocation builder: before the last local // allocation, or at the start of the pullback function's entry if no local // allocations exist yet. @@ -518,52 +553,11 @@ PullbackEmitter::createFunctionLocalAllocation(SILType type, SILLocation loc) { return alloc; } -SILValue &PullbackEmitter::getAdjointBuffer(SILBasicBlock *origBB, - SILValue originalBuffer) { - assert(getTangentValueCategory(originalBuffer) == SILValueCategory::Address); - assert(originalBuffer->getFunction() == &getOriginal()); - auto insertion = bufferMap.try_emplace({origBB, originalBuffer}, SILValue()); - if (!insertion.second) // not inserted - return insertion.first->getSecond(); - - // If the original buffer is a projection, return a corresponding projection - // into the adjoint buffer. - if (auto adjProj = getAdjointProjection(origBB, originalBuffer)) - return (bufferMap[{origBB, originalBuffer}] = adjProj); - - auto bufObjectType = getRemappedTangentType(originalBuffer->getType()); - // Set insertion point for local allocation builder: before the last local - // allocation, or at the start of the pullback function's entry if no local - // allocations exist yet. - auto *newBuf = createFunctionLocalAllocation( - bufObjectType, RegularLocation::getAutoGeneratedLocation()); - // Temporarily change global builder insertion point and emit zero into the - // local allocation. - auto insertionPoint = builder.getInsertionBB(); - builder.setInsertionPoint(localAllocBuilder.getInsertionBB(), - localAllocBuilder.getInsertionPoint()); - emitZeroIndirect(bufObjectType.getASTType(), newBuf, newBuf->getLoc()); - builder.setInsertionPoint(insertionPoint); - return (insertion.first->getSecond() = newBuf); -} - -void PullbackEmitter::addToAdjointBuffer(SILBasicBlock *origBB, - SILValue originalBuffer, - SILValue rhsBufferAccess, - SILLocation loc) { - assert(getTangentValueCategory(originalBuffer) == SILValueCategory::Address && - rhsBufferAccess->getType().isAddress()); - assert(originalBuffer->getFunction() == &getOriginal()); - assert(rhsBufferAccess->getFunction() == &getPullback()); - auto adjointBuffer = getAdjointBuffer(origBB, originalBuffer); - accumulateIndirect(adjointBuffer, rhsBufferAccess, loc); -} - //--------------------------------------------------------------------------// // Debugging utilities //--------------------------------------------------------------------------// -void PullbackEmitter::printAdjointValueMapping() { +void PullbackCloner::printAdjointValueMapping() { // Group original/adjoint values by basic block. llvm::DenseMap> tmp; for (auto pair : valueMap) { @@ -591,7 +585,7 @@ void PullbackEmitter::printAdjointValueMapping() { } } -void PullbackEmitter::printAdjointBufferMapping() { +void PullbackCloner::printAdjointBufferMapping() { // Group original/adjoint buffers by basic block. llvm::DenseMap> tmp; for (auto pair : bufferMap) { @@ -623,7 +617,7 @@ void PullbackEmitter::printAdjointBufferMapping() { // Member accessor pullback generation //--------------------------------------------------------------------------// -bool PullbackEmitter::runForSemanticMemberAccessor() { +bool PullbackCloner::runForSemanticMemberAccessor() { auto &original = getOriginal(); auto *accessor = cast(original.getDeclContext()->getAsDecl()); switch (accessor->getAccessorKind()) { @@ -638,7 +632,7 @@ bool PullbackEmitter::runForSemanticMemberAccessor() { } } -bool PullbackEmitter::runForSemanticMemberGetter() { +bool PullbackCloner::runForSemanticMemberGetter() { auto &original = getOriginal(); auto &pullback = getPullback(); auto pbLoc = getPullback().getLocation(); @@ -752,7 +746,7 @@ bool PullbackEmitter::runForSemanticMemberGetter() { return false; } -bool PullbackEmitter::runForSemanticMemberSetter() { +bool PullbackCloner::runForSemanticMemberSetter() { auto &original = getOriginal(); auto &pullback = getPullback(); auto pbLoc = getPullback().getLocation(); @@ -813,12 +807,12 @@ bool PullbackEmitter::runForSemanticMemberSetter() { // Entry point //--------------------------------------------------------------------------// -bool PullbackEmitter::run() { +bool PullbackCloner::run() { PrettyStackTraceSILFunction trace("generating pullback for", &getOriginal()); auto &original = getOriginal(); auto &pullback = getPullback(); auto pbLoc = getPullback().getLocation(); - LLVM_DEBUG(getADDebugStream() << "Running PullbackEmitter on\n" << original); + LLVM_DEBUG(getADDebugStream() << "Running PullbackCloner on\n" << original); auto origExitIt = original.findReturnBB(); assert(origExitIt != original.end() && @@ -1171,7 +1165,7 @@ bool PullbackEmitter::run() { return errorOccurred; } -void PullbackEmitter::emitZeroDerivativesForNonvariedResult( +void PullbackCloner::emitZeroDerivativesForNonvariedResult( SILValue origNonvariedResult) { auto &pullback = getPullback(); auto pbLoc = getPullback().getLocation(); @@ -1214,7 +1208,7 @@ void PullbackEmitter::emitZeroDerivativesForNonvariedResult( << pullback); } -SILBasicBlock *PullbackEmitter::buildPullbackSuccessor( +SILBasicBlock *PullbackCloner::buildPullbackSuccessor( SILBasicBlock *origBB, SILBasicBlock *origPredBB, SmallDenseMap &pullbackTrampolineBlockMap) { // Get the pullback block and optional pullback trampoline block of the @@ -1300,7 +1294,7 @@ SILBasicBlock *PullbackEmitter::buildPullbackSuccessor( return pullbackTrampolineBB; } -void PullbackEmitter::visitSILBasicBlock(SILBasicBlock *bb) { +void PullbackCloner::visitSILBasicBlock(SILBasicBlock *bb) { auto pbLoc = getPullback().getLocation(); // Get the corresponding pullback basic block. auto *pbBB = getPullbackBlock(bb); @@ -1329,7 +1323,7 @@ void PullbackEmitter::visitSILBasicBlock(SILBasicBlock *bb) { // Emit a branching terminator for the block. // If the original block is the original entry, then the pullback block is - // the pullback exit. This is handled specially in `PullbackEmitter::run()`, + // the pullback exit. This is handled specially in `PullbackCloner::run()`, // so we leave the block non-terminated. if (bb->isEntry()) return; @@ -1354,7 +1348,7 @@ void PullbackEmitter::visitSILBasicBlock(SILBasicBlock *bb) { // Materialize adjoint value of active basic block argument, create a // copy, and set copy as adjoint value of incoming values. auto bbArgAdj = getAdjointValue(bb, bbArg); - auto concreteBBArgAdj = materializeAdjoint(bbArgAdj, pbLoc); + auto concreteBBArgAdj = materializeAdjointDirect(bbArgAdj, pbLoc); auto concreteBBArgAdjCopy = builder.emitCopyValueOperation(pbLoc, concreteBBArgAdj); for (auto pair : incomingValues) { @@ -1407,11 +1401,11 @@ void PullbackEmitter::visitSILBasicBlock(SILBasicBlock *bb) { pullbackSuccessorCases); } -void PullbackEmitter::visit(SILInstruction *inst) { +void PullbackCloner::visit(SILInstruction *inst) { if (errorOccurred) return; - LLVM_DEBUG(getADDebugStream() << "PullbackEmitter visited:\n[ORIG]" << *inst); + LLVM_DEBUG(getADDebugStream() << "PullbackCloner visited:\n[ORIG]" << *inst); #ifndef NDEBUG auto beforeInsertion = std::prev(builder.getInsertionPoint()); #endif @@ -1424,17 +1418,17 @@ void PullbackEmitter::visit(SILInstruction *inst) { }); } -void PullbackEmitter::visitSILInstruction(SILInstruction *inst) { +void PullbackCloner::visitSILInstruction(SILInstruction *inst) { LLVM_DEBUG(getADDebugStream() - << "Unhandled instruction in PullbackEmitter: " << *inst); + << "Unhandled instruction in PullbackCloner: " << *inst); getContext().emitNondifferentiabilityError( inst, getInvoker(), diag::autodiff_expression_not_differentiable_note); errorOccurred = true; } AllocStackInst * -PullbackEmitter::getArrayAdjointElementBuffer(SILValue arrayAdjoint, - int eltIndex, SILLocation loc) { +PullbackCloner::getArrayAdjointElementBuffer(SILValue arrayAdjoint, + int eltIndex, SILLocation loc) { auto &ctx = builder.getASTContext(); auto arrayTanType = cast(arrayAdjoint->getType().getASTType()); auto arrayType = arrayTanType->getParent()->castTo(); @@ -1458,7 +1452,8 @@ PullbackEmitter::getArrayAdjointElementBuffer(SILValue arrayAdjoint, } } assert(subscriptDecl && "No `Array.TangentVector.subscript`"); - auto *subscriptGetterDecl = subscriptDecl->getOpaqueAccessor(AccessorKind::Get); + auto *subscriptGetterDecl = + subscriptDecl->getOpaqueAccessor(AccessorKind::Get); assert(subscriptGetterDecl && "No `Array.TangentVector.subscript` getter"); SILOptFunctionBuilder fb(getContext().getTransform()); auto *subscriptGetterFn = fb.getOrCreateFunction( @@ -1516,7 +1511,7 @@ PullbackEmitter::getArrayAdjointElementBuffer(SILValue arrayAdjoint, return eltAdjBuffer; } -void PullbackEmitter::visitApplyInst(ApplyInst *ai) { +void PullbackCloner::visitApplyInst(ApplyInst *ai) { assert(getPullbackInfo().shouldDifferentiateApplySite(ai)); // Skip `array.uninitialized_intrinsic` applications, which have special // `store` and `copy_addr` support. @@ -1595,7 +1590,7 @@ void PullbackEmitter::visitApplyInst(ApplyInst *ai) { SILValue seed; switch (getTangentValueCategory(origResult)) { case SILValueCategory::Object: - seed = materializeAdjoint(getAdjointValue(bb, origResult), loc); + seed = materializeAdjointDirect(getAdjointValue(bb, origResult), loc); break; case SILValueCategory::Address: seed = getAdjointBuffer(bb, origResult); @@ -1675,7 +1670,7 @@ void PullbackEmitter::visitApplyInst(ApplyInst *ai) { } } -void PullbackEmitter::visitStructInst(StructInst *si) { +void PullbackCloner::visitStructInst(StructInst *si) { auto *bb = si->getParent(); auto loc = si->getLoc(); auto *structDecl = si->getStructDecl(); @@ -1731,7 +1726,7 @@ void PullbackEmitter::visitStructInst(StructInst *si) { } } -void PullbackEmitter::visitBeginApplyInst(BeginApplyInst *bai) { +void PullbackCloner::visitBeginApplyInst(BeginApplyInst *bai) { // Diagnose `begin_apply` instructions. // Coroutine differentiation is not yet supported. getContext().emitNondifferentiabilityError( @@ -1740,7 +1735,7 @@ void PullbackEmitter::visitBeginApplyInst(BeginApplyInst *bai) { return; } -void PullbackEmitter::visitStructExtractInst(StructExtractInst *sei) { +void PullbackCloner::visitStructExtractInst(StructExtractInst *sei) { auto *bb = sei->getParent(); auto loc = getValidLocation(sei); auto structTy = remapType(sei->getOperand()->getType()).getASTType(); @@ -1782,7 +1777,7 @@ void PullbackEmitter::visitStructExtractInst(StructExtractInst *sei) { } } -void PullbackEmitter::visitRefElementAddrInst(RefElementAddrInst *reai) { +void PullbackCloner::visitRefElementAddrInst(RefElementAddrInst *reai) { auto *bb = reai->getParent(); auto loc = reai->getLoc(); auto adjBuf = getAdjointBuffer(bb, reai); @@ -1829,7 +1824,7 @@ void PullbackEmitter::visitRefElementAddrInst(RefElementAddrInst *reai) { } } -void PullbackEmitter::visitTupleInst(TupleInst *ti) { +void PullbackCloner::visitTupleInst(TupleInst *ti) { auto *bb = ti->getParent(); auto av = getAdjointValue(bb, ti); switch (av.getKind()) { @@ -1879,7 +1874,7 @@ void PullbackEmitter::visitTupleInst(TupleInst *ti) { } } -void PullbackEmitter::visitTupleExtractInst(TupleExtractInst *tei) { +void PullbackCloner::visitTupleExtractInst(TupleExtractInst *tei) { auto *bb = tei->getParent(); auto tupleTanTy = getRemappedTangentType(tei->getOperand()->getType()); auto av = getAdjointValue(bb, tei); @@ -1922,7 +1917,7 @@ void PullbackEmitter::visitTupleExtractInst(TupleExtractInst *tei) { } } -void PullbackEmitter::visitDestructureTupleInst(DestructureTupleInst *dti) { +void PullbackCloner::visitDestructureTupleInst(DestructureTupleInst *dti) { auto *bb = dti->getParent(); auto tupleTanTy = getRemappedTangentType(dti->getOperand()->getType()); SmallVector adjValues; @@ -1947,7 +1942,7 @@ void PullbackEmitter::visitDestructureTupleInst(DestructureTupleInst *dti) { } } -void PullbackEmitter::visitLoadOperation(SingleValueInstruction *inst) { +void PullbackCloner::visitLoadOperation(SingleValueInstruction *inst) { assert(isa(inst) || isa(inst)); auto *bb = inst->getParent(); auto loc = inst->getLoc(); @@ -1974,8 +1969,8 @@ void PullbackEmitter::visitLoadOperation(SingleValueInstruction *inst) { } } -void PullbackEmitter::visitStoreOperation(SILBasicBlock *bb, SILLocation loc, - SILValue origSrc, SILValue origDest) { +void PullbackCloner::visitStoreOperation(SILBasicBlock *bb, SILLocation loc, + SILValue origSrc, SILValue origDest) { auto &adjBuf = getAdjointBuffer(bb, origDest); switch (getTangentValueCategory(origSrc)) { case SILValueCategory::Object: { @@ -1995,12 +1990,12 @@ void PullbackEmitter::visitStoreOperation(SILBasicBlock *bb, SILLocation loc, } } -void PullbackEmitter::visitStoreInst(StoreInst *si) { +void PullbackCloner::visitStoreInst(StoreInst *si) { visitStoreOperation(si->getParent(), si->getLoc(), si->getSrc(), si->getDest()); } -void PullbackEmitter::visitCopyAddrInst(CopyAddrInst *cai) { +void PullbackCloner::visitCopyAddrInst(CopyAddrInst *cai) { auto *bb = cai->getParent(); auto &adjDest = getAdjointBuffer(bb, cai->getDest()); auto destType = remapType(adjDest->getType()); @@ -2009,7 +2004,7 @@ void PullbackEmitter::visitCopyAddrInst(CopyAddrInst *cai) { emitZeroIndirect(destType.getASTType(), adjDest, cai->getLoc()); } -void PullbackEmitter::visitCopyValueInst(CopyValueInst *cvi) { +void PullbackCloner::visitCopyValueInst(CopyValueInst *cvi) { auto *bb = cvi->getParent(); switch (getTangentValueCategory(cvi)) { case SILValueCategory::Object: { @@ -2028,7 +2023,7 @@ void PullbackEmitter::visitCopyValueInst(CopyValueInst *cvi) { } } -void PullbackEmitter::visitBeginBorrowInst(BeginBorrowInst *bbi) { +void PullbackCloner::visitBeginBorrowInst(BeginBorrowInst *bbi) { auto *bb = bbi->getParent(); switch (getTangentValueCategory(bbi)) { case SILValueCategory::Object: { @@ -2047,7 +2042,7 @@ void PullbackEmitter::visitBeginBorrowInst(BeginBorrowInst *bbi) { } } -void PullbackEmitter::visitBeginAccessInst(BeginAccessInst *bai) { +void PullbackCloner::visitBeginAccessInst(BeginAccessInst *bai) { // Check for non-differentiable writes. if (bai->getAccessKind() == SILAccessKind::Modify) { if (isa(bai->getSource())) { @@ -2067,7 +2062,7 @@ void PullbackEmitter::visitBeginAccessInst(BeginAccessInst *bai) { } } -void PullbackEmitter::visitUnconditionalCheckedCastAddrInst( +void PullbackCloner::visitUnconditionalCheckedCastAddrInst( UnconditionalCheckedCastAddrInst *uccai) { auto *bb = uccai->getParent(); auto &adjDest = getAdjointBuffer(bb, uccai->getDest()); @@ -2083,7 +2078,7 @@ void PullbackEmitter::visitUnconditionalCheckedCastAddrInst( emitZeroIndirect(destType.getASTType(), adjDest, uccai->getLoc()); } -void PullbackEmitter::visitUncheckedRefCastInst(UncheckedRefCastInst *urci) { +void PullbackCloner::visitUncheckedRefCastInst(UncheckedRefCastInst *urci) { auto *bb = urci->getParent(); assert(urci->getOperand()->getType().isObject()); assert(getRemappedTangentType(urci->getOperand()->getType()) == @@ -2093,7 +2088,7 @@ void PullbackEmitter::visitUncheckedRefCastInst(UncheckedRefCastInst *urci) { addAdjointValue(bb, urci->getOperand(), adj, urci->getLoc()); } -void PullbackEmitter::visitUpcastInst(UpcastInst *ui) { +void PullbackCloner::visitUpcastInst(UpcastInst *ui) { auto *bb = ui->getParent(); assert(ui->getOperand()->getType().isObject()); assert(getRemappedTangentType(ui->getOperand()->getType()) == @@ -2104,7 +2099,7 @@ void PullbackEmitter::visitUpcastInst(UpcastInst *ui) { } #define NOT_DIFFERENTIABLE(INST, DIAG) \ - void PullbackEmitter::visit##INST##Inst(INST##Inst *inst) { \ + void PullbackCloner::visit##INST##Inst(INST##Inst *inst) { \ getContext().emitNondifferentiabilityError(inst, getInvoker(), \ diag::DIAG); \ errorOccurred = true; \ @@ -2112,17 +2107,17 @@ void PullbackEmitter::visitUpcastInst(UpcastInst *ui) { } #undef NOT_DIFFERENTIABLE -AdjointValue PullbackEmitter::makeZeroAdjointValue(SILType type) { +AdjointValue PullbackCloner::makeZeroAdjointValue(SILType type) { return AdjointValue::createZero(allocator, remapType(type)); } -AdjointValue PullbackEmitter::makeConcreteAdjointValue(SILValue value) { +AdjointValue PullbackCloner::makeConcreteAdjointValue(SILValue value) { return AdjointValue::createConcrete(allocator, value); } template -AdjointValue PullbackEmitter::makeAggregateAdjointValue(SILType type, - EltRange elements) { +AdjointValue PullbackCloner::makeAggregateAdjointValue(SILType type, + EltRange elements) { AdjointValue *buf = reinterpret_cast(allocator.Allocate( elements.size() * sizeof(AdjointValue), alignof(AdjointValue))); MutableArrayRef elementsCopy(buf, elements.size()); @@ -2132,11 +2127,10 @@ AdjointValue PullbackEmitter::makeAggregateAdjointValue(SILType type, elementsCopy); } -SILValue PullbackEmitter::materializeAdjointDirect(AdjointValue val, - SILLocation loc) { +SILValue PullbackCloner::materializeAdjointDirect(AdjointValue val, + SILLocation loc) { assert(val.getType().isObject()); - LLVM_DEBUG(getADDebugStream() - << "Materializing adjoints for " << val << '\n'); + LLVM_DEBUG(getADDebugStream() << "Materializing adjoint for " << val << '\n'); switch (val.getKind()) { case AdjointValueKind::Zero: return recordTemporary(emitZeroDirect(val.getType().getASTType(), loc)); @@ -2158,37 +2152,18 @@ SILValue PullbackEmitter::materializeAdjointDirect(AdjointValue val, llvm_unreachable("invalid value kind"); } -SILValue PullbackEmitter::materializeAdjoint(AdjointValue val, - SILLocation loc) { - if (val.isConcrete()) { - LLVM_DEBUG(getADDebugStream() - << "Materializing adjoint: Value is concrete.\n"); - return val.getConcreteValue(); - } - LLVM_DEBUG(getADDebugStream() << "Materializing adjoint: Value is " - "non-concrete. Materializing directly.\n"); - return materializeAdjointDirect(val, loc); -} - -void PullbackEmitter::materializeAdjointIndirect(AdjointValue val, - SILValue destBufferAccess, - SILLocation loc) { +void PullbackCloner::materializeAdjointIndirect(AdjointValue val, + SILValue destAddress, + SILLocation loc) { + assert(destAddress->getType().isAddress()); switch (val.getKind()) { - /// Given a `%buf : *T, emit instructions that produce a zero or an aggregate - /// of zeros of the expected type. When `T` conforms to - /// `AdditiveArithmetic`, we emit a call to `AdditiveArithmetic.zero`. When - /// `T` is a builtin float, we emit a `float_literal` instruction. - /// Otherwise, we assert that `T` must be an aggregate where each element - /// conforms to `AdditiveArithmetic` or is a builtin float. We expect to emit - /// a zero for each element and use the appropriate aggregate constructor - /// instruction (in this case, `tuple`) to produce a tuple. But currently, - /// since we need indirect passing for aggregate instruction, we just use - /// `tuple_element_addr` to get element buffers and write elements to them. + /// If adjoint value is a symbolic zero, emit a call to + /// `AdditiveArithmetic.zero`. case AdjointValueKind::Zero: - emitZeroIndirect(val.getSwiftType(), destBufferAccess, loc); + emitZeroIndirect(val.getSwiftType(), destAddress, loc); break; - /// Given a `%buf : *(T0, T1, T2, ...)` or `%buf : *Struct` recursively emit - /// instructions to materialize the symbolic tuple or struct, filling the + /// If adjoint value is a symbolic aggregate (tuple or struct), recursively + /// materialize materialize the symbolic tuple or struct, filling the /// buffer. case AdjointValueKind::Aggregate: { if (auto *tupTy = val.getSwiftType()->getAs()) { @@ -2196,7 +2171,7 @@ void PullbackEmitter::materializeAdjointIndirect(AdjointValue val, auto eltTy = SILType::getPrimitiveAddressType( tupTy->getElementType(idx)->getCanonicalType()); auto *eltBuf = - builder.createTupleElementAddr(loc, destBufferAccess, idx, eltTy); + builder.createTupleElementAddr(loc, destAddress, idx, eltTy); materializeAdjointIndirect(val.getAggregateElement(idx), eltBuf, loc); } } else if (auto *structDecl = @@ -2205,7 +2180,7 @@ void PullbackEmitter::materializeAdjointIndirect(AdjointValue val, for (unsigned i = 0; fieldIt != structDecl->getStoredProperties().end(); ++fieldIt, ++i) { auto eltBuf = - builder.createStructElementAddr(loc, destBufferAccess, *fieldIt); + builder.createStructElementAddr(loc, destAddress, *fieldIt); materializeAdjointIndirect(val.getAggregateElement(i), eltBuf, loc); } } else { @@ -2213,28 +2188,29 @@ void PullbackEmitter::materializeAdjointIndirect(AdjointValue val, } break; } - /// Value is already materialized! + /// If adjoint value is concrete, it is already materialized. Store it in the + /// destination address. case AdjointValueKind::Concrete: auto concreteVal = val.getConcreteValue(); - builder.emitStoreValueOperation(loc, concreteVal, destBufferAccess, + builder.emitStoreValueOperation(loc, concreteVal, destAddress, StoreOwnershipQualifier::Init); break; } } -void PullbackEmitter::emitZeroIndirect(CanType type, SILValue bufferAccess, - SILLocation loc) { +void PullbackCloner::emitZeroIndirect(CanType type, SILValue address, + SILLocation loc) { auto tangentSpace = getTangentSpace(type); assert(tangentSpace && "No tangent space for this type"); switch (tangentSpace->getKind()) { case TangentSpace::Kind::TangentVector: - emitZeroIntoBuffer(builder, type, bufferAccess, loc); + emitZeroIntoBuffer(builder, type, address, loc); return; case TangentSpace::Kind::Tuple: { auto tupleType = tangentSpace->getTuple(); SmallVector zeroElements; for (unsigned i : range(tupleType->getNumElements())) { - auto eltAddr = builder.createTupleElementAddr(loc, bufferAccess, i); + auto eltAddr = builder.createTupleElementAddr(loc, address, i); emitZeroIndirect(tupleType->getElementType(i)->getCanonicalType(), eltAddr, loc); } @@ -2243,7 +2219,7 @@ void PullbackEmitter::emitZeroIndirect(CanType type, SILValue bufferAccess, } } -SILValue PullbackEmitter::emitZeroDirect(CanType type, SILLocation loc) { +SILValue PullbackCloner::emitZeroDirect(CanType type, SILLocation loc) { auto silType = getModule().Types.getLoweredLoadableType( type, TypeExpansionContext::minimal(), getModule()); auto *buffer = builder.createAllocStack(loc, silType); @@ -2254,9 +2230,9 @@ SILValue PullbackEmitter::emitZeroDirect(CanType type, SILLocation loc) { return loaded; } -AdjointValue PullbackEmitter::accumulateAdjointsDirect(AdjointValue lhs, - AdjointValue rhs, - SILLocation loc) { +AdjointValue PullbackCloner::accumulateAdjointsDirect(AdjointValue lhs, + AdjointValue rhs, + SILLocation loc) { LLVM_DEBUG(getADDebugStream() << "Materializing adjoint directly.\nLHS: " << lhs << "\nRHS: " << rhs << '\n'); @@ -2330,8 +2306,8 @@ AdjointValue PullbackEmitter::accumulateAdjointsDirect(AdjointValue lhs, llvm_unreachable("invalid LHS kind"); } -SILValue PullbackEmitter::accumulateDirect(SILValue lhs, SILValue rhs, - SILLocation loc) { +SILValue PullbackCloner::accumulateDirect(SILValue lhs, SILValue rhs, + SILLocation loc) { // TODO: Optimize for the case when lhs == rhs. LLVM_DEBUG(getADDebugStream() << "Emitting adjoint accumulation for lhs: " << lhs << " and rhs: " << rhs); @@ -2380,17 +2356,16 @@ SILValue PullbackEmitter::accumulateDirect(SILValue lhs, SILValue rhs, llvm_unreachable("invalid tangent space"); } -void PullbackEmitter::accumulateIndirect(SILValue resultBufAccess, - SILValue lhsBufAccess, - SILValue rhsBufAccess, - SILLocation loc) { +void PullbackCloner::accumulateIndirect(SILValue resultAddress, + SILValue lhsAddress, + SILValue rhsAddress, SILLocation loc) { // TODO: Optimize for the case when lhs == rhs. - assert(lhsBufAccess->getType() == rhsBufAccess->getType() && + assert(lhsAddress->getType() == rhsAddress->getType() && "Adjoint values must have same type!"); - assert(lhsBufAccess->getType().isAddress() && - rhsBufAccess->getType().isAddress() && + assert(lhsAddress->getType().isAddress() && + rhsAddress->getType().isAddress() && "Adjoint values must both have address types!"); - auto adjointTy = lhsBufAccess->getType(); + auto adjointTy = lhsAddress->getType(); auto adjointASTTy = adjointTy.getASTType(); auto *swiftMod = getModule().getSwiftModule(); auto tangentSpace = adjointASTTy->getAutoDiffTangentSpace( @@ -2423,7 +2398,7 @@ void PullbackEmitter::accumulateIndirect(SILValue resultBufAccess, auto metatype = builder.createMetatype(loc, metatypeSILType); // %2 = apply $0(%result, %new, %old, %1) builder.createApply(loc, witnessMethod, subMap, - {resultBufAccess, rhsBufAccess, lhsBufAccess, metatype}, + {resultAddress, rhsAddress, lhsAddress, metatype}, /*isNonThrowing*/ false); builder.emitDestroyValueOperation(loc, witnessMethod); return; @@ -2431,9 +2406,9 @@ void PullbackEmitter::accumulateIndirect(SILValue resultBufAccess, case TangentSpace::Kind::Tuple: { auto tupleType = tangentSpace->getTuple(); for (unsigned i : range(tupleType->getNumElements())) { - auto *destAddr = builder.createTupleElementAddr(loc, resultBufAccess, i); - auto *eltAddrLHS = builder.createTupleElementAddr(loc, lhsBufAccess, i); - auto *eltAddrRHS = builder.createTupleElementAddr(loc, rhsBufAccess, i); + auto *destAddr = builder.createTupleElementAddr(loc, resultAddress, i); + auto *eltAddrLHS = builder.createTupleElementAddr(loc, lhsAddress, i); + auto *eltAddrRHS = builder.createTupleElementAddr(loc, rhsAddress, i); accumulateIndirect(destAddr, eltAddrLHS, eltAddrRHS, loc); } return; @@ -2441,13 +2416,13 @@ void PullbackEmitter::accumulateIndirect(SILValue resultBufAccess, } } -void PullbackEmitter::accumulateIndirect(SILValue lhsDestAccess, - SILValue rhsAccess, SILLocation loc) { - assert(lhsDestAccess->getType().isAddress() && - rhsAccess->getType().isAddress()); - assert(lhsDestAccess->getFunction() == &getPullback()); - assert(rhsAccess->getFunction() == &getPullback()); - auto type = lhsDestAccess->getType(); +void PullbackCloner::accumulateIndirect(SILValue lhsDestAddress, + SILValue rhsAddress, SILLocation loc) { + assert(lhsDestAddress->getType().isAddress() && + rhsAddress->getType().isAddress()); + assert(lhsDestAddress->getFunction() == &getPullback()); + assert(rhsAddress->getFunction() == &getPullback()); + auto type = lhsDestAddress->getType(); auto astType = type.getASTType(); auto *swiftMod = getModule().getSwiftModule(); auto tangentSpace = @@ -2476,7 +2451,7 @@ void PullbackEmitter::accumulateIndirect(SILValue lhsDestAccess, auto metatype = builder.createMetatype(loc, metatypeSILType); // %2 = apply $0(%lhs, %rhs, %1) builder.createApply(loc, witnessMethod, subMap, - {lhsDestAccess, rhsAccess, metatype}, + {lhsDestAddress, rhsAddress, metatype}, /*isNonThrowing*/ false); builder.emitDestroyValueOperation(loc, witnessMethod); return; @@ -2484,8 +2459,8 @@ void PullbackEmitter::accumulateIndirect(SILValue lhsDestAccess, case TangentSpace::Kind::Tuple: { auto tupleType = tangentSpace->getTuple(); for (unsigned i : range(tupleType->getNumElements())) { - auto *destAddr = builder.createTupleElementAddr(loc, lhsDestAccess, i); - auto *eltAddrRHS = builder.createTupleElementAddr(loc, rhsAccess, i); + auto *destAddr = builder.createTupleElementAddr(loc, lhsDestAddress, i); + auto *eltAddrRHS = builder.createTupleElementAddr(loc, rhsAddress, i); accumulateIndirect(destAddr, eltAddrRHS, loc); } return; diff --git a/lib/SILOptimizer/Differentiation/VJPEmitter.cpp b/lib/SILOptimizer/Differentiation/VJPCloner.cpp similarity index 93% rename from lib/SILOptimizer/Differentiation/VJPEmitter.cpp rename to lib/SILOptimizer/Differentiation/VJPCloner.cpp index c9eab1f3b39e7..cdaa189bf382f 100644 --- a/lib/SILOptimizer/Differentiation/VJPEmitter.cpp +++ b/lib/SILOptimizer/Differentiation/VJPCloner.cpp @@ -1,4 +1,4 @@ -//===--- VJPEmitter.cpp - VJP generation in differentiation ---*- C++ -*---===// +//===--- VJPCloner.cpp - VJP function generation --------------*- C++ -*---===// // // This source file is part of the Swift.org open source project // @@ -17,9 +17,9 @@ #define DEBUG_TYPE "differentiation" -#include "swift/SILOptimizer/Differentiation/VJPEmitter.h" +#include "swift/SILOptimizer/Differentiation/VJPCloner.h" #include "swift/SILOptimizer/Differentiation/ADContext.h" -#include "swift/SILOptimizer/Differentiation/PullbackEmitter.h" +#include "swift/SILOptimizer/Differentiation/PullbackCloner.h" #include "swift/SILOptimizer/Differentiation/Thunk.h" #include "swift/SILOptimizer/PassManager/PrettyStackTrace.h" @@ -30,8 +30,8 @@ namespace swift { namespace autodiff { /*static*/ -SubstitutionMap VJPEmitter::getSubstitutionMap(SILFunction *original, - SILFunction *vjp) { +SubstitutionMap VJPCloner::getSubstitutionMap(SILFunction *original, + SILFunction *vjp) { auto substMap = original->getForwardingSubstitutionMap(); if (auto *vjpGenEnv = vjp->getGenericEnvironment()) { auto vjpSubstMap = vjpGenEnv->getForwardingSubstitutionMap(); @@ -44,8 +44,8 @@ SubstitutionMap VJPEmitter::getSubstitutionMap(SILFunction *original, /*static*/ const DifferentiableActivityInfo & -VJPEmitter::getActivityInfo(ADContext &context, SILFunction *original, - SILAutoDiffIndices indices, SILFunction *vjp) { +VJPCloner::getActivityInfo(ADContext &context, SILFunction *original, + SILAutoDiffIndices indices, SILFunction *vjp) { // Get activity info of the original function. auto &passManager = context.getPassManager(); auto *activityAnalysis = @@ -58,9 +58,9 @@ VJPEmitter::getActivityInfo(ADContext &context, SILFunction *original, return activityInfo; } -VJPEmitter::VJPEmitter(ADContext &context, SILFunction *original, - SILDifferentiabilityWitness *witness, SILFunction *vjp, - DifferentiationInvoker invoker) +VJPCloner::VJPCloner(ADContext &context, SILFunction *original, + SILDifferentiabilityWitness *witness, SILFunction *vjp, + DifferentiationInvoker invoker) : TypeSubstCloner(*vjp, *original, getSubstitutionMap(original, vjp)), context(context), original(original), witness(witness), vjp(vjp), invoker(invoker), @@ -73,7 +73,7 @@ VJPEmitter::VJPEmitter(ADContext &context, SILFunction *original, context.recordGeneratedFunction(pullback); } -SILFunction *VJPEmitter::createEmptyPullback() { +SILFunction *VJPCloner::createEmptyPullback() { auto &module = context.getModule(); auto origTy = original->getLoweredFunctionType(); // Get witness generic signature for remapping types. @@ -259,13 +259,13 @@ SILFunction *VJPEmitter::createEmptyPullback() { return pullback; } -void VJPEmitter::postProcess(SILInstruction *orig, SILInstruction *cloned) { +void VJPCloner::postProcess(SILInstruction *orig, SILInstruction *cloned) { if (errorOccurred) return; SILClonerWithScopes::postProcess(orig, cloned); } -SILBasicBlock *VJPEmitter::remapBasicBlock(SILBasicBlock *bb) { +SILBasicBlock *VJPCloner::remapBasicBlock(SILBasicBlock *bb) { auto *vjpBB = BBMap[bb]; // If error has occurred, or if block has already been remapped, return // remapped, return remapped block. @@ -282,9 +282,9 @@ SILBasicBlock *VJPEmitter::remapBasicBlock(SILBasicBlock *bb) { return vjpBB; } -SILBasicBlock *VJPEmitter::createTrampolineBasicBlock(TermInst *termInst, - StructInst *pbStructVal, - SILBasicBlock *succBB) { +SILBasicBlock *VJPCloner::createTrampolineBasicBlock(TermInst *termInst, + StructInst *pbStructVal, + SILBasicBlock *succBB) { assert(llvm::find(termInst->getSuccessorBlocks(), succBB) != termInst->getSuccessorBlocks().end() && "Basic block is not a successor of terminator instruction"); @@ -307,32 +307,32 @@ SILBasicBlock *VJPEmitter::createTrampolineBasicBlock(TermInst *termInst, return trampolineBB; } -void VJPEmitter::visit(SILInstruction *inst) { +void VJPCloner::visit(SILInstruction *inst) { if (errorOccurred) return; TypeSubstCloner::visit(inst); } -void VJPEmitter::visitSILInstruction(SILInstruction *inst) { +void VJPCloner::visitSILInstruction(SILInstruction *inst) { context.emitNondifferentiabilityError( inst, invoker, diag::autodiff_expression_not_differentiable_note); errorOccurred = true; } -SILType VJPEmitter::getLoweredType(Type type) { +SILType VJPCloner::getLoweredType(Type type) { auto vjpGenSig = vjp->getLoweredFunctionType()->getSubstGenericSignature(); Lowering::AbstractionPattern pattern(vjpGenSig, type->getCanonicalType(vjpGenSig)); return vjp->getLoweredType(pattern, type); } -SILType VJPEmitter::getNominalDeclLoweredType(NominalTypeDecl *nominal) { +SILType VJPCloner::getNominalDeclLoweredType(NominalTypeDecl *nominal) { auto nominalType = getOpASTType(nominal->getDeclaredInterfaceType()->getCanonicalType()); return getLoweredType(nominalType); } -StructInst *VJPEmitter::buildPullbackValueStructValue(TermInst *termInst) { +StructInst *VJPCloner::buildPullbackValueStructValue(TermInst *termInst) { assert(termInst->getFunction() == original); auto loc = RegularLocation::getAutoGeneratedLocation(); auto origBB = termInst->getParent(); @@ -348,10 +348,10 @@ StructInst *VJPEmitter::buildPullbackValueStructValue(TermInst *termInst) { return getBuilder().createStruct(loc, structLoweredTy, bbPullbackValues); } -EnumInst *VJPEmitter::buildPredecessorEnumValue(SILBuilder &builder, - SILBasicBlock *predBB, - SILBasicBlock *succBB, - SILValue pbStructVal) { +EnumInst *VJPCloner::buildPredecessorEnumValue(SILBuilder &builder, + SILBasicBlock *predBB, + SILBasicBlock *succBB, + SILValue pbStructVal) { auto loc = RegularLocation::getAutoGeneratedLocation(); auto *succEnum = pullbackInfo.getBranchingTraceDecl(succBB); auto enumLoweredTy = getNominalDeclLoweredType(succEnum); @@ -374,7 +374,7 @@ EnumInst *VJPEmitter::buildPredecessorEnumValue(SILBuilder &builder, return builder.createEnum(loc, newBox, enumEltDecl, enumLoweredTy); } -void VJPEmitter::visitReturnInst(ReturnInst *ri) { +void VJPCloner::visitReturnInst(ReturnInst *ri) { auto loc = ri->getOperand().getLoc(); auto &builder = getBuilder(); @@ -431,7 +431,7 @@ void VJPEmitter::visitReturnInst(ReturnInst *ri) { builder.createReturn(ri->getLoc(), joinElements(directResults, builder, loc)); } -void VJPEmitter::visitBranchInst(BranchInst *bi) { +void VJPCloner::visitBranchInst(BranchInst *bi) { // Build pullback struct value for original block. // Build predecessor enum value for destination block. auto *origBB = bi->getParent(); @@ -450,7 +450,7 @@ void VJPEmitter::visitBranchInst(BranchInst *bi) { args); } -void VJPEmitter::visitCondBranchInst(CondBranchInst *cbi) { +void VJPCloner::visitCondBranchInst(CondBranchInst *cbi) { // Build pullback struct value for original block. auto *pbStructVal = buildPullbackValueStructValue(cbi); // Create a new `cond_br` instruction. @@ -460,7 +460,7 @@ void VJPEmitter::visitCondBranchInst(CondBranchInst *cbi) { createTrampolineBasicBlock(cbi, pbStructVal, cbi->getFalseBB())); } -void VJPEmitter::visitSwitchEnumInstBase(SwitchEnumInstBase *sei) { +void VJPCloner::visitSwitchEnumInstBase(SwitchEnumInstBase *sei) { // Build pullback struct value for original block. auto *pbStructVal = buildPullbackValueStructValue(sei); @@ -492,15 +492,15 @@ void VJPEmitter::visitSwitchEnumInstBase(SwitchEnumInstBase *sei) { } } -void VJPEmitter::visitSwitchEnumInst(SwitchEnumInst *sei) { +void VJPCloner::visitSwitchEnumInst(SwitchEnumInst *sei) { visitSwitchEnumInstBase(sei); } -void VJPEmitter::visitSwitchEnumAddrInst(SwitchEnumAddrInst *seai) { +void VJPCloner::visitSwitchEnumAddrInst(SwitchEnumAddrInst *seai) { visitSwitchEnumInstBase(seai); } -void VJPEmitter::visitCheckedCastBranchInst(CheckedCastBranchInst *ccbi) { +void VJPCloner::visitCheckedCastBranchInst(CheckedCastBranchInst *ccbi) { // Build pullback struct value for original block. auto *pbStructVal = buildPullbackValueStructValue(ccbi); // Create a new `checked_cast_branch` instruction. @@ -513,7 +513,7 @@ void VJPEmitter::visitCheckedCastBranchInst(CheckedCastBranchInst *ccbi) { ccbi->getTrueBBCount(), ccbi->getFalseBBCount()); } -void VJPEmitter::visitCheckedCastValueBranchInst( +void VJPCloner::visitCheckedCastValueBranchInst( CheckedCastValueBranchInst *ccvbi) { // Build pullback struct value for original block. auto *pbStructVal = buildPullbackValueStructValue(ccvbi); @@ -527,7 +527,7 @@ void VJPEmitter::visitCheckedCastValueBranchInst( createTrampolineBasicBlock(ccvbi, pbStructVal, ccvbi->getFailureBB())); } -void VJPEmitter::visitCheckedCastAddrBranchInst( +void VJPCloner::visitCheckedCastAddrBranchInst( CheckedCastAddrBranchInst *ccabi) { // Build pullback struct value for original block. auto *pbStructVal = buildPullbackValueStructValue(ccabi); @@ -541,7 +541,7 @@ void VJPEmitter::visitCheckedCastAddrBranchInst( ccabi->getTrueBBCount(), ccabi->getFalseBBCount()); } -void VJPEmitter::visitApplyInst(ApplyInst *ai) { +void VJPCloner::visitApplyInst(ApplyInst *ai) { // If callee should not be differentiated, do standard cloning. if (!pullbackInfo.shouldDifferentiateApplySite(ai)) { LLVM_DEBUG(getADDebugStream() << "No active results:\n" << *ai << '\n'); @@ -815,7 +815,7 @@ void VJPEmitter::visitApplyInst(ApplyInst *ai) { getOpValue(origCallee)->getDefiningInstruction()); } -void VJPEmitter::visitDifferentiableFunctionInst( +void VJPCloner::visitDifferentiableFunctionInst( DifferentiableFunctionInst *dfi) { // Clone `differentiable_function` from original to VJP, then add the cloned // instruction to the `differentiable_function` worklist. @@ -824,7 +824,7 @@ void VJPEmitter::visitDifferentiableFunctionInst( context.addDifferentiableFunctionInstToWorklist(newDFI); } -bool VJPEmitter::run() { +bool VJPCloner::run() { PrettyStackTraceSILFunction trace("generating VJP for", original); LLVM_DEBUG(getADDebugStream() << "Cloning original @" << original->getName() << " to vjp @" << vjp->getName() << '\n'); @@ -853,8 +853,8 @@ bool VJPEmitter::run() { << *vjp); // Generate pullback code. - PullbackEmitter PullbackEmitter(*this); - if (PullbackEmitter.run()) { + PullbackCloner PullbackCloner(*this); + if (PullbackCloner.run()) { errorOccurred = true; return true; } diff --git a/lib/SILOptimizer/Mandatory/Differentiation.cpp b/lib/SILOptimizer/Mandatory/Differentiation.cpp index 48b87082aa49d..0c323052e339f 100644 --- a/lib/SILOptimizer/Mandatory/Differentiation.cpp +++ b/lib/SILOptimizer/Mandatory/Differentiation.cpp @@ -37,9 +37,9 @@ #include "swift/SIL/TypeSubstCloner.h" #include "swift/SILOptimizer/Analysis/DominanceAnalysis.h" #include "swift/SILOptimizer/Differentiation/ADContext.h" -#include "swift/SILOptimizer/Differentiation/JVPEmitter.h" +#include "swift/SILOptimizer/Differentiation/JVPCloner.h" #include "swift/SILOptimizer/Differentiation/Thunk.h" -#include "swift/SILOptimizer/Differentiation/VJPEmitter.h" +#include "swift/SILOptimizer/Differentiation/VJPCloner.h" #include "swift/SILOptimizer/PassManager/Passes.h" #include "swift/SILOptimizer/PassManager/Transforms.h" #include "swift/SILOptimizer/Utils/SILOptFunctionBuilder.h" @@ -919,8 +919,8 @@ bool DifferentiationTransformer::canonicalizeDifferentiabilityWitness( return true; } // Emit JVP function. - JVPEmitter emitter(context, original, witness, jvp, invoker); - if (emitter.run()) + JVPCloner cloner(context, original, witness, jvp, invoker); + if (cloner.run()) return true; } else { // If JVP generation is disabled or a user-defined custom VJP function @@ -947,8 +947,8 @@ bool DifferentiationTransformer::canonicalizeDifferentiabilityWitness( witness->setVJP(vjp); context.recordGeneratedFunction(vjp); // Emit VJP function. - VJPEmitter emitter(context, original, witness, vjp, invoker); - return emitter.run(); + VJPCloner cloner(context, original, witness, vjp, invoker); + return cloner.run(); } return false; } From b0cb51635c66b360ad40203bc03c70d9de3faa0a Mon Sep 17 00:00:00 2001 From: Michael Gottesman Date: Wed, 8 Jul 2020 19:13:31 -0700 Subject: [PATCH 15/16] [gardening] Add a helper to SILFunctionConventions to retrieve if a function has a noreturn result. Sometimes one just has a SILFunctionConvention instead of the underlying SILFunctionType (that the SILFunctionConvention contains). This just shims in that API onto the composition type. --- include/swift/SIL/SILFunctionConventions.h | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/include/swift/SIL/SILFunctionConventions.h b/include/swift/SIL/SILFunctionConventions.h index ff148628bf7f6..95945b3c0cb61 100644 --- a/include/swift/SIL/SILFunctionConventions.h +++ b/include/swift/SIL/SILFunctionConventions.h @@ -396,6 +396,9 @@ class SILFunctionConventions { /// Return the SIL type of the apply/entry argument at the given index. SILType getSILArgumentType(unsigned index, TypeExpansionContext context) const; + + /// Returns true if this function does not return to the caller. + bool isNoReturn(TypeExpansionContext context) const; }; struct SILFunctionConventions::SILResultTypeFunc { @@ -464,6 +467,11 @@ SILFunctionConventions::getSILArgumentType(unsigned index, funcTy->getParameters()[index - getNumIndirectSILResults()], context); } +inline bool +SILFunctionConventions::isNoReturn(TypeExpansionContext context) const { + return funcTy->isNoReturnFunction(silConv.getModule(), context); +} + inline SILFunctionConventions SILModuleConventions::getFunctionConventions(CanSILFunctionType funcTy) { return SILFunctionConventions(funcTy, *this); From 59b8c6e58fdada7e00769cc6a421c243db61bb49 Mon Sep 17 00:00:00 2001 From: Doug Gregor Date: Wed, 8 Jul 2020 19:56:04 -0700 Subject: [PATCH 16/16] Fix test again :( --- test/Driver/print_target_info.swift | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/Driver/print_target_info.swift b/test/Driver/print_target_info.swift index 5f8bd50b2bae5..98237b321f0d0 100644 --- a/test/Driver/print_target_info.swift +++ b/test/Driver/print_target_info.swift @@ -36,7 +36,7 @@ // CHECK-IOS: } -// CHECK-LINUX: "compilerVersion": "Swift version +// CHECK-LINUX: "compilerVersion": "{{.*}}Swift version // CHECK-LINUX: "target": { // CHECK-LINUX: "triple": "x86_64-unknown-linux",