From 6bf0f26ecd7add55c31e8027be15435b1248a6de Mon Sep 17 00:00:00 2001 From: Martin Boehme Date: Wed, 20 May 2020 15:28:12 +0200 Subject: [PATCH 01/28] Emit `/DEFAULTLIB` directive for `#pragma comment(lib, ...)` in a C module. This is important, for example, for linking correctly to the C++ standard library on Windows, which uses `#pragma comment(lib, ...)` in its headers to specify the correct library to link against. --- lib/IRGen/GenClangDecl.cpp | 11 ++++++++++ .../autolink-coff-c-pragma-transitive.h | 1 + test/IRGen/Inputs/autolink-coff-c-pragma.h | 3 +++ test/IRGen/Inputs/module.modulemap | 9 +++++++++ test/IRGen/autolink-coff-c-pragma.swift | 20 +++++++++++++++++++ 5 files changed, 44 insertions(+) create mode 100644 test/IRGen/Inputs/autolink-coff-c-pragma-transitive.h create mode 100644 test/IRGen/Inputs/autolink-coff-c-pragma.h create mode 100644 test/IRGen/Inputs/module.modulemap create mode 100644 test/IRGen/autolink-coff-c-pragma.swift diff --git a/lib/IRGen/GenClangDecl.cpp b/lib/IRGen/GenClangDecl.cpp index 304ddf5bf2b92..577307ce502d9 100644 --- a/lib/IRGen/GenClangDecl.cpp +++ b/lib/IRGen/GenClangDecl.cpp @@ -113,6 +113,17 @@ IRGenModule::getAddrOfClangGlobalDecl(clang::GlobalDecl global, } void IRGenModule::finalizeClangCodeGen() { + // Ensure that code is emitted for any `PragmaCommentDecl`s. (These are + // always guaranteed to be directly below the TranslationUnitDecl.) + // In Clang, this happens automatically during the Sema phase, but here we + // need to take care of it manually because our Clang CodeGenerator is not + // attached to Clang Sema as an ASTConsumer. + for (const auto *D : ClangASTContext->getTranslationUnitDecl()->decls()) { + if (const auto *PCD = dyn_cast(D)) { + emitClangDecl(PCD); + } + } + ClangCodeGen->HandleTranslationUnit( *const_cast(ClangASTContext)); } diff --git a/test/IRGen/Inputs/autolink-coff-c-pragma-transitive.h b/test/IRGen/Inputs/autolink-coff-c-pragma-transitive.h new file mode 100644 index 0000000000000..d3644cee43cb4 --- /dev/null +++ b/test/IRGen/Inputs/autolink-coff-c-pragma-transitive.h @@ -0,0 +1 @@ +#pragma comment(lib, "transitive-module") diff --git a/test/IRGen/Inputs/autolink-coff-c-pragma.h b/test/IRGen/Inputs/autolink-coff-c-pragma.h new file mode 100644 index 0000000000000..6d2705ff4e152 --- /dev/null +++ b/test/IRGen/Inputs/autolink-coff-c-pragma.h @@ -0,0 +1,3 @@ +#include "autolink-coff-c-pragma-transitive.h" + +#pragma comment(lib, "module") diff --git a/test/IRGen/Inputs/module.modulemap b/test/IRGen/Inputs/module.modulemap new file mode 100644 index 0000000000000..d396471ea63e0 --- /dev/null +++ b/test/IRGen/Inputs/module.modulemap @@ -0,0 +1,9 @@ +module AutolinkCoffCPragma { + header "autolink-coff-c-pragma.h" + export * +} + +module AutolinkCoffCPragmaTransitive { + header "autolink-coff-c-pragma-transitive.h" + export * +} diff --git a/test/IRGen/autolink-coff-c-pragma.swift b/test/IRGen/autolink-coff-c-pragma.swift new file mode 100644 index 0000000000000..0b4d8b389d64f --- /dev/null +++ b/test/IRGen/autolink-coff-c-pragma.swift @@ -0,0 +1,20 @@ +// Tests that a `#pragma comment(lib, ...)` in a C header imported as a module +// causes a corresponding `/DEFAULTLIB` directive to be emitted. +// +// We test that this is true also for C headers included transitively from +// another C header. + +// RUN: %swift -module-name Swift -target x86_64-unknown-windows-msvc -I %S/Inputs -emit-ir %s -parse-stdlib -parse-as-library -disable-legacy-type-info | %FileCheck %s -check-prefix=CHECK-MSVC-IR +// RUN: %swift -module-name Swift -target x86_64-unknown-windows-msvc -I %S/Inputs -S %s -parse-stdlib -parse-as-library -disable-legacy-type-info | %FileCheck %s -check-prefix=CHECK-MSVC-ASM + +// REQUIRES: CODEGENERATOR=X86 + +import AutolinkCoffCPragma + +// CHECK-MSVC-IR: !llvm.linker.options = !{!{{[0-9]+}}, !{{[0-9]+}}} +// CHECK-MSVC-IR-DAG: !{{[0-9]+}} = !{!"/DEFAULTLIB:module.lib"} +// CHECK-MSVC-IR-DAG: !{{[0-9]+}} = !{!"/DEFAULTLIB:transitive-module.lib"} + +// CHECK-MSVC-ASM: .section .drectve +// CHECK-MSVC-ASM-DAG: .ascii " /DEFAULTLIB:module.lib" +// CHECK-MSVC-ASM-DAG: .ascii " /DEFAULTLIB:transitive-module.lib" From e530ef8f775879add40f6a44b1c316510b4e8c57 Mon Sep 17 00:00:00 2001 From: martinboehme Date: Mon, 25 May 2020 16:46:08 +0200 Subject: [PATCH 02/28] Update WindowsBuild.md Make the note on needing to adjust `PYTHON_EXECUTABLE` more prominent, and repeat in in the place where that argument actually needs to be adjusted. --- docs/WindowsBuild.md | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/docs/WindowsBuild.md b/docs/WindowsBuild.md index 61c1e6ef2c5e2..d49b0ca00462f 100644 --- a/docs/WindowsBuild.md +++ b/docs/WindowsBuild.md @@ -21,7 +21,7 @@ In the Visual Studio installation program, under *Individual Components* 1. Install *Python 2*, either the 32-bit version (C:\Python27\\) or the 64-bit version (C:\Python27amd64\\) -> If you install the 64-bit version only, you will need to adjust `PYTHON_EXECUTABLE` below to `C:\Python27amd64\python.exe` + **Note:** If you install the 64-bit version only, you will need to adjust `PYTHON_EXECUTABLE` below to `C:\Python27amd64\python.exe` 2. Install *Python 3 64 bits (3.7.x)* @@ -139,6 +139,9 @@ cmake -B "S:\b\toolchain" ^ ninja -C S:\b\toolchain ``` +**Note:** If you installed only the 64-bit version of Python, you will need to adjust `PYTHON_EXECUTABLE` argument to `C:\Python27amd64\python.exe` + + ## Running Swift tests on Windows ```cmd From fd56606d5fbd68d6d3ecb4ecdf1aa3e6cd3d0829 Mon Sep 17 00:00:00 2001 From: Varun Gandhi Date: Mon, 25 May 2020 15:29:54 -0700 Subject: [PATCH 03/28] [NFC] Remove bits and pieces referring to Ubuntu 14.04 LTS (Trusty Tahr). Ubuntu 20.04 LTS is out now; the last version to support 14.04 was Swift 5.1, so it is ok to delete this. --- README.md | 2 -- docs/Ubuntu14.md | 24 ------------------------ utils/build-presets.ini | 13 ------------- 3 files changed, 39 deletions(-) delete mode 100644 docs/Ubuntu14.md diff --git a/README.md b/README.md index d8bfb65a54a9d..1427307fa9a06 100644 --- a/README.md +++ b/README.md @@ -151,8 +151,6 @@ with version 2 shipped with Ubuntu. **Note:** For Ubuntu 20.04, use `libpython2-dev` in place of the libpython-dev package above. -Build instructions for Ubuntu 14.04 LTS can be found [here](docs/Ubuntu14.md). - ### Getting Sources for Swift and Related Projects First create a directory for all of the Swift sources: diff --git a/docs/Ubuntu14.md b/docs/Ubuntu14.md deleted file mode 100644 index 78e2982e4ef53..0000000000000 --- a/docs/Ubuntu14.md +++ /dev/null @@ -1,24 +0,0 @@ -# Getting Started with Swift on Ubuntu 14.04 - -## Upgrade Clang -You'll need to upgrade your clang compiler for C++14 support and create a symlink. The minimum required version of clang may change, and may not be available on Ubuntu 14.04 in the future. -```bash -sudo apt-get install clang-3.9 -sudo update-alternatives --install /usr/bin/clang clang /usr/bin/clang-3.9 100 -sudo update-alternatives --install /usr/bin/clang++ clang++ /usr/bin/clang++-3.9 100 -sudo update-alternatives --set clang /usr/bin/clang-3.9 -sudo update-alternatives --set clang++ /usr/bin/clang++-3.9 -``` - -## Upgrade CMake -You'll need to upgrade your CMake toolchain to a supported version by building a local copy. The minimum required version of CMake may change, and may not be available on Ubuntu 14.04 in the future. -```bash -wget https://cmake.org/files/v3.5/cmake-3.5.2.tar.gz -tar xf cmake-3.5.2.tar.gz -cd cmake-3.5.2 -./configure -make -sudo make install -sudo update-alternatives --install /usr/bin/cmake cmake /usr/local/bin/cmake 1 --force -cmake --version # This should print 3.5.2 -``` diff --git a/utils/build-presets.ini b/utils/build-presets.ini index 902493f677d45..d51ca412f28f6 100644 --- a/utils/build-presets.ini +++ b/utils/build-presets.ini @@ -968,19 +968,6 @@ mixin-preset= buildbot_linux_1510 mixin_buildbot_linux,no_test -# Ubuntu 14.04 preset for backwards compat and future customizations. -[preset: buildbot_linux_1404] -mixin-preset=buildbot_linux -indexstore-db=0 -sourcekit-lsp=0 - -# Ubuntu 14.04 preset that skips all tests except for integration testing of the -# package. -[preset: buildbot_linux_1404,notest] -mixin-preset= - buildbot_linux_1404 - mixin_buildbot_linux,no_test - [preset: buildbot_linux,smoketest] mixin-preset=mixin_linux_installation build-subdir=buildbot_linux From d867c88d2c830ff4036109ccf60e3eb099df710a Mon Sep 17 00:00:00 2001 From: martinboehme Date: Tue, 26 May 2020 16:20:15 +0200 Subject: [PATCH 04/28] Update WindowsBuild.md Added note on which edition of Visual Studio to get. --- docs/WindowsBuild.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/WindowsBuild.md b/docs/WindowsBuild.md index 61c1e6ef2c5e2..32297b692136c 100644 --- a/docs/WindowsBuild.md +++ b/docs/WindowsBuild.md @@ -1,6 +1,6 @@ # Building Swift on Windows -Visual Studio 2017 or newer is needed to build swift on Windows. +Visual Studio 2017 or newer is needed to build Swift on Windows. As of this writing, Visual Studio 2019 is the most recent release. The free Community edition is sufficient to build Swift. The following must take place in the **developer command prompt** (provided by Visual Studio). This shows up as "x64 Native Tools Command Prompt for VS2017" (or VS2019, VS2019 Preview depending on the Visual Studio that you are using) in the Start Menu. From 6ca148d65ed8ccc46afc8b9e2aabff707a805940 Mon Sep 17 00:00:00 2001 From: Slava Pestov Date: Tue, 26 May 2020 22:17:18 -0400 Subject: [PATCH 05/28] AST: Optimize removeShadowedDecls() Previously, whenever name lookup returned two declarations with the same name, we would compute the canonical type of each one as part of the shadowing check. The canonical type calculation is rather expensive for GenericFunctionTypes since it requires constructing a GenericSignatureBuilder to canonicalize type parameters that appear in the function's signature. Instead, let's first shard all declarations that have the same name by their generic signature. If two declarations have the same signature, only then do we proceed to compute their canonical type. Since computing a canonical GenericSignature is cheaper than computing a canonical GenericFunctionType, this should speed up name lookup of heavily-overloaded names, such as operators. Fixes . --- lib/AST/Decl.cpp | 2 +- lib/AST/NameLookup.cpp | 121 +++++++++++++++++++++++++++++------------ 2 files changed, 86 insertions(+), 37 deletions(-) diff --git a/lib/AST/Decl.cpp b/lib/AST/Decl.cpp index 13a4d46dcac9e..df86e154de32e 100644 --- a/lib/AST/Decl.cpp +++ b/lib/AST/Decl.cpp @@ -2810,7 +2810,7 @@ CanType ValueDecl::getOverloadSignatureType() const { // implementation of the swift::conflicting overload that deals with // overload types, in order to account for cases where the overload types // don't match, but the decls differ and therefore always conflict. - + assert(isa(this)); return CanType(); } diff --git a/lib/AST/NameLookup.cpp b/lib/AST/NameLookup.cpp index 5afc70b7d9bce..c6f36c1522a67 100644 --- a/lib/AST/NameLookup.cpp +++ b/lib/AST/NameLookup.cpp @@ -244,12 +244,13 @@ static ConstructorComparison compareConstructors(ConstructorDecl *ctor1, return ConstructorComparison::Same; } -/// Given a set of declarations whose names and signatures have matched, +/// Given a set of declarations whose names and interface types have matched, /// figure out which of these declarations have been shadowed by others. template -static void -recordShadowedDeclsAfterSignatureMatch(ArrayRef decls, const DeclContext *dc, - llvm::SmallPtrSetImpl &shadowed) { +static void recordShadowedDeclsAfterTypeMatch( + ArrayRef decls, + const DeclContext *dc, + llvm::SmallPtrSetImpl &shadowed) { assert(decls.size() > 1 && "Nothing collided"); // Compare each declaration to every other declaration. This is @@ -491,6 +492,48 @@ recordShadowedDeclsAfterSignatureMatch(ArrayRef decls, const DeclContext *dc, } } +/// Given a set of declarations whose names and generic signatures have matched, +/// figure out which of these declarations have been shadowed by others. +static void recordShadowedDeclsAfterSignatureMatch( + ArrayRef decls, + const DeclContext *dc, + llvm::SmallPtrSetImpl &shadowed) { + assert(decls.size() > 1 && "Nothing collided"); + + // Categorize all of the declarations based on their overload types. + llvm::SmallDenseMap> collisions; + llvm::SmallVector collisionTypes; + + for (auto decl : decls) { + assert(!isa(decl)); + + CanType type; + + // FIXME: The type of a variable or subscript doesn't include + // enough context to distinguish entities from different + // constrained extensions, so use the overload signature's + // type. This is layering a partial fix upon a total hack. + if (auto asd = dyn_cast(decl)) + type = asd->getOverloadSignatureType(); + else + type = decl->getInterfaceType()->getCanonicalType(); + + // Record this declaration based on its signature. + auto &known = collisions[type]; + if (known.size() == 1) { + collisionTypes.push_back(type); + } + known.push_back(decl); + } + + // Check whether we have shadowing for signature collisions. + for (auto type : collisionTypes) { + ArrayRef collidingDecls = collisions[type]; + recordShadowedDeclsAfterTypeMatch(collidingDecls, dc, + shadowed); + } +} + /// Look through the given set of declarations (that all have the same name), /// recording those that are shadowed by another declaration in the /// \c shadowed set. @@ -526,14 +569,23 @@ static void recordShadowedDecls(ArrayRef decls, if (decls.size() < 2) return; + llvm::TinyPtrVector typeDecls; + // Categorize all of the declarations based on their overload signatures. - llvm::SmallDenseMap> collisions; - llvm::SmallVector collisionTypes; - llvm::SmallDenseMap> + llvm::SmallDenseMap> collisions; + llvm::SmallVector collisionSignatures; + llvm::SmallDenseMap> importedInitializerCollisions; - llvm::TinyPtrVector importedInitializerCollectionTypes; + llvm::TinyPtrVector importedInitializerCollisionTypes; for (auto decl : decls) { + if (auto *typeDecl = dyn_cast(decl)) { + typeDecls.push_back(typeDecl); + continue; + } + // Specifically keep track of imported initializers, which can come from // Objective-C init methods, Objective-C factory methods, renamed C // functions, or be synthesized by the importer. @@ -544,50 +596,47 @@ static void recordShadowedDecls(ArrayRef decls, auto nominal = ctor->getDeclContext()->getSelfNominalTypeDecl(); auto &knownInits = importedInitializerCollisions[nominal]; if (knownInits.size() == 1) { - importedInitializerCollectionTypes.push_back(nominal); + importedInitializerCollisionTypes.push_back(nominal); } knownInits.push_back(ctor); } } - CanType signature; + // If the decl is currently being validated, this is likely a recursive + // reference and we'll want to skip ahead so as to avoid having its type + // attempt to desugar itself. + if (decl->isRecursiveValidation()) + continue; - if (!isa(decl)) { - // If the decl is currently being validated, this is likely a recursive - // reference and we'll want to skip ahead so as to avoid having its type - // attempt to desugar itself. - if (decl->isRecursiveValidation()) - continue; - auto ifaceType = decl->getInterfaceType(); - - // FIXME: the canonical type makes a poor signature, because we don't - // canonicalize away default arguments. - signature = ifaceType->getCanonicalType(); - - // FIXME: The type of a variable or subscript doesn't include - // enough context to distinguish entities from different - // constrained extensions, so use the overload signature's - // type. This is layering a partial fix upon a total hack. - if (auto asd = dyn_cast(decl)) - signature = asd->getOverloadSignatureType(); - } + CanGenericSignature signature; + + auto *dc = decl->getInnermostDeclContext(); + if (auto genericSig = dc->getGenericSignatureOfContext()) + signature = genericSig->getCanonicalSignature(); // Record this declaration based on its signature. - auto &known = collisions[signature]; + auto &known = collisions[signature.getPointer()]; if (known.size() == 1) { - collisionTypes.push_back(signature); + collisionSignatures.push_back(signature.getPointer()); } + known.push_back(decl); } + // Check whether we have shadowing for type declarations. + if (typeDecls.size() > 1) { + ArrayRef collidingDecls = typeDecls; + recordShadowedDeclsAfterTypeMatch(collidingDecls, dc, shadowed); + } + // Check whether we have shadowing for signature collisions. - for (auto signature : collisionTypes) { + for (auto signature : collisionSignatures) { ArrayRef collidingDecls = collisions[signature]; recordShadowedDeclsAfterSignatureMatch(collidingDecls, dc, shadowed); } // Check whether we have shadowing for imported initializer collisions. - for (auto nominal : importedInitializerCollectionTypes) { + for (auto nominal : importedInitializerCollisionTypes) { recordShadowedDeclsForImportedInits(importedInitializerCollisions[nominal], shadowed); } @@ -597,15 +646,15 @@ static void recordShadowedDecls(ArrayRef decls, const DeclContext *dc, llvm::SmallPtrSetImpl &shadowed) { // Always considered to have the same signature. - recordShadowedDeclsAfterSignatureMatch(decls, dc, shadowed); + recordShadowedDeclsAfterTypeMatch(decls, dc, shadowed); } static void recordShadowedDecls(ArrayRef decls, const DeclContext *dc, llvm::SmallPtrSetImpl &shadowed) { - // Always considered to have the same signature. - recordShadowedDeclsAfterSignatureMatch(decls, dc, shadowed); + // Always considered to have the same type. + recordShadowedDeclsAfterTypeMatch(decls, dc, shadowed); } template From f025deaee0f196d180d0486ee15aabdd7a316cdd Mon Sep 17 00:00:00 2001 From: martinboehme Date: Thu, 28 May 2020 10:54:12 +0200 Subject: [PATCH 06/28] Update WindowsBuild.md Remove reference to the latest release. --- docs/WindowsBuild.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/WindowsBuild.md b/docs/WindowsBuild.md index 32297b692136c..7cf9dcebb0636 100644 --- a/docs/WindowsBuild.md +++ b/docs/WindowsBuild.md @@ -1,6 +1,6 @@ # Building Swift on Windows -Visual Studio 2017 or newer is needed to build Swift on Windows. As of this writing, Visual Studio 2019 is the most recent release. The free Community edition is sufficient to build Swift. +Visual Studio 2017 or newer is needed to build Swift on Windows. The free Community edition is sufficient to build Swift. The following must take place in the **developer command prompt** (provided by Visual Studio). This shows up as "x64 Native Tools Command Prompt for VS2017" (or VS2019, VS2019 Preview depending on the Visual Studio that you are using) in the Start Menu. From 3767ece85ccd64585ae88ef7e9535c6dd33d5872 Mon Sep 17 00:00:00 2001 From: David Zarzycki Date: Wed, 27 May 2020 11:29:04 -0400 Subject: [PATCH 07/28] [CMake] Simplify two binary variables into one tri-state variable Also remove some ancient logic to detect and ignore requests to use LLD. If people want to explicitly use LLD, they probably have a reason and we shouldn't second guess them. --- CMakeLists.txt | 28 ++++++++++------------- cmake/modules/AddSwift.cmake | 9 +++----- cmake/modules/AddSwiftUnittests.cmake | 10 ++------ stdlib/cmake/modules/AddSwiftStdlib.cmake | 21 +++++------------ test/lit.site.cfg.in | 4 ++-- validation-test/lit.site.cfg.in | 4 ++-- 6 files changed, 27 insertions(+), 49 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index baeb525d1f737..d699d69d04a50 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -147,24 +147,20 @@ set(SWIFT_COMPILER_VERSION "" CACHE STRING set(CLANG_COMPILER_VERSION "" CACHE STRING "The internal version of the Clang compiler") -# Indicate whether Swift should attempt to use the lld linker. -if(CMAKE_SYSTEM_NAME STREQUAL Windows AND NOT CMAKE_HOST_SYSTEM_NAME STREQUAL Windows) - set(SWIFT_ENABLE_LLD_LINKER_default TRUE) +# Which default linker to use. Prefer LLVM_USE_LINKER if it set, otherwise use +# our own defaults. This should only be possible in a unified (not stand alone) +# build environment. +if(LLVM_USE_LINKER) + set(SWIFT_USE_LINKER_default "${LLVM_USE_LINKER}") +elseif(CMAKE_SYSTEM_NAME STREQUAL Windows AND NOT CMAKE_HOST_SYSTEM_NAME STREQUAL Windows) + set(SWIFT_USE_LINKER_default "lld") +elseif(CMAKE_SYSTEM_NAME STREQUAL Darwin) + set(SWIFT_USE_LINKER_default "") else() - set(SWIFT_ENABLE_LLD_LINKER_default FALSE) + set(SWIFT_USE_LINKER_default "gold") endif() -set(SWIFT_ENABLE_LLD_LINKER ${SWIFT_ENABLE_LLD_LINKER_default} CACHE BOOL - "Enable using the lld linker when available") - -# Indicate whether Swift should attempt to use the gold linker. -# This is not used on Darwin. -if(CMAKE_SYSTEM_NAME STREQUAL Darwin OR CMAKE_SYSTEM_NAME STREQUAL Windows) - set(SWIFT_ENABLE_GOLD_LINKER_default FALSE) -else() - set(SWIFT_ENABLE_GOLD_LINKER_default TRUE) -endif() -set(SWIFT_ENABLE_GOLD_LINKER ${SWIFT_ENABLE_GOLD_LINKER_default} CACHE BOOL - "Enable using the gold linker when available") +set(SWIFT_USE_LINKER ${SWIFT_USE_LINKER_default} CACHE STRING + "Build Swift with a non-default linker") set(SWIFT_TOOLS_ENABLE_LTO OFF CACHE STRING "Build Swift tools with LTO. One must specify the form of LTO by setting this to one of: 'full', 'thin'. This diff --git a/cmake/modules/AddSwift.cmake b/cmake/modules/AddSwift.cmake index ae109498abce4..ff24b14198b4a 100644 --- a/cmake/modules/AddSwift.cmake +++ b/cmake/modules/AddSwift.cmake @@ -116,7 +116,7 @@ function(_add_host_variant_c_compile_link_flags name) if(SWIFT_HOST_VARIANT_SDK STREQUAL ANDROID) # lld can handle targeting the android build. However, if lld is not # enabled, then fallback to the linker included in the android NDK. - if(NOT SWIFT_ENABLE_LLD_LINKER) + if(NOT SWIFT_USE_LINKER STREQUAL "lld") swift_android_tools_path(${SWIFT_HOST_VARIANT_ARCH} tools_path) target_compile_options(${name} PRIVATE -B${tools_path}) endif() @@ -368,12 +368,9 @@ function(_add_host_variant_link_flags target) endif() if(NOT SWIFT_COMPILER_IS_MSVC_LIKE) - if(SWIFT_ENABLE_LLD_LINKER) + if(SWIFT_USE_LINKER) target_link_options(${target} PRIVATE - -fuse-ld=lld$<$:.exe>) - elseif(SWIFT_ENABLE_GOLD_LINKER) - target_link_options(${target} PRIVATE - -fuse-ld=gold$<$:.exe>) + -fuse-ld=${SWIFT_USE_LINKER}$<$:.exe>) endif() endif() diff --git a/cmake/modules/AddSwiftUnittests.cmake b/cmake/modules/AddSwiftUnittests.cmake index e088997741eb1..811b455bd3ed2 100644 --- a/cmake/modules/AddSwiftUnittests.cmake +++ b/cmake/modules/AddSwiftUnittests.cmake @@ -57,15 +57,9 @@ function(add_swift_unittest test_dirname) _ENABLE_EXTENDED_ALIGNED_STORAGE) endif() - find_program(LDLLD_PATH "ld.lld") - # Strangely, macOS finds lld and then can't find it when using -fuse-ld= - if(SWIFT_ENABLE_LLD_LINKER AND LDLLD_PATH AND NOT APPLE) + if(SWIFT_USE_LINKER) set_property(TARGET "${test_dirname}" APPEND_STRING PROPERTY - LINK_FLAGS " -fuse-ld=lld") - elseif(SWIFT_ENABLE_GOLD_LINKER AND - "${SWIFT_SDK_${SWIFT_HOST_VARIANT_SDK}_OBJECT_FORMAT}" STREQUAL "ELF") - set_property(TARGET "${test_dirname}" APPEND_STRING PROPERTY - LINK_FLAGS " -fuse-ld=gold") + LINK_FLAGS " -fuse-ld=${SWIFT_USE_LINKER}") endif() if(SWIFT_ANALYZE_CODE_COVERAGE) diff --git a/stdlib/cmake/modules/AddSwiftStdlib.cmake b/stdlib/cmake/modules/AddSwiftStdlib.cmake index 8a4432fb46a0a..05123d1d8e127 100644 --- a/stdlib/cmake/modules/AddSwiftStdlib.cmake +++ b/stdlib/cmake/modules/AddSwiftStdlib.cmake @@ -93,7 +93,7 @@ function(_add_target_variant_c_compile_link_flags) if("${CFLAGS_SDK}" STREQUAL "ANDROID") # lld can handle targeting the android build. However, if lld is not # enabled, then fallback to the linker included in the android NDK. - if(NOT SWIFT_ENABLE_LLD_LINKER) + if(NOT SWIFT_USE_LINKER STREQUAL "lld") swift_android_tools_path(${CFLAGS_ARCH} tools_path) list(APPEND result "-B" "${tools_path}") endif() @@ -414,20 +414,11 @@ function(_add_target_variant_link_flags) list(APPEND library_search_directories "${SWIFT_${LFLAGS_SDK}_${LFLAGS_ARCH}_ICU_I18N_LIBDIR}") endif() - if(NOT SWIFT_COMPILER_IS_MSVC_LIKE) - # FIXME: On Apple platforms, find_program needs to look for "ld64.lld" - find_program(LDLLD_PATH "ld.lld") - if((SWIFT_ENABLE_LLD_LINKER AND LDLLD_PATH AND NOT APPLE) OR - ("${LFLAGS_SDK}" STREQUAL "WINDOWS" AND - NOT "${CMAKE_SYSTEM_NAME}" STREQUAL "WINDOWS")) - list(APPEND result "-fuse-ld=lld") - elseif(SWIFT_ENABLE_GOLD_LINKER AND - "${SWIFT_SDK_${LFLAGS_SDK}_OBJECT_FORMAT}" STREQUAL "ELF") - if(CMAKE_HOST_SYSTEM_NAME STREQUAL Windows) - list(APPEND result "-fuse-ld=gold.exe") - else() - list(APPEND result "-fuse-ld=gold") - endif() + if(SWIFT_USE_LINKER AND NOT SWIFT_COMPILER_IS_MSVC_LIKE) + if(CMAKE_HOST_SYSTEM_NAME STREQUAL Windows) + list(APPEND result "-fuse-ld=${SWIFT_USE_LINKER}.exe") + else() + list(APPEND result "-fuse-ld=${SWIFT_USE_LINKER}") endif() endif() diff --git a/test/lit.site.cfg.in b/test/lit.site.cfg.in index 7df41178b775c..a3b4188eb030f 100644 --- a/test/lit.site.cfg.in +++ b/test/lit.site.cfg.in @@ -116,10 +116,10 @@ if "@SWIFT_ENABLE_SOURCEKIT_TESTS@" == "TRUE": if "@SWIFT_HAVE_LIBXML2@" == "TRUE": config.available_features.add('libxml2') -if "@SWIFT_ENABLE_LLD_LINKER@" == "TRUE": +if "@SWIFT_USE_LINKER@" == "lld": config.android_linker_name = "lld" else: - # even if SWIFT_ENABLE_GOLD_LINKER isn't set, we cannot use BFD for Android + # even if SWIFT_USE_LINKER isn't set, we cannot use BFD for Android config.android_linker_name = "gold" if '@SWIFT_INCLUDE_TOOLS@' == 'TRUE': diff --git a/validation-test/lit.site.cfg.in b/validation-test/lit.site.cfg.in index 6a00ee3ab5ef7..3666329b720d6 100644 --- a/validation-test/lit.site.cfg.in +++ b/validation-test/lit.site.cfg.in @@ -98,10 +98,10 @@ if "@CMAKE_GENERATOR@" == "Xcode": config.available_features.add("CMAKE_GENERATOR=@CMAKE_GENERATOR@") -if "@SWIFT_ENABLE_LLD_LINKER@" == "TRUE": +if "@SWIFT_USE_LINKER@" == "lld": config.android_linker_name = "lld" else: - # even if SWIFT_ENABLE_GOLD_LINKER isn't set, we cannot use BFD for Android + # even if SWIFT_USE_LINKER isn't set, we cannot use BFD for Android config.android_linker_name = "gold" # Let the main config do the real work. From 9526384afffdec8605f9110d6740f588a3d011e1 Mon Sep 17 00:00:00 2001 From: David Zarzycki Date: Thu, 28 May 2020 11:23:18 -0400 Subject: [PATCH 08/28] [IRGen] NFC: Narrow the scope of some code This code is only used on one side of a if/else branch, so let's just move it inside the `else` block. --- lib/IRGen/GenClass.cpp | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/lib/IRGen/GenClass.cpp b/lib/IRGen/GenClass.cpp index 5c1d736fb99e4..4042d8e9feb0d 100644 --- a/lib/IRGen/GenClass.cpp +++ b/lib/IRGen/GenClass.cpp @@ -813,17 +813,6 @@ llvm::Value *irgen::emitClassAllocation(IRGenFunction &IGF, SILType selfType, auto &classLayout = classTI.getClassLayout(IGF.IGM, selfType, /*forBackwardDeployment=*/false); - llvm::Value *size, *alignMask; - if (classLayout.isFixedSize()) { - size = IGF.IGM.getSize(classLayout.getSize()); - alignMask = IGF.IGM.getSize(classLayout.getAlignMask()); - } else { - std::tie(size, alignMask) - = emitClassResilientInstanceSizeAndAlignMask(IGF, - selfType.getClassOrBoundGenericClass(), - metadata); - } - llvm::Type *destType = classLayout.getType()->getPointerTo(); llvm::Value *val = nullptr; if (llvm::Value *Promoted = stackPromote(IGF, classLayout, StackAllocSize, @@ -831,6 +820,17 @@ llvm::Value *irgen::emitClassAllocation(IRGenFunction &IGF, SILType selfType, val = IGF.Builder.CreateBitCast(Promoted, IGF.IGM.RefCountedPtrTy); val = IGF.emitInitStackObjectCall(metadata, val, "reference.new"); } else { + llvm::Value *size, *alignMask; + if (classLayout.isFixedSize()) { + size = IGF.IGM.getSize(classLayout.getSize()); + alignMask = IGF.IGM.getSize(classLayout.getAlignMask()); + } else { + std::tie(size, alignMask) + = emitClassResilientInstanceSizeAndAlignMask(IGF, + selfType.getClassOrBoundGenericClass(), + metadata); + } + // Allocate the object on the heap. std::tie(size, alignMask) = appendSizeForTailAllocatedArrays(IGF, size, alignMask, TailArrays); From 3c3093db75e61115b9198a933b6c5aabebd12884 Mon Sep 17 00:00:00 2001 From: Dan Zheng Date: Thu, 28 May 2020 12:32:15 -0700 Subject: [PATCH 09/28] [AutoDiff] Clean up VJP basic block utilities. Add a common helper function `VJPEmitter::createTrampolineBasicBlock`. Change `VJPEmitter::buildPullbackValueStructValue` to take an original basic block instead of a terminator instruction. --- .../SILOptimizer/Differentiation/VJPEmitter.h | 18 ++- .../Differentiation/VJPEmitter.cpp | 113 +++++++----------- 2 files changed, 57 insertions(+), 74 deletions(-) diff --git a/include/swift/SILOptimizer/Differentiation/VJPEmitter.h b/include/swift/SILOptimizer/Differentiation/VJPEmitter.h index d2f8f7fabb996..a0ec673b4d73b 100644 --- a/include/swift/SILOptimizer/Differentiation/VJPEmitter.h +++ b/include/swift/SILOptimizer/Differentiation/VJPEmitter.h @@ -117,9 +117,21 @@ class VJPEmitter final /// Get the lowered SIL type of the given nominal type declaration. SILType getNominalDeclLoweredType(NominalTypeDecl *nominal); - /// Build a pullback struct value for the original block corresponding to the - /// given terminator. - StructInst *buildPullbackValueStructValue(TermInst *termInst); + // Creates a trampoline block for given original terminator instruction, the + // pullback struct value for its parent block, and a successor basic block. + // + // The trampoline block has the same arguments as and branches to the remapped + // successor block, but drops the last predecessor enum argument. + // + // Used for cloning branching terminator instructions with specific + // requirements on successor block arguments, where an additional predecessor + // enum argument is not acceptable. + SILBasicBlock *createTrampolineBasicBlock(TermInst *termInst, + StructInst *pbStructVal, + SILBasicBlock *succBB); + + /// Build a pullback struct value for the given original block. + StructInst *buildPullbackValueStructValue(SILBasicBlock *bb); /// Build a predecessor enum instance using the given builder for the given /// original predecessor/successor blocks and pullback struct value. diff --git a/lib/SILOptimizer/Differentiation/VJPEmitter.cpp b/lib/SILOptimizer/Differentiation/VJPEmitter.cpp index 082d6884f39b1..8486b7b2bec14 100644 --- a/lib/SILOptimizer/Differentiation/VJPEmitter.cpp +++ b/lib/SILOptimizer/Differentiation/VJPEmitter.cpp @@ -265,6 +265,31 @@ SILBasicBlock *VJPEmitter::remapBasicBlock(SILBasicBlock *bb) { return vjpBB; } +SILBasicBlock *VJPEmitter::createTrampolineBasicBlock(TermInst *termInst, + StructInst *pbStructVal, + SILBasicBlock *succBB) { + assert(llvm::find(termInst->getSuccessorBlocks(), succBB) != + termInst->getSuccessorBlocks().end() && + "Basic block is not a successor of terminator instruction"); + // Create the trampoline block. + auto *vjpSuccBB = getOpBasicBlock(succBB); + auto *trampolineBB = vjp->createBasicBlockBefore(vjpSuccBB); + for (auto *arg : vjpSuccBB->getArguments().drop_back()) + trampolineBB->createPhiArgument(arg->getType(), arg->getOwnershipKind()); + // In the trampoline block, build predecessor enum value for VJP successor + // block and branch to it. + SILBuilder trampolineBuilder(trampolineBB); + auto *origBB = termInst->getParent(); + auto *succEnumVal = + buildPredecessorEnumValue(trampolineBuilder, origBB, succBB, pbStructVal); + SmallVector forwardedArguments( + trampolineBB->getArguments().begin(), trampolineBB->getArguments().end()); + forwardedArguments.push_back(succEnumVal); + trampolineBuilder.createBranch(termInst->getLoc(), vjpSuccBB, + forwardedArguments); + return trampolineBB; +} + void VJPEmitter::visit(SILInstruction *inst) { if (errorOccurred) return; @@ -290,10 +315,9 @@ SILType VJPEmitter::getNominalDeclLoweredType(NominalTypeDecl *nominal) { return getLoweredType(nominalType); } -StructInst *VJPEmitter::buildPullbackValueStructValue(TermInst *termInst) { - assert(termInst->getFunction() == original); - auto loc = termInst->getFunction()->getLocation(); - auto *origBB = termInst->getParent(); +StructInst *VJPEmitter::buildPullbackValueStructValue(SILBasicBlock *origBB) { + assert(origBB->getParent() == original); + auto loc = origBB->getParent()->getLocation(); auto *vjpBB = BBMap[origBB]; auto *pbStruct = pullbackInfo.getLinearMapStruct(origBB); auto structLoweredTy = getNominalDeclLoweredType(pbStruct); @@ -333,9 +357,11 @@ EnumInst *VJPEmitter::buildPredecessorEnumValue(SILBuilder &builder, void VJPEmitter::visitReturnInst(ReturnInst *ri) { auto loc = ri->getOperand().getLoc(); - auto *origExit = ri->getParent(); auto &builder = getBuilder(); - auto *pbStructVal = buildPullbackValueStructValue(ri); + + // Build pullback struct value for original block. + auto *origExit = ri->getParent(); + auto *pbStructVal = buildPullbackValueStructValue(origExit); // Get the value in the VJP corresponding to the original result. auto *origRetInst = cast(origExit->getTerminator()); @@ -390,7 +416,7 @@ void VJPEmitter::visitBranchInst(BranchInst *bi) { // Build pullback struct value for original block. // Build predecessor enum value for destination block. auto *origBB = bi->getParent(); - auto *pbStructVal = buildPullbackValueStructValue(bi); + auto *pbStructVal = buildPullbackValueStructValue(origBB); auto *enumVal = buildPredecessorEnumValue(getBuilder(), origBB, bi->getDestBB(), pbStructVal); @@ -407,85 +433,30 @@ void VJPEmitter::visitBranchInst(BranchInst *bi) { void VJPEmitter::visitCondBranchInst(CondBranchInst *cbi) { // Build pullback struct value for original block. - // Build predecessor enum values for true/false blocks. - auto *origBB = cbi->getParent(); - auto *pbStructVal = buildPullbackValueStructValue(cbi); - - // Creates a trampoline block for given original successor block. The - // trampoline block has the same arguments as the VJP successor block but - // drops the last predecessor enum argument. The generated `switch_enum` - // instruction branches to the trampoline block, and the trampoline block - // constructs a predecessor enum value and branches to the VJP successor - // block. - auto createTrampolineBasicBlock = - [&](SILBasicBlock *origSuccBB) -> SILBasicBlock * { - auto *vjpSuccBB = getOpBasicBlock(origSuccBB); - // Create the trampoline block. - auto *trampolineBB = vjp->createBasicBlockBefore(vjpSuccBB); - for (auto *arg : vjpSuccBB->getArguments().drop_back()) - trampolineBB->createPhiArgument(arg->getType(), arg->getOwnershipKind()); - // Build predecessor enum value for successor block and branch to it. - SILBuilder trampolineBuilder(trampolineBB); - auto *succEnumVal = buildPredecessorEnumValue(trampolineBuilder, origBB, - origSuccBB, pbStructVal); - SmallVector forwardedArguments( - trampolineBB->getArguments().begin(), - trampolineBB->getArguments().end()); - forwardedArguments.push_back(succEnumVal); - trampolineBuilder.createBranch(cbi->getLoc(), vjpSuccBB, - forwardedArguments); - return trampolineBB; - }; - + auto *pbStructVal = buildPullbackValueStructValue(cbi->getParent()); // Create a new `cond_br` instruction. - getBuilder().createCondBranch(cbi->getLoc(), getOpValue(cbi->getCondition()), - createTrampolineBasicBlock(cbi->getTrueBB()), - createTrampolineBasicBlock(cbi->getFalseBB())); + getBuilder().createCondBranch( + cbi->getLoc(), getOpValue(cbi->getCondition()), + createTrampolineBasicBlock(cbi, pbStructVal, cbi->getTrueBB()), + createTrampolineBasicBlock(cbi, pbStructVal, cbi->getFalseBB())); } void VJPEmitter::visitSwitchEnumInstBase(SwitchEnumInstBase *sei) { // Build pullback struct value for original block. - auto *origBB = sei->getParent(); - auto *pbStructVal = buildPullbackValueStructValue(sei); - - // Creates a trampoline block for given original successor block. The - // trampoline block has the same arguments as the VJP successor block but - // drops the last predecessor enum argument. The generated `switch_enum` - // instruction branches to the trampoline block, and the trampoline block - // constructs a predecessor enum value and branches to the VJP successor - // block. - auto createTrampolineBasicBlock = - [&](SILBasicBlock *origSuccBB) -> SILBasicBlock * { - auto *vjpSuccBB = getOpBasicBlock(origSuccBB); - // Create the trampoline block. - auto *trampolineBB = vjp->createBasicBlockBefore(vjpSuccBB); - for (auto *destArg : vjpSuccBB->getArguments().drop_back()) - trampolineBB->createPhiArgument(destArg->getType(), - destArg->getOwnershipKind()); - // Build predecessor enum value for successor block and branch to it. - SILBuilder trampolineBuilder(trampolineBB); - auto *succEnumVal = buildPredecessorEnumValue(trampolineBuilder, origBB, - origSuccBB, pbStructVal); - SmallVector forwardedArguments( - trampolineBB->getArguments().begin(), - trampolineBB->getArguments().end()); - forwardedArguments.push_back(succEnumVal); - trampolineBuilder.createBranch(sei->getLoc(), vjpSuccBB, - forwardedArguments); - return trampolineBB; - }; + auto *pbStructVal = buildPullbackValueStructValue(sei->getParent()); // Create trampoline successor basic blocks. SmallVector, 4> caseBBs; for (unsigned i : range(sei->getNumCases())) { auto caseBB = sei->getCase(i); - auto *trampolineBB = createTrampolineBasicBlock(caseBB.second); + auto *trampolineBB = + createTrampolineBasicBlock(sei, pbStructVal, caseBB.second); caseBBs.push_back({caseBB.first, trampolineBB}); } // Create trampoline default basic block. SILBasicBlock *newDefaultBB = nullptr; if (auto *defaultBB = sei->getDefaultBBOrNull().getPtrOrNull()) - newDefaultBB = createTrampolineBasicBlock(defaultBB); + newDefaultBB = createTrampolineBasicBlock(sei, pbStructVal, defaultBB); // Create a new `switch_enum` instruction. switch (sei->getKind()) { From 24de636822b7680db442b6fb161356d2b0034f6a Mon Sep 17 00:00:00 2001 From: Dan Zheng Date: Thu, 28 May 2020 12:41:55 -0700 Subject: [PATCH 10/28] [AutoDiff] Re-enable control_flow.swift test. This test was disabled in SR-12741 due to iphonesimulator-i386 failures. Enabling the test on other platforms is important to prevent regressions. --- test/AutoDiff/validation-test/control_flow.swift | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/test/AutoDiff/validation-test/control_flow.swift b/test/AutoDiff/validation-test/control_flow.swift index d9a45e5b143ff..5a2315022e5f3 100644 --- a/test/AutoDiff/validation-test/control_flow.swift +++ b/test/AutoDiff/validation-test/control_flow.swift @@ -1,7 +1,9 @@ // RUN: %target-run-simple-swift // REQUIRES: executable_test -// REQUIRES: SR12741 +// FIXME(SR-12741): Enable test for all platforms after debugging +// iphonesimulator-i386-specific failures. +// REQUIRES: CPU=x86_64 import _Differentiation import StdlibUnittest From c050a26aebd1d9677e1315cf87ef9f595e9a3bd5 Mon Sep 17 00:00:00 2001 From: Saleem Abdulrasool Date: Thu, 28 May 2020 12:16:18 -0700 Subject: [PATCH 11/28] runtime: further isolate runtime from LLVMSupport This cleans up some more `llvm::` leakage in the runtime when built into a static library. With this change we are down to 3 leaking symbols in the static library related to a missed ADT (`StringSwitch`). --- include/swift/Basic/LLVM.h | 4 ++-- include/swift/Demangling/TypeDecoder.h | 14 +++++------ stdlib/public/Reflection/TypeRef.cpp | 2 +- stdlib/public/runtime/Demangle.cpp | 8 +++---- stdlib/public/runtime/MetadataLookup.cpp | 24 +++++++++---------- stdlib/public/runtime/ProtocolConformance.cpp | 2 +- 6 files changed, 27 insertions(+), 27 deletions(-) diff --git a/include/swift/Basic/LLVM.h b/include/swift/Basic/LLVM.h index c88b322b0ecd7..df5db366409dd 100644 --- a/include/swift/Basic/LLVM.h +++ b/include/swift/Basic/LLVM.h @@ -36,8 +36,8 @@ namespace llvm { template class SmallPtrSet; #if !defined(swiftCore_EXPORTS) template class SmallVectorImpl; -#endif template class SmallVector; +#endif template class SmallString; template class SmallSetVector; #if !defined(swiftCore_EXPORTS) @@ -86,8 +86,8 @@ namespace swift { using llvm::SmallPtrSetImpl; using llvm::SmallSetVector; using llvm::SmallString; - using llvm::SmallVector; #if !defined(swiftCore_EXPORTS) + using llvm::SmallVector; using llvm::SmallVectorImpl; #endif using llvm::StringLiteral; diff --git a/include/swift/Demangling/TypeDecoder.h b/include/swift/Demangling/TypeDecoder.h index 8bfc73f4e932d..017fd096e0f74 100644 --- a/include/swift/Demangling/TypeDecoder.h +++ b/include/swift/Demangling/TypeDecoder.h @@ -361,7 +361,7 @@ class TypeDecoder { if (Node->getNumChildren() < 2) return BuiltType(); - SmallVector args; + llvm::SmallVector args; const auto &genericArgs = Node->getChild(1); if (genericArgs->getKind() != NodeKind::TypeList) @@ -474,7 +474,7 @@ class TypeDecoder { return BuiltType(); // Find the protocol list. - SmallVector Protocols; + llvm::SmallVector Protocols; auto TypeList = Node->getChild(0); if (TypeList->getKind() == NodeKind::ProtocolList && TypeList->getNumChildren() >= 1) { @@ -576,7 +576,7 @@ class TypeDecoder { return BuiltType(); bool hasParamFlags = false; - SmallVector, 8> parameters; + llvm::SmallVector, 8> parameters; if (!decodeMangledFunctionInputType(Node->getChild(isThrow ? 1 : 0), parameters, hasParamFlags)) return BuiltType(); @@ -598,9 +598,9 @@ class TypeDecoder { } case NodeKind::ImplFunctionType: { auto calleeConvention = ImplParameterConvention::Direct_Unowned; - SmallVector, 8> parameters; - SmallVector, 8> results; - SmallVector, 8> errorResults; + llvm::SmallVector, 8> parameters; + llvm::SmallVector, 8> results; + llvm::SmallVector, 8> errorResults; ImplFunctionTypeFlags flags; for (unsigned i = 0; i < Node->getNumChildren(); i++) { @@ -684,7 +684,7 @@ class TypeDecoder { return decodeMangledType(Node->getChild(0)); case NodeKind::Tuple: { - SmallVector elements; + llvm::SmallVector elements; std::string labels; bool variadic = false; for (auto &element : *Node) { diff --git a/stdlib/public/Reflection/TypeRef.cpp b/stdlib/public/Reflection/TypeRef.cpp index b27624338c839..e775c72ac352d 100644 --- a/stdlib/public/Reflection/TypeRef.cpp +++ b/stdlib/public/Reflection/TypeRef.cpp @@ -491,7 +491,7 @@ class DemanglingForTypeRef break; } - SmallVector, 8> inputs; + llvm::SmallVector, 8> inputs; for (const auto ¶m : F->getParameters()) { auto flags = param.getFlags(); auto input = visit(param.getType()); diff --git a/stdlib/public/runtime/Demangle.cpp b/stdlib/public/runtime/Demangle.cpp index 0e3d9987173de..0e820890d5882 100644 --- a/stdlib/public/runtime/Demangle.cpp +++ b/stdlib/public/runtime/Demangle.cpp @@ -34,7 +34,7 @@ swift::_buildDemanglingForContext(const ContextDescriptor *context, NodePointer node = nullptr; // Walk up the context tree. - SmallVector descriptorPath; + llvm::SmallVector descriptorPath; { const ContextDescriptor *parent = context; while (parent) { @@ -285,11 +285,11 @@ _buildDemanglingForNominalType(const Metadata *type, Demangle::Demangler &Dem) { // Gather the complete set of generic arguments that must be written to // form this type. - SmallVector allGenericArgs; + llvm::SmallVector allGenericArgs; gatherWrittenGenericArgs(type, description, allGenericArgs, Dem); // Demangle the generic arguments. - SmallVector demangledGenerics; + llvm::SmallVector demangledGenerics; for (auto genericArg : allGenericArgs) { // When there is no generic argument, put in a placeholder. if (!genericArg) { @@ -470,7 +470,7 @@ swift::_swift_buildDemanglingForMetadata(const Metadata *type, break; } - SmallVector, 8> inputs; + llvm::SmallVector, 8> inputs; for (unsigned i = 0, e = func->getNumParameters(); i < e; ++i) { auto param = func->getParameter(i); auto flags = func->getParameterFlags(i); diff --git a/stdlib/public/runtime/MetadataLookup.cpp b/stdlib/public/runtime/MetadataLookup.cpp index 3aeb5581251f1..85440b3050593 100644 --- a/stdlib/public/runtime/MetadataLookup.cpp +++ b/stdlib/public/runtime/MetadataLookup.cpp @@ -988,8 +988,8 @@ _gatherGenericParameters(const ContextDescriptor *context, // requirements and fill in the generic arguments vector. if (!genericParamCounts.empty()) { // Compute the set of generic arguments "as written". - SmallVector allGenericArgs; - + llvm::SmallVector allGenericArgs; + // If we have a parent, gather it's generic arguments "as written". if (parent) { gatherWrittenGenericArgs(parent, parent->getTypeContextDescriptor(), @@ -1183,16 +1183,16 @@ class DecodedMetadataBuilder { if (!descriptor) return BuiltType(); auto outerContext = descriptor->Parent.get(); - - SmallVector allGenericArgs; + + llvm::SmallVector allGenericArgs; for (auto argSet : genericArgs) { allGenericArgs.append(argSet.begin(), argSet.end()); } // Gather the generic parameters we need to parameterize the opaque decl. - SmallVector genericParamCounts; - SmallVector allGenericArgsVec; - + llvm::SmallVector genericParamCounts; + llvm::SmallVector allGenericArgsVec; + if (!_gatherGenericParameters(outerContext, allGenericArgs, BuiltType(), /* no parent */ @@ -1291,8 +1291,8 @@ class DecodedMetadataBuilder { // Figure out the various levels of generic parameters we have in // this type. - SmallVector genericParamCounts; - SmallVector allGenericArgsVec; + llvm::SmallVector genericParamCounts; + llvm::SmallVector allGenericArgsVec; if (!_gatherGenericParameters(typeDecl, genericArgs, @@ -1366,8 +1366,8 @@ class DecodedMetadataBuilder { BuiltType createFunctionType(llvm::ArrayRef> params, BuiltType result, FunctionTypeFlags flags) const { - SmallVector paramTypes; - SmallVector paramFlags; + llvm::SmallVector paramTypes; + llvm::SmallVector paramFlags; // Fill in the parameters. paramTypes.reserve(params.size()); @@ -2032,7 +2032,7 @@ void swift::gatherWrittenGenericArgs( // canonicalized away. Use same-type requirements to reconstitute them. // Retrieve the mapping information needed for depth/index -> flat index. - SmallVector genericParamCounts; + llvm::SmallVector genericParamCounts; (void)_gatherGenericParameterCounts(description, genericParamCounts, BorrowFrom); diff --git a/stdlib/public/runtime/ProtocolConformance.cpp b/stdlib/public/runtime/ProtocolConformance.cpp index 5d951c2c4c876..bef53aa1ea684 100644 --- a/stdlib/public/runtime/ProtocolConformance.cpp +++ b/stdlib/public/runtime/ProtocolConformance.cpp @@ -162,7 +162,7 @@ template<> const WitnessTable * ProtocolConformanceDescriptor::getWitnessTable(const Metadata *type) const { // If needed, check the conditional requirements. - SmallVector conditionalArgs; + llvm::SmallVector conditionalArgs; if (hasConditionalRequirements()) { SubstGenericParametersFromMetadata substitutions(type); bool failed = From d5d076db6a2381fc9f205b978b0722457f794c34 Mon Sep 17 00:00:00 2001 From: Dan Zheng Date: Thu, 28 May 2020 12:44:36 -0700 Subject: [PATCH 12/28] [AutoDiff] Support differentiation of branching cast instructions. Support differentiation of `is` and `as?` operators. These operators lower to branching cast SIL instructions, requiring control flow differentiation support. Resolves SR-12898. --- .../SILOptimizer/Differentiation/VJPEmitter.h | 6 + .../DifferentiableActivityAnalysis.cpp | 22 +++- .../Differentiation/VJPEmitter.cpp | 41 +++++++ .../Mandatory/Differentiation.cpp | 8 +- .../SILOptimizer/activity_analysis.swift | 110 ++++++++++++++++++ .../validation-test/control_flow.swift | 21 ++++ 6 files changed, 201 insertions(+), 7 deletions(-) diff --git a/include/swift/SILOptimizer/Differentiation/VJPEmitter.h b/include/swift/SILOptimizer/Differentiation/VJPEmitter.h index a0ec673b4d73b..475ac7b3f7272 100644 --- a/include/swift/SILOptimizer/Differentiation/VJPEmitter.h +++ b/include/swift/SILOptimizer/Differentiation/VJPEmitter.h @@ -153,6 +153,12 @@ class VJPEmitter final void visitSwitchEnumAddrInst(SwitchEnumAddrInst *seai); + void visitCheckedCastBranchInst(CheckedCastBranchInst *ccbi); + + void visitCheckedCastValueBranchInst(CheckedCastValueBranchInst *ccvbi); + + void visitCheckedCastAddrBranchInst(CheckedCastAddrBranchInst *ccabi); + // If an `apply` has active results or active inout arguments, replace it // with an `apply` of its VJP. void visitApplyInst(ApplyInst *ai); diff --git a/lib/SILOptimizer/Analysis/DifferentiableActivityAnalysis.cpp b/lib/SILOptimizer/Analysis/DifferentiableActivityAnalysis.cpp index 8358319ef16b2..c903dbeb6435c 100644 --- a/lib/SILOptimizer/Analysis/DifferentiableActivityAnalysis.cpp +++ b/lib/SILOptimizer/Analysis/DifferentiableActivityAnalysis.cpp @@ -193,12 +193,26 @@ void DifferentiableActivityInfo::propagateVaried( if (auto *destBBArg = cbi->getArgForOperand(operand)) setVariedAndPropagateToUsers(destBBArg, i); } - // Handle `switch_enum`. - else if (auto *sei = dyn_cast(inst)) { - if (isVaried(sei->getOperand(), i)) - for (auto *succBB : sei->getSuccessorBlocks()) + // Handle `checked_cast_addr_br`. + // Propagate variedness from source operand to destination operand, in + // addition to all successor block arguments. + else if (auto *ccabi = dyn_cast(inst)) { + if (isVaried(ccabi->getSrc(), i)) { + setVariedAndPropagateToUsers(ccabi->getDest(), i); + for (auto *succBB : ccabi->getSuccessorBlocks()) for (auto *arg : succBB->getArguments()) setVariedAndPropagateToUsers(arg, i); + } + } + // Handle all other terminators: if any operand is active, propagate + // variedness to all successor block arguments. This logic may be incorrect + // for some terminator instructions, so special cases must be defined above. + else if (auto *termInst = dyn_cast(inst)) { + for (auto &op : termInst->getAllOperands()) + if (isVaried(op.get(), i)) + for (auto *succBB : termInst->getSuccessorBlocks()) + for (auto *arg : succBB->getArguments()) + setVariedAndPropagateToUsers(arg, i); } // Handle everything else. else { diff --git a/lib/SILOptimizer/Differentiation/VJPEmitter.cpp b/lib/SILOptimizer/Differentiation/VJPEmitter.cpp index 8486b7b2bec14..ec6b82cc7fb17 100644 --- a/lib/SILOptimizer/Differentiation/VJPEmitter.cpp +++ b/lib/SILOptimizer/Differentiation/VJPEmitter.cpp @@ -481,6 +481,47 @@ void VJPEmitter::visitSwitchEnumAddrInst(SwitchEnumAddrInst *seai) { visitSwitchEnumInstBase(seai); } +void VJPEmitter::visitCheckedCastBranchInst(CheckedCastBranchInst *ccbi) { + // Build pullback struct value for original block. + auto *pbStructVal = buildPullbackValueStructValue(ccbi->getParent()); + // Create a new `checked_cast_branch` instruction. + getBuilder().createCheckedCastBranch( + ccbi->getLoc(), ccbi->isExact(), getOpValue(ccbi->getOperand()), + getOpType(ccbi->getTargetLoweredType()), + getOpASTType(ccbi->getTargetFormalType()), + createTrampolineBasicBlock(ccbi, pbStructVal, ccbi->getSuccessBB()), + createTrampolineBasicBlock(ccbi, pbStructVal, ccbi->getFailureBB()), + ccbi->getTrueBBCount(), ccbi->getFalseBBCount()); +} + +void VJPEmitter::visitCheckedCastValueBranchInst( + CheckedCastValueBranchInst *ccvbi) { + // Build pullback struct value for original block. + auto *pbStructVal = buildPullbackValueStructValue(ccvbi->getParent()); + // Create a new `checked_cast_value_branch` instruction. + getBuilder().createCheckedCastValueBranch( + ccvbi->getLoc(), getOpValue(ccvbi->getOperand()), + getOpASTType(ccvbi->getSourceFormalType()), + getOpType(ccvbi->getTargetLoweredType()), + getOpASTType(ccvbi->getTargetFormalType()), + createTrampolineBasicBlock(ccvbi, pbStructVal, ccvbi->getSuccessBB()), + createTrampolineBasicBlock(ccvbi, pbStructVal, ccvbi->getFailureBB())); +} + +void VJPEmitter::visitCheckedCastAddrBranchInst( + CheckedCastAddrBranchInst *ccabi) { + // Build pullback struct value for original block. + auto *pbStructVal = buildPullbackValueStructValue(ccabi->getParent()); + // Create a new `checked_cast_addr_branch` instruction. + getBuilder().createCheckedCastAddrBranch( + ccabi->getLoc(), ccabi->getConsumptionKind(), getOpValue(ccabi->getSrc()), + getOpASTType(ccabi->getSourceFormalType()), getOpValue(ccabi->getDest()), + getOpASTType(ccabi->getTargetFormalType()), + createTrampolineBasicBlock(ccabi, pbStructVal, ccabi->getSuccessBB()), + createTrampolineBasicBlock(ccabi, pbStructVal, ccabi->getFailureBB()), + ccabi->getTrueBBCount(), ccabi->getFalseBBCount()); +} + void VJPEmitter::visitApplyInst(ApplyInst *ai) { // If callee should not be differentiated, do standard cloning. if (!pullbackInfo.shouldDifferentiateApplySite(ai)) { diff --git a/lib/SILOptimizer/Mandatory/Differentiation.cpp b/lib/SILOptimizer/Mandatory/Differentiation.cpp index 89588cd7b110d..3914cc3380e2f 100644 --- a/lib/SILOptimizer/Mandatory/Differentiation.cpp +++ b/lib/SILOptimizer/Mandatory/Differentiation.cpp @@ -152,10 +152,12 @@ static bool diagnoseUnsupportedControlFlow(ADContext &context, // Diagnose unsupported branching terminators. for (auto &bb : *original) { auto *term = bb.getTerminator(); - // Supported terminators are: `br`, `cond_br`, `switch_enum`, - // `switch_enum_addr`. + // Check supported branching terminators. if (isa(term) || isa(term) || - isa(term) || isa(term)) + isa(term) || isa(term) || + isa(term) || + isa(term) || + isa(term)) continue; // If terminator is an unsupported branching terminator, emit an error. if (term->isBranch()) { diff --git a/test/AutoDiff/SILOptimizer/activity_analysis.swift b/test/AutoDiff/SILOptimizer/activity_analysis.swift index 6c1fc31f23cf4..b4360484123a5 100644 --- a/test/AutoDiff/SILOptimizer/activity_analysis.swift +++ b/test/AutoDiff/SILOptimizer/activity_analysis.swift @@ -122,6 +122,116 @@ func TF_954(_ x: Float) -> Float { // CHECK: [ACTIVE] %40 = begin_access [read] [static] %2 : $*Float // CHECK: [ACTIVE] %41 = load [trivial] %40 : $*Float +//===----------------------------------------------------------------------===// +// Branching cast instructions +//===----------------------------------------------------------------------===// + +@differentiable +func checked_cast_branch(_ x: Float) -> Float { + // expected-warning @+1 {{'is' test is always true}} + if Int.self is Any.Type { + return x + x + } + return x * x +} + +// CHECK-LABEL: [AD] Activity info for ${{.*}}checked_cast_branch{{.*}} at (source=0 parameters=(0)) +// CHECK: bb0: +// CHECK: [ACTIVE] %0 = argument of bb0 : $Float +// CHECK: [NONE] %2 = metatype $@thin Int.Type +// CHECK: [NONE] %3 = metatype $@thick Int.Type +// CHECK: bb1: +// CHECK: [NONE] %5 = argument of bb1 : $@thick Any.Type +// CHECK: [NONE] %6 = integer_literal $Builtin.Int1, -1 +// CHECK: bb2: +// CHECK: [NONE] %8 = argument of bb2 : $@thick Int.Type +// CHECK: [NONE] %9 = integer_literal $Builtin.Int1, 0 +// CHECK: bb3: +// CHECK: [NONE] %11 = argument of bb3 : $Builtin.Int1 +// CHECK: [NONE] %12 = metatype $@thin Bool.Type +// CHECK: [NONE] // function_ref Bool.init(_builtinBooleanLiteral:) +// CHECK: [NONE] %14 = apply %13(%11, %12) : $@convention(method) (Builtin.Int1, @thin Bool.Type) -> Bool +// CHECK: [NONE] %15 = struct_extract %14 : $Bool, #Bool._value +// CHECK: bb4: +// CHECK: [USEFUL] %17 = metatype $@thin Float.Type +// CHECK: [NONE] // function_ref static Float.+ infix(_:_:) +// CHECK: [ACTIVE] %19 = apply %18(%0, %0, %17) : $@convention(method) (Float, Float, @thin Float.Type) -> Float +// CHECK: bb5: +// CHECK: [USEFUL] %21 = metatype $@thin Float.Type +// CHECK: [NONE] // function_ref static Float.* infix(_:_:) +// CHECK: [ACTIVE] %23 = apply %22(%0, %0, %21) : $@convention(method) (Float, Float, @thin Float.Type) -> Float + +// CHECK-LABEL: sil hidden [ossa] @${{.*}}checked_cast_branch{{.*}} : $@convention(thin) (Float) -> Float { +// CHECK: checked_cast_br %3 : $@thick Int.Type to Any.Type, bb1, bb2 +// CHECK: } + +@differentiable +func checked_cast_addr_nonactive_result(_ x: T) -> T { + if let _ = x as? Float { + // Do nothing with `y: Float?` value. + } + return x +} + +// CHECK-LABEL: [AD] Activity info for ${{.*}}checked_cast_addr_nonactive_result{{.*}} at (source=0 parameters=(0)) +// CHECK: bb0: +// CHECK: [ACTIVE] %0 = argument of bb0 : $*T +// CHECK: [ACTIVE] %1 = argument of bb0 : $*T +// CHECK: [VARIED] %3 = alloc_stack $T +// CHECK: [VARIED] %5 = alloc_stack $Float +// CHECK: bb1: +// CHECK: [VARIED] %7 = load [trivial] %5 : $*Float +// CHECK: [VARIED] %8 = enum $Optional, #Optional.some!enumelt, %7 : $Float +// CHECK: bb2: +// CHECK: [NONE] %11 = enum $Optional, #Optional.none!enumelt +// CHECK: bb3: +// CHECK: [VARIED] %14 = argument of bb3 : $Optional +// CHECK: bb4: +// CHECK: bb5: +// CHECK: [VARIED] %18 = argument of bb5 : $Float +// CHECK: bb6: +// CHECK: [NONE] %22 = tuple () + +// CHECK-LABEL: sil hidden [ossa] @${{.*}}checked_cast_addr_nonactive_result{{.*}} : $@convention(thin) (@in_guaranteed T) -> @out T { +// CHECK: checked_cast_addr_br take_always T in %3 : $*T to Float in %5 : $*Float, bb1, bb2 +// CHECK: } + +// expected-error @+1 {{function is not differentiable}} +@differentiable +// expected-note @+1 {{when differentiating this function definition}} +func checked_cast_addr_active_result(x: T) -> T { + // expected-note @+1 {{differentiating enum values is not yet supported}} + if let y = x as? Float { + // Use `y: Float?` value in an active way. + return y as! T + } + return x +} + +// CHECK-LABEL: [AD] Activity info for ${{.*}}checked_cast_addr_active_result{{.*}} at (source=0 parameters=(0)) +// CHECK: bb0: +// CHECK: [ACTIVE] %0 = argument of bb0 : $*T +// CHECK: [ACTIVE] %1 = argument of bb0 : $*T +// CHECK: [ACTIVE] %3 = alloc_stack $T +// CHECK: [ACTIVE] %5 = alloc_stack $Float +// CHECK: bb1: +// CHECK: [ACTIVE] %7 = load [trivial] %5 : $*Float +// CHECK: [ACTIVE] %8 = enum $Optional, #Optional.some!enumelt, %7 : $Float +// CHECK: bb2: +// CHECK: [USEFUL] %11 = enum $Optional, #Optional.none!enumelt +// CHECK: bb3: +// CHECK: [ACTIVE] %14 = argument of bb3 : $Optional +// CHECK: bb4: +// CHECK: [ACTIVE] %16 = argument of bb4 : $Float +// CHECK: [ACTIVE] %19 = alloc_stack $Float +// CHECK: bb5: +// CHECK: bb6: +// CHECK: [NONE] %27 = tuple () + +// CHECK-LABEL: sil hidden [ossa] @${{.*}}checked_cast_addr_active_result{{.*}} : $@convention(thin) (@in_guaranteed T) -> @out T { +// CHECK: checked_cast_addr_br take_always T in %3 : $*T to Float in %5 : $*Float, bb1, bb2 +// CHECK: } + //===----------------------------------------------------------------------===// // Array literal differentiation //===----------------------------------------------------------------------===// diff --git a/test/AutoDiff/validation-test/control_flow.swift b/test/AutoDiff/validation-test/control_flow.swift index 5a2315022e5f3..d2042a1135216 100644 --- a/test/AutoDiff/validation-test/control_flow.swift +++ b/test/AutoDiff/validation-test/control_flow.swift @@ -715,4 +715,25 @@ ControlFlowTests.test("Loops") { expectEqual((24, 28), valueWithGradient(at: 2, in: { x in nested_loop2(x, count: 4) })) } +ControlFlowTests.test("BranchingCastInstructions") { + // checked_cast_br + func typeCheckOperator(_ x: Float, _ metatype: T.Type) -> Float { + if metatype is Int.Type { + return x + x + } + return x * x + } + expectEqual((6, 2), valueWithGradient(at: 3, in: { typeCheckOperator($0, Int.self) })) + expectEqual((9, 6), valueWithGradient(at: 3, in: { typeCheckOperator($0, Float.self) })) + + // checked_cast_addr_br + func conditionalCast(_ x: T) -> T { + if let _ = x as? Float { + // Do nothing with `y: Float?` value. + } + return x + } + expectEqual((3, 1), valueWithGradient(at: Float(3), in: conditionalCast)) +} + runAllTests() From 56c77df926f73f3890c821aca8667c93599d7ade Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alexis=20Laferrie=CC=80re?= Date: Wed, 27 May 2020 14:30:34 -0700 Subject: [PATCH 13/28] [Serialization] Skip SPI documentation in swiftdoc files Hide comments from SPI decls in all swiftdoc files. This applies the same restrictions as private declarations. This is a temporary solution, a long term fix is to emit both a public and an internal swiftdoc file. rdar://63729195 --- lib/Serialization/SerializeDoc.cpp | 7 ++++++- test/Serialization/comments-hidden.swift | 16 ++++++++++++++-- 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/lib/Serialization/SerializeDoc.cpp b/lib/Serialization/SerializeDoc.cpp index c16b70012b5a4..6eaf7ab4cd443 100644 --- a/lib/Serialization/SerializeDoc.cpp +++ b/lib/Serialization/SerializeDoc.cpp @@ -344,6 +344,11 @@ static bool shouldIncludeDecl(Decl *D, bool ExcludeDoubleUnderscore) { if (VD->getEffectiveAccess() < swift::AccessLevel::Public) return false; } + + // Skip SPI decls. + if (D->isSPI()) + return false; + if (auto *ED = dyn_cast(D)) { return shouldIncludeDecl(ED->getExtendedNominal(), ExcludeDoubleUnderscore); } @@ -745,7 +750,7 @@ Result.X.Column = Locs->X.Column; }; // .swiftdoc doesn't include comments for double underscored symbols, but // for .swiftsourceinfo, having the source location for these symbols isn't - // a concern becuase these symbols are in .swiftinterface anyway. + // a concern because these symbols are in .swiftinterface anyway. if (!shouldIncludeDecl(D, /*ExcludeDoubleUnderscore*/false)) return false; if (!shouldSerializeSourceLoc(D)) diff --git a/test/Serialization/comments-hidden.swift b/test/Serialization/comments-hidden.swift index 9ad0e52fd170c..8eb92e04b4b8a 100644 --- a/test/Serialization/comments-hidden.swift +++ b/test/Serialization/comments-hidden.swift @@ -43,6 +43,8 @@ public class PublicClass { public init(label __name: String) {} /// Public Filter Subscript Documentation NotForNormal NotForTesting public subscript(label __name: String) -> Int { return 0 } + /// SPI Function Documentation NotForNormal NotForTesting + @_spi(SPI) public func f_spi() { } } public extension PublicClass { @@ -64,6 +66,16 @@ private class PrivateClass { private func f_private() { } } +/// SPI Documentation NotForNormal NotForTesting +@_spi(SPI) public class SPIClass { + /// SPI Function Documentation NotForNormal NotForTesting + public func f_spi() { } +} + +/// SPI Extension Documentation NotForNormal NotForTesting +@_spi(SPI) public extension PublicClass { +} + // NORMAL-NEGATIVE-NOT: NotForNormal // NORMAL-NEGATIVE-NOT: NotForTesting // NORMAL: PublicClass Documentation @@ -74,7 +86,7 @@ private class PrivateClass { // TESTING-NEGATIVE-NOT: NotForTesting // TESTING: PublicClass Documentation // TESTING: Public Function Documentation -// TESTINH: Public Init Documentation +// TESTING: Public Init Documentation // TESTING: Public Subscript Documentation // TESTING: Internal Function Documentation // TESTING: InternalClass Documentation @@ -85,4 +97,4 @@ private class PrivateClass { // SOURCE-LOC: comments-hidden.swift:41:10: Subscript/PublicClass.subscript RawComment=none BriefComment=none DocCommentAsXML=none // SOURCE-LOC: comments-hidden.swift:43:10: Constructor/PublicClass.init RawComment=none BriefComment=none DocCommentAsXML=none // SOURCE-LOC: comments-hidden.swift:45:10: Subscript/PublicClass.subscript RawComment=none BriefComment=none DocCommentAsXML=none -// SOURCE-LOC: comments-hidden.swift:50:15: Func/-= RawComment=none BriefComment=none DocCommentAsXML=none +// SOURCE-LOC: comments-hidden.swift:52:15: Func/-= RawComment=none BriefComment=none DocCommentAsXML=none From b89ef78ce5be633981f7935ab3e728a578cb5120 Mon Sep 17 00:00:00 2001 From: Joe Groff Date: Thu, 28 May 2020 14:10:39 -0700 Subject: [PATCH 14/28] Adding `final` to public API is API-stable. Client code can't override or subclass a `public` declaration already, so although the ABI differs, the API is the same whether something is `final` or not. --- include/swift/AST/Attr.def | 2 +- test/api-digester/Outputs/Cake.txt | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/include/swift/AST/Attr.def b/include/swift/AST/Attr.def index 8054500c3b772..95504c3477426 100644 --- a/include/swift/AST/Attr.def +++ b/include/swift/AST/Attr.def @@ -126,7 +126,7 @@ DECL_ATTR(available, Available, CONTEXTUAL_SIMPLE_DECL_ATTR(final, Final, OnClass | OnFunc | OnAccessor | OnVar | OnSubscript | DeclModifier | - ABIBreakingToAdd | ABIBreakingToRemove | APIBreakingToAdd | APIStableToRemove, + ABIBreakingToAdd | ABIBreakingToRemove | APIStableToAdd | APIStableToRemove, 2) DECL_ATTR(objc, ObjC, OnAbstractFunction | OnClass | OnProtocol | OnExtension | OnVar | diff --git a/test/api-digester/Outputs/Cake.txt b/test/api-digester/Outputs/Cake.txt index 2ca2dc6739d81..a794079a71f9e 100644 --- a/test/api-digester/Outputs/Cake.txt +++ b/test/api-digester/Outputs/Cake.txt @@ -40,7 +40,6 @@ cake: TypeAlias TChangesFromIntToString.T has underlying type change from Swift. /* Decl Attribute changes */ cake: Enum IceKind is now without @frozen cake: Func C1.foo1() is now not static -cake: Func FinalFuncContainer.NewFinalFunc() is now with final cake: Func HasMutatingMethodClone.foo() has self access kind changing from Mutating to NonMutating cake: Func S1.foo1() has self access kind changing from NonMutating to Mutating cake: Func S1.foo3() is now static From d5c40bf2310be40cd654ca757ae366028abf368f Mon Sep 17 00:00:00 2001 From: Mishal Shah Date: Thu, 28 May 2020 14:10:40 -0700 Subject: [PATCH 15/28] Revert "Disable objc_mangling.swift and SwiftObjectNSObject.swift test on tvOS" This reverts commit 7e2c2452ad218f0554c85eb8c93a5641c44ce1b9. --- test/Interpreter/SDK/objc_mangling.swift | 1 - test/stdlib/SwiftObjectNSObject.swift | 1 - 2 files changed, 2 deletions(-) diff --git a/test/Interpreter/SDK/objc_mangling.swift b/test/Interpreter/SDK/objc_mangling.swift index 887263faf75e4..c9a8c3f8d6c90 100644 --- a/test/Interpreter/SDK/objc_mangling.swift +++ b/test/Interpreter/SDK/objc_mangling.swift @@ -8,7 +8,6 @@ // rdar://problem/56959761 // UNSUPPORTED: OS=watchos -// UNSUPPORTED: OS=tvos import Foundation diff --git a/test/stdlib/SwiftObjectNSObject.swift b/test/stdlib/SwiftObjectNSObject.swift index 68b3268d6b8b2..695a357427a0d 100644 --- a/test/stdlib/SwiftObjectNSObject.swift +++ b/test/stdlib/SwiftObjectNSObject.swift @@ -23,7 +23,6 @@ // rdar://problem/56959761 // UNSUPPORTED: OS=watchos -// UNSUPPORTED: OS=tvos import Foundation From d4f48718c1253c116d3d997097d2135c305ef566 Mon Sep 17 00:00:00 2001 From: Saleem Abdulrasool Date: Thu, 28 May 2020 13:29:20 -0700 Subject: [PATCH 16/28] stdlib: add `StringSwitch` to LLVMSupport fork This is used in the standard library for the reflection. Add the missing header. Somehow this was missed in the dynamic version of the standard library. --- stdlib/include/llvm/ADT/StringSwitch.h | 198 +++++++++++++++++++++++++ 1 file changed, 198 insertions(+) create mode 100644 stdlib/include/llvm/ADT/StringSwitch.h diff --git a/stdlib/include/llvm/ADT/StringSwitch.h b/stdlib/include/llvm/ADT/StringSwitch.h new file mode 100644 index 0000000000000..5c4e7c1714783 --- /dev/null +++ b/stdlib/include/llvm/ADT/StringSwitch.h @@ -0,0 +1,198 @@ +//===--- StringSwitch.h - Switch-on-literal-string Construct --------------===/ +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +//===----------------------------------------------------------------------===/ +// +// This file implements the StringSwitch template, which mimics a switch() +// statement whose cases are string literals. +// +//===----------------------------------------------------------------------===/ +#ifndef LLVM_ADT_STRINGSWITCH_H +#define LLVM_ADT_STRINGSWITCH_H + +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Compiler.h" +#include +#include + +inline namespace __swift { inline namespace __runtime { +namespace llvm { + +/// A switch()-like statement whose cases are string literals. +/// +/// The StringSwitch class is a simple form of a switch() statement that +/// determines whether the given string matches one of the given string +/// literals. The template type parameter \p T is the type of the value that +/// will be returned from the string-switch expression. For example, +/// the following code switches on the name of a color in \c argv[i]: +/// +/// \code +/// Color color = StringSwitch(argv[i]) +/// .Case("red", Red) +/// .Case("orange", Orange) +/// .Case("yellow", Yellow) +/// .Case("green", Green) +/// .Case("blue", Blue) +/// .Case("indigo", Indigo) +/// .Cases("violet", "purple", Violet) +/// .Default(UnknownColor); +/// \endcode +template +class StringSwitch { + /// The string we are matching. + const StringRef Str; + + /// The pointer to the result of this switch statement, once known, + /// null before that. + Optional Result; + +public: + explicit StringSwitch(StringRef S) + : Str(S), Result() { } + + // StringSwitch is not copyable. + StringSwitch(const StringSwitch &) = delete; + + // StringSwitch is not assignable due to 'Str' being 'const'. + void operator=(const StringSwitch &) = delete; + void operator=(StringSwitch &&other) = delete; + + StringSwitch(StringSwitch &&other) + : Str(other.Str), Result(std::move(other.Result)) { } + + ~StringSwitch() = default; + + // Case-sensitive case matchers + StringSwitch &Case(StringLiteral S, T Value) { + if (!Result && Str == S) { + Result = std::move(Value); + } + return *this; + } + + StringSwitch& EndsWith(StringLiteral S, T Value) { + if (!Result && Str.endswith(S)) { + Result = std::move(Value); + } + return *this; + } + + StringSwitch& StartsWith(StringLiteral S, T Value) { + if (!Result && Str.startswith(S)) { + Result = std::move(Value); + } + return *this; + } + + StringSwitch &Cases(StringLiteral S0, StringLiteral S1, T Value) { + return Case(S0, Value).Case(S1, Value); + } + + StringSwitch &Cases(StringLiteral S0, StringLiteral S1, StringLiteral S2, + T Value) { + return Case(S0, Value).Cases(S1, S2, Value); + } + + StringSwitch &Cases(StringLiteral S0, StringLiteral S1, StringLiteral S2, + StringLiteral S3, T Value) { + return Case(S0, Value).Cases(S1, S2, S3, Value); + } + + StringSwitch &Cases(StringLiteral S0, StringLiteral S1, StringLiteral S2, + StringLiteral S3, StringLiteral S4, T Value) { + return Case(S0, Value).Cases(S1, S2, S3, S4, Value); + } + + StringSwitch &Cases(StringLiteral S0, StringLiteral S1, StringLiteral S2, + StringLiteral S3, StringLiteral S4, StringLiteral S5, + T Value) { + return Case(S0, Value).Cases(S1, S2, S3, S4, S5, Value); + } + + StringSwitch &Cases(StringLiteral S0, StringLiteral S1, StringLiteral S2, + StringLiteral S3, StringLiteral S4, StringLiteral S5, + StringLiteral S6, T Value) { + return Case(S0, Value).Cases(S1, S2, S3, S4, S5, S6, Value); + } + + StringSwitch &Cases(StringLiteral S0, StringLiteral S1, StringLiteral S2, + StringLiteral S3, StringLiteral S4, StringLiteral S5, + StringLiteral S6, StringLiteral S7, T Value) { + return Case(S0, Value).Cases(S1, S2, S3, S4, S5, S6, S7, Value); + } + + StringSwitch &Cases(StringLiteral S0, StringLiteral S1, StringLiteral S2, + StringLiteral S3, StringLiteral S4, StringLiteral S5, + StringLiteral S6, StringLiteral S7, StringLiteral S8, + T Value) { + return Case(S0, Value).Cases(S1, S2, S3, S4, S5, S6, S7, S8, Value); + } + + StringSwitch &Cases(StringLiteral S0, StringLiteral S1, StringLiteral S2, + StringLiteral S3, StringLiteral S4, StringLiteral S5, + StringLiteral S6, StringLiteral S7, StringLiteral S8, + StringLiteral S9, T Value) { + return Case(S0, Value).Cases(S1, S2, S3, S4, S5, S6, S7, S8, S9, Value); + } + + // Case-insensitive case matchers. + StringSwitch &CaseLower(StringLiteral S, T Value) { + if (!Result && Str.equals_lower(S)) + Result = std::move(Value); + + return *this; + } + + StringSwitch &EndsWithLower(StringLiteral S, T Value) { + if (!Result && Str.endswith_lower(S)) + Result = Value; + + return *this; + } + + StringSwitch &StartsWithLower(StringLiteral S, T Value) { + if (!Result && Str.startswith_lower(S)) + Result = std::move(Value); + + return *this; + } + + StringSwitch &CasesLower(StringLiteral S0, StringLiteral S1, T Value) { + return CaseLower(S0, Value).CaseLower(S1, Value); + } + + StringSwitch &CasesLower(StringLiteral S0, StringLiteral S1, StringLiteral S2, + T Value) { + return CaseLower(S0, Value).CasesLower(S1, S2, Value); + } + + StringSwitch &CasesLower(StringLiteral S0, StringLiteral S1, StringLiteral S2, + StringLiteral S3, T Value) { + return CaseLower(S0, Value).CasesLower(S1, S2, S3, Value); + } + + StringSwitch &CasesLower(StringLiteral S0, StringLiteral S1, StringLiteral S2, + StringLiteral S3, StringLiteral S4, T Value) { + return CaseLower(S0, Value).CasesLower(S1, S2, S3, S4, Value); + } + + LLVM_NODISCARD + R Default(T Value) { + if (Result) + return std::move(*Result); + return Value; + } + + LLVM_NODISCARD + operator R() { + assert(Result && "Fell off the end of a string-switch"); + return std::move(*Result); + } +}; + +} // end namespace llvm +}} + +#endif // LLVM_ADT_STRINGSWITCH_H From be8674ea73551cfa1e91ab3a3e3f52f712c3428b Mon Sep 17 00:00:00 2001 From: Joe Groff Date: Thu, 28 May 2020 18:19:03 -0700 Subject: [PATCH 17/28] Make an internal KeyPath helper final. And remove an unnecessary override, so that further work will allow this method not to need a vtable entry. --- stdlib/public/core/KeyPath.swift | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/stdlib/public/core/KeyPath.swift b/stdlib/public/core/KeyPath.swift index dcedf5aa4c79d..53930842b2981 100644 --- a/stdlib/public/core/KeyPath.swift +++ b/stdlib/public/core/KeyPath.swift @@ -134,6 +134,7 @@ public class AnyKeyPath: Hashable, _AppendKeyPath { // Prevent normal initialization. We use tail allocation via // allocWithTailElems(). + @available(*, unavailable) internal init() { _internalInvariantFailure("use _create(...)") } @@ -158,7 +159,7 @@ public class AnyKeyPath: Hashable, _AppendKeyPath { return result } - internal func withBuffer(_ f: (KeyPathBuffer) throws -> T) rethrows -> T { + final internal func withBuffer(_ f: (KeyPathBuffer) throws -> T) rethrows -> T { defer { _fixLifetime(self) } let base = UnsafeRawPointer(Builtin.projectTailElems(self, Int32.self)) @@ -348,14 +349,6 @@ public class ReferenceWritableKeyPath< internal final override class var kind: Kind { return .reference } - internal final override func _projectMutableAddress( - from base: UnsafePointer - ) -> (pointer: UnsafeMutablePointer, owner: AnyObject?) { - // Since we're a ReferenceWritableKeyPath, we know we don't mutate the base - // in practice. - return _projectMutableAddress(from: base.pointee) - } - @usableFromInline internal final func _projectMutableAddress(from origBase: Root) -> (pointer: UnsafeMutablePointer, owner: AnyObject?) { From 93ff8b0d9679bdce633a1a5cd8fdab34ff6dc4ae Mon Sep 17 00:00:00 2001 From: Erik Eckstein Date: Tue, 26 May 2020 10:44:07 +0200 Subject: [PATCH 18/28] stdlib: make sure that SetAlgebra.init(sequence) is on the fast path. In order to fully optimize OptionSet literals, it's important that this function is inlined and fully optimized. So far this was done by chance, but with COW representation it needs a hint to the optimizer. --- stdlib/public/core/SetAlgebra.swift | 2 ++ test/SILOptimizer/optionset.swift | 14 ++++++++------ 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/stdlib/public/core/SetAlgebra.swift b/stdlib/public/core/SetAlgebra.swift index 8ef4c914d4e7a..a7c206af3213f 100644 --- a/stdlib/public/core/SetAlgebra.swift +++ b/stdlib/public/core/SetAlgebra.swift @@ -409,6 +409,8 @@ extension SetAlgebra { public init(_ sequence: __owned S) where S.Element == Element { self.init() + // Needed to fully optimize OptionSet literals. + _onFastPath() for e in sequence { insert(e) } } diff --git a/test/SILOptimizer/optionset.swift b/test/SILOptimizer/optionset.swift index 1ee43c2ef2135..6e332bc27281e 100644 --- a/test/SILOptimizer/optionset.swift +++ b/test/SILOptimizer/optionset.swift @@ -14,6 +14,7 @@ public struct TestOptions: OptionSet { // CHECK: sil @{{.*}}returnTestOptions{{.*}} // CHECK-NEXT: bb0: +// CHECK-NEXT: builtin // CHECK-NEXT: integer_literal {{.*}}, 15 // CHECK-NEXT: struct $Int // CHECK-NEXT: struct $TestOptions @@ -22,18 +23,19 @@ public func returnTestOptions() -> TestOptions { return [.first, .second, .third, .fourth] } -// CHECK: sil @{{.*}}returnEmptyTestOptions{{.*}} -// CHECK-NEXT: bb0: -// CHECK-NEXT: integer_literal {{.*}}, 0 -// CHECK-NEXT: struct $Int -// CHECK-NEXT: struct $TestOptions -// CHECK-NEXT: return +// CHECK: sil @{{.*}}returnEmptyTestOptions{{.*}} +// CHECK: [[ZERO:%[0-9]+]] = integer_literal {{.*}}, 0 +// CHECK: [[ZEROINT:%[0-9]+]] = struct $Int ([[ZERO]] +// CHECK: [[TO:%[0-9]+]] = struct $TestOptions ([[ZEROINT]] +// CHECK: return [[TO]] +// CHECK: } // end sil function {{.*}}returnEmptyTestOptions{{.*}} public func returnEmptyTestOptions() -> TestOptions { return [] } // CHECK: alloc_global @{{.*}}globalTestOptions{{.*}} // CHECK-NEXT: global_addr +// CHECK-NEXT: builtin // CHECK-NEXT: integer_literal {{.*}}, 15 // CHECK-NEXT: struct $Int // CHECK-NEXT: struct $TestOptions From f6ec448583672ced2bac64ae0fd9660d0f99c06f Mon Sep 17 00:00:00 2001 From: Erik Eckstein Date: Tue, 26 May 2020 11:53:57 +0200 Subject: [PATCH 19/28] stdlib: Prevent storing into the empty array singleton when replacing an array sub-sequence. In the corner case of a 0 sized replacement in an empty array, we did write the 0 count back to the array singleton. This is not a big problem right now, because it would just overwrite a 0 with a 0, but it shouldn't be done. But with COW representation in Array, it would break the sanity checks. --- stdlib/public/core/ArrayBufferProtocol.swift | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/stdlib/public/core/ArrayBufferProtocol.swift b/stdlib/public/core/ArrayBufferProtocol.swift index 7a48162a84d8a..8ed718b8cbe70 100644 --- a/stdlib/public/core/ArrayBufferProtocol.swift +++ b/stdlib/public/core/ArrayBufferProtocol.swift @@ -150,7 +150,10 @@ extension _ArrayBufferProtocol where Indices == Range{ let eraseCount = subrange.count let growth = newCount - eraseCount - self.count = oldCount + growth + // This check will prevent storing a 0 count to the empty array singleton. + if growth != 0 { + self.count = oldCount + growth + } let elements = self.subscriptBaseAddress let oldTailIndex = subrange.upperBound From 68728dcb7d6588059a80c98e4a813023116259cf Mon Sep 17 00:00:00 2001 From: Erik Eckstein Date: Tue, 26 May 2020 15:42:38 +0200 Subject: [PATCH 20/28] stdlib: move the new-buffer creation function from Array to ArrayBuffer This has two advantages: 1. It does not force the Array in memory (to pass it as inout self to the non-inlinable _createNewBuffer). 2. The new _consumeAndCreateNew is annotated to consume self. This helps to reduce unnecessary retains/releases. The change applies for Array and ContiguousArray. --- stdlib/public/core/Array.swift | 32 ++--------- stdlib/public/core/ArrayBuffer.swift | 54 +++++++++++++++++++ stdlib/public/core/ContiguousArray.swift | 31 ++--------- .../public/core/ContiguousArrayBuffer.swift | 54 +++++++++++++++++++ test/IRGen/multithread_module.swift | 4 +- test/SILOptimizer/array_contentof_opt.swift | 10 ++-- 6 files changed, 125 insertions(+), 60 deletions(-) diff --git a/stdlib/public/core/Array.swift b/stdlib/public/core/Array.swift index 147b056665c48..bb7c72cd474a8 100644 --- a/stdlib/public/core/Array.swift +++ b/stdlib/public/core/Array.swift @@ -346,8 +346,7 @@ extension Array { @_semantics("array.make_mutable") internal mutating func _makeMutableAndUnique() { if _slowPath(!_buffer.isMutableAndUniquelyReferenced()) { - _createNewBuffer(bufferIsUnique: false, minimumCapacity: count, - growForAppend: false) + _buffer = _buffer._consumeAndCreateNew() } } @@ -1049,34 +1048,13 @@ extension Array: RangeReplaceableCollection { /// If `growForAppend` is true, the new capacity is calculated using /// `_growArrayCapacity`, but at least kept at `minimumCapacity`. @_alwaysEmitIntoClient - @inline(never) internal mutating func _createNewBuffer( bufferIsUnique: Bool, minimumCapacity: Int, growForAppend: Bool ) { - let newCapacity = _growArrayCapacity(oldCapacity: _getCapacity(), - minimumCapacity: minimumCapacity, - growForAppend: growForAppend) - let count = _getCount() - _internalInvariant(newCapacity >= count) - - let newBuffer = _ContiguousArrayBuffer( - _uninitializedCount: count, minimumCapacity: newCapacity) - - if bufferIsUnique { - _internalInvariant(_buffer.isUniquelyReferenced()) - - // As an optimization, if the original buffer is unique, we can just move - // the elements instead of copying. - let dest = newBuffer.firstElementAddress - dest.moveInitialize(from: _buffer.firstElementAddress, - count: count) - _buffer.count = 0 - } else { - _buffer._copyContents( - subRange: 0.. _ArrayBuffer { + return _consumeAndCreateNew(bufferIsUnique: false, + minimumCapacity: count, + growForAppend: false) + } + + /// Creates and returns a new uniquely referenced buffer which is a copy of + /// this buffer. + /// + /// If `bufferIsUnique` is true, the buffer is assumed to be uniquely + /// referenced and the elements are moved - instead of copied - to the new + /// buffer. + /// The `minimumCapacity` is the lower bound for the new capacity. + /// If `growForAppend` is true, the new capacity is calculated using + /// `_growArrayCapacity`, but at least kept at `minimumCapacity`. + /// + /// This buffer is consumed, i.e. it's released. + @_alwaysEmitIntoClient + @inline(never) + @_semantics("optimize.sil.specialize.owned2guarantee.never") + internal __consuming func _consumeAndCreateNew( + bufferIsUnique: Bool, minimumCapacity: Int, growForAppend: Bool + ) -> _ArrayBuffer { + let newCapacity = _growArrayCapacity(oldCapacity: capacity, + minimumCapacity: minimumCapacity, + growForAppend: growForAppend) + let c = count + _internalInvariant(newCapacity >= c) + + let newBuffer = _ContiguousArrayBuffer( + _uninitializedCount: c, minimumCapacity: newCapacity) + + if bufferIsUnique { + // As an optimization, if the original buffer is unique, we can just move + // the elements instead of copying. + let dest = newBuffer.firstElementAddress + dest.moveInitialize(from: firstElementAddress, + count: c) + _native.count = 0 + } else { + _copyContents( + subRange: 0..= count) - - let newBuffer = _ContiguousArrayBuffer( - _uninitializedCount: count, minimumCapacity: newCapacity) - - if bufferIsUnique { - _internalInvariant(_buffer.isUniquelyReferenced()) - - // As an optimization, if the original buffer is unique, we can just move - // the elements instead of copying. - let dest = newBuffer.firstElementAddress - dest.moveInitialize(from: _buffer.firstElementAddress, - count: count) - _buffer.count = 0 - } else { - _buffer._copyContents( - subRange: 0..: _ArrayBufferProtocol { return _isUnique(&_storage) } + /// Creates and returns a new uniquely referenced buffer which is a copy of + /// this buffer. + /// + /// This buffer is consumed, i.e. it's released. + @_alwaysEmitIntoClient + @inline(never) + @_semantics("optimize.sil.specialize.owned2guarantee.never") + internal __consuming func _consumeAndCreateNew() -> _ContiguousArrayBuffer { + return _consumeAndCreateNew(bufferIsUnique: false, + minimumCapacity: count, + growForAppend: false) + } + + /// Creates and returns a new uniquely referenced buffer which is a copy of + /// this buffer. + /// + /// If `bufferIsUnique` is true, the buffer is assumed to be uniquely + /// referenced and the elements are moved - instead of copied - to the new + /// buffer. + /// The `minimumCapacity` is the lower bound for the new capacity. + /// If `growForAppend` is true, the new capacity is calculated using + /// `_growArrayCapacity`, but at least kept at `minimumCapacity`. + /// + /// This buffer is consumed, i.e. it's released. + @_alwaysEmitIntoClient + @inline(never) + @_semantics("optimize.sil.specialize.owned2guarantee.never") + internal __consuming func _consumeAndCreateNew( + bufferIsUnique: Bool, minimumCapacity: Int, growForAppend: Bool + ) -> _ContiguousArrayBuffer { + let newCapacity = _growArrayCapacity(oldCapacity: capacity, + minimumCapacity: minimumCapacity, + growForAppend: growForAppend) + let c = count + _internalInvariant(newCapacity >= c) + + let newBuffer = _ContiguousArrayBuffer( + _uninitializedCount: c, minimumCapacity: newCapacity) + + if bufferIsUnique { + // As an optimization, if the original buffer is unique, we can just move + // the elements instead of copying. + let dest = newBuffer.firstElementAddress + dest.moveInitialize(from: firstElementAddress, + count: c) + count = 0 + } else { + _copyContents( + subRange: 0..". diff --git a/test/SILOptimizer/array_contentof_opt.swift b/test/SILOptimizer/array_contentof_opt.swift index 7d2af225a2a79..1b1a3d5b9c649 100644 --- a/test/SILOptimizer/array_contentof_opt.swift +++ b/test/SILOptimizer/array_contentof_opt.swift @@ -1,4 +1,4 @@ -// RUN: %target-swift-frontend -O -sil-verify-all -emit-sil -Xllvm '-sil-inline-never-functions=$sSa6append' %s | %FileCheck %s +// RUN: %target-swift-frontend -O -sil-verify-all -emit-sil -Xllvm '-sil-inline-never-functions=$sSa6appendyy' %s | %FileCheck %s // REQUIRES: swift_stdlib_no_asserts,optimized_stdlib // This is an end-to-end test of the Array.append(contentsOf:) -> @@ -24,7 +24,7 @@ public func testInt(_ a: inout [Int]) { } // CHECK-LABEL: sil @{{.*}}testThreeInts -// CHECK-DAG: [[FR:%[0-9]+]] = function_ref @${{(sSa15reserveCapacityyySiFSi_Tg5|sSa16_createNewBuffer)}} +// CHECK-DAG: [[FR:%[0-9]+]] = function_ref @${{.*(reserveCapacity|_createNewBuffer)}} // CHECK-DAG: apply [[FR]] // CHECK-DAG: [[F:%[0-9]+]] = function_ref @$sSa6appendyyxnFSi_Tg5 // CHECK-DAG: apply [[F]] @@ -37,7 +37,7 @@ public func testThreeInts(_ a: inout [Int]) { // CHECK-LABEL: sil @{{.*}}testTooManyInts // CHECK-NOT: apply -// CHECK: [[F:%[0-9]+]] = function_ref @$sSa6append10contentsOfyqd__n_t7ElementQyd__RszSTRd__lFSi_SaySiGTg5 +// CHECK: [[F:%[0-9]+]] = function_ref @${{.*append.*contentsOf.*}} // CHECK-NOT: apply // CHECK: apply [[F]] // CHECK-NOT: apply @@ -65,12 +65,12 @@ public func dontPropagateContiguousArray(_ a: inout ContiguousArray) { // Check if the specialized Array.append(contentsOf:) is reasonably optimized for Array. -// CHECK-LABEL: sil shared {{.*}}@$sSa6append10contentsOfyqd__n_t7ElementQyd__RszSTRd__lFSi_SaySiGTg5Tf4gn_n +// CHECK-LABEL: sil shared {{.*}}@$sSa6append10contentsOfyqd__n_t7ElementQyd__RszSTRd__lFSi_SaySiGTg5 // There should only be a single call to _createNewBuffer or reserveCapacityForAppend/reserveCapacityImpl. // CHECK-NOT: apply -// CHECK: [[F:%[0-9]+]] = function_ref @{{.*(_createNewBuffer|reserveCapacity).*}} +// CHECK: [[F:%[0-9]+]] = function_ref @{{.*(_consumeAndCreateNew|reserveCapacity).*}} // CHECK-NEXT: apply [[F]] // CHECK-NOT: apply From 56c8ebabbdbd1465b4825a378acf369f31c67835 Mon Sep 17 00:00:00 2001 From: Dan Zheng Date: Tue, 26 May 2020 17:53:02 -0700 Subject: [PATCH 21/28] [SIL] Fix `alloc_stack [dynamic_lifetime]` attribute cloning. Make `SILCloner:visitAllocStack` correctly propagate the `[dynamic_lifetime]` attribute. Resolves SR-12886: differentiation transform error related to the `VJPEmitter` subclass of `SILCloner`. --- include/swift/SIL/SILCloner.h | 6 +-- ...6-clone-alloc-stack-dynamic-lifetime.swift | 42 +++++++++++++++++++ 2 files changed, 45 insertions(+), 3 deletions(-) create mode 100644 test/AutoDiff/compiler_crashers_fixed/sr12886-clone-alloc-stack-dynamic-lifetime.swift diff --git a/include/swift/SIL/SILCloner.h b/include/swift/SIL/SILCloner.h index 474961ebaf76d..010aefc58805f 100644 --- a/include/swift/SIL/SILCloner.h +++ b/include/swift/SIL/SILCloner.h @@ -802,9 +802,9 @@ SILCloner::visitAllocStackInst(AllocStackInst *Inst) { Loc = MandatoryInlinedLocation::getAutoGeneratedLocation(); VarInfo = None; } - recordClonedInstruction(Inst, - getBuilder().createAllocStack( - Loc, getOpType(Inst->getElementType()), VarInfo)); + recordClonedInstruction(Inst, getBuilder().createAllocStack( + Loc, getOpType(Inst->getElementType()), + VarInfo, Inst->hasDynamicLifetime())); } template diff --git a/test/AutoDiff/compiler_crashers_fixed/sr12886-clone-alloc-stack-dynamic-lifetime.swift b/test/AutoDiff/compiler_crashers_fixed/sr12886-clone-alloc-stack-dynamic-lifetime.swift new file mode 100644 index 0000000000000..67f441445ae4b --- /dev/null +++ b/test/AutoDiff/compiler_crashers_fixed/sr12886-clone-alloc-stack-dynamic-lifetime.swift @@ -0,0 +1,42 @@ +// RUN: %target-build-swift %s +// RUN: %target-swift-frontend -emit-sil %s | %FileCheck %s + +// SR-12493: SIL memory lifetime verification error due to +// `SILCloner::visitAllocStack` not copying the `[dynamic_lifetime]` attribute. + +import _Differentiation + +enum Enum { + case a +} + +struct Tensor: Differentiable { + @noDerivative var x: T + @noDerivative var optional: Int? + + init(_ x: T, _ e: Enum) { + self.x = x + switch e { + case .a: optional = 1 + } + } + + // Definite initialization triggers for this initializer. + @differentiable + init(_ x: T, _ other: Self) { + self = Self(x, Enum.a) + } +} + +// Check that `allock_stack [dynamic_lifetime]` attribute is correctly cloned. + +// CHECK-LABEL: sil hidden @$s4main6TensorVyACyxGx_ADtcfC : $@convention(method) (@in T, @in Tensor, @thin Tensor.Type) -> @out Tensor { +// CHECK: [[SELF_ALLOC:%.*]] = alloc_stack [dynamic_lifetime] $Tensor, var, name "self" + +// CHECK-LABEL: sil hidden @AD__$s4main6TensorVyACyxGx_ADtcfC__vjp_src_0_wrt_1_l : $@convention(method) <Ï„_0_0> (@in Ï„_0_0, @in Tensor<Ï„_0_0>, @thin Tensor<Ï„_0_0>.Type) -> (@out Tensor<Ï„_0_0>, @owned @callee_guaranteed @substituted <Ï„_0_0, Ï„_0_1> (@in_guaranteed Ï„_0_0) -> @out Ï„_0_1 for .TangentVector, Tensor<Ï„_0_0>.TangentVector>) { +// CHECK: [[SELF_ALLOC:%.*]] = alloc_stack [dynamic_lifetime] $Tensor<Ï„_0_0>, var, name "self" + +// Original error: +// SIL memory lifetime failure in @AD__$s5crash6TensorVyACyxGx_ADtcfC__vjp_src_0_wrt_1_l: memory is not initialized, but should +// memory location: %29 = struct_element_addr %5 : $*Tensor<Ï„_0_0>, #Tensor.x // user: %30 +// at instruction: destroy_addr %29 : $*Ï„_0_0 // id: %30 From f9c5d7ae6c618abfeecb3e3c4e35cfddfbbe8310 Mon Sep 17 00:00:00 2001 From: Dan Zheng Date: Fri, 29 May 2020 01:59:52 -0700 Subject: [PATCH 22/28] [AutoDiff] Derive `Differentiable.zeroTangentVectorInitializer`. (#31823) `Differentiable` conformance derivation now supports `Differentiable.zeroTangentVectorInitializer`. There are two potential cases: 1. Memberwise derivation: done when `TangentVector` can be initialized memberwise. 2. `{ TangentVector.zero }` derivation: done as a fallback. `zeroTangentVectorInitializer` is a closure that produces a zero tangent vector, capturing minimal necessary information from `self`. It is an instance property, unlike the static property `AdditiveArithmetic.zero`, and should be used by the differentiation transform for correctness. Remove `Differentiable.zeroTangentVectorInitializer` dummy default implementation. Update stdlib `Differentiable` conformances and tests. Clean up DerivedConformanceDifferentiable.cpp cruft. Resolves TF-1007. Progress towards TF-1008: differentiation correctness for projection operations. --- docs/DifferentiableProgramming.md | 36 +- include/swift/AST/KnownIdentifiers.def | 1 + lib/Sema/CodeSynthesis.cpp | 19 + lib/Sema/CodeSynthesis.h | 15 + lib/Sema/DebuggerTestingTransform.cpp | 30 +- lib/Sema/DerivedConformanceDifferentiable.cpp | 652 ++++++++++++------ lib/Sema/DerivedConformances.cpp | 9 +- lib/Sema/DerivedConformances.h | 6 +- lib/Sema/TypeCheckProtocol.cpp | 41 +- lib/Sema/TypeChecker.h | 6 + .../Differentiation/AnyDifferentiable.swift | 9 + .../ArrayDifferentiation.swift | 4 +- .../Differentiation/Differentiable.swift | 15 - .../FloatingPointDifferentiation.swift.gyb | 5 + .../SIMDDifferentiation.swift.gyb | 5 + .../Inputs/b.swift | 1 + ...ived_zero_tangent_vector_initializer.swift | 249 +++++++ .../zero_tangent_vector_initializer.swift | 59 ++ 18 files changed, 886 insertions(+), 276 deletions(-) create mode 100644 test/AutoDiff/Sema/DerivedConformances/derived_zero_tangent_vector_initializer.swift create mode 100644 test/AutoDiff/validation-test/zero_tangent_vector_initializer.swift diff --git a/docs/DifferentiableProgramming.md b/docs/DifferentiableProgramming.md index c40dd3da310ff..0cec90aa47712 100644 --- a/docs/DifferentiableProgramming.md +++ b/docs/DifferentiableProgramming.md @@ -1079,11 +1079,6 @@ public extension Differentiable where Self == TangentVector { mutating func move(along direction: TangentVector) { self += direction } - - @noDerivative - var zeroTangentVectorInitializer: () -> TangentVector { - { .zero } - } } ``` @@ -1144,8 +1139,8 @@ extension Array: Differentiable where Element: Differentiable { @noDerivative public var zeroTangentVectorInitializer: () -> TangentVector { - { [count = self.count] in - TangentVector(Array(repeating: .zero, count: count)) + { [zeroInits = map(\.zeroTangentVectorInitializer)] in + TangentVector(zeroInits.map { $0() }) } } } @@ -1238,8 +1233,15 @@ the same effective access level as their corresponding original properties. A `move(along:)` method is synthesized with a body that calls `move(along:)` for each pair of the original property and its corresponding property in -`TangentVector`. Similarly, `zeroTangentVector` is synthesized to return a -tangent vector that consists of each stored property's `zeroTangentVector`. +`TangentVector`. + +Similarly, when memberwise derivation is possible, +`zeroTangentVectorInitializer` is synthesized to return a closure that captures +and calls each stored property's `zeroTangentVectorInitializer` closure. +When memberwise derivation is not possible (e.g. for custom user-defined +`TangentVector` types), `zeroTangentVectorInitializer` is synthesized as a +`{ TangentVector.zero }` closure. + Here's an example: ```swift @@ -1251,14 +1253,17 @@ struct Foo: @memberwise Differentiable { @noDerivative let helperVariable: T // The compiler synthesizes: + // // struct TangentVector: Differentiable, AdditiveArithmetic { // var x: T.TangentVector // var y: U.TangentVector // } + // // mutating func move(along direction: TangentVector) { // x.move(along: direction.x) // y.move(along: direction.y) // } + // // @noDerivative // var zeroTangentVectorInitializer: () -> TangentVector { // { [xTanInit = x.zeroTangentVectorInitializer, @@ -1278,8 +1283,8 @@ properties are declared to conform to `AdditiveArithmetic`. There are no `@noDerivative` stored properties. In these cases, the compiler will make `TangentVector` be a type alias for Self. -Method `move(along:)` and property `zeroTangentVector` will not be synthesized -because a default implementation already exists. +Method `move(along:)` will not be synthesized because a default implementation +already exists. ```swift struct Point: @memberwise Differentiable, @memberwise AdditiveArithmetic { @@ -1287,7 +1292,16 @@ struct Point: @memberwise Differentiable, @memberwise AdditiveArithmeti var x, y: T // The compiler synthesizes: + // // typealias TangentVector = Self + // + // @noDerivative + // var zeroTangentVectorInitializer: () -> TangentVector { + // { [xTanInit = x.zeroTangentVectorInitializer, + // yTanInit = y.zeroTangentVectorInitializer] in + // TangentVector(x: xTanInit(), y: yTanInit()) + // } + // } } ``` diff --git a/include/swift/AST/KnownIdentifiers.def b/include/swift/AST/KnownIdentifiers.def index 7dd491d870d22..24c9f9c3d979c 100644 --- a/include/swift/AST/KnownIdentifiers.def +++ b/include/swift/AST/KnownIdentifiers.def @@ -223,6 +223,7 @@ IDENTIFIER(move) IDENTIFIER(pullback) IDENTIFIER(TangentVector) IDENTIFIER(zero) +IDENTIFIER(zeroTangentVectorInitializer) #undef IDENTIFIER #undef IDENTIFIER_ diff --git a/lib/Sema/CodeSynthesis.cpp b/lib/Sema/CodeSynthesis.cpp index 225de9b9751f8..325132ea9fbaa 100644 --- a/lib/Sema/CodeSynthesis.cpp +++ b/lib/Sema/CodeSynthesis.cpp @@ -1414,3 +1414,22 @@ void swift::addFixedLayoutAttr(NominalTypeDecl *nominal) { // Add `@_fixed_layout` to the nominal. nominal->getAttrs().add(new (C) FixedLayoutAttr(/*Implicit*/ true)); } + +Expr *DiscriminatorFinder::walkToExprPost(Expr *E) { + auto *ACE = dyn_cast(E); + if (!ACE) + return E; + + unsigned Discriminator = ACE->getDiscriminator(); + assert(Discriminator != AbstractClosureExpr::InvalidDiscriminator && + "Existing closures should have valid discriminators"); + if (Discriminator >= NextDiscriminator) + NextDiscriminator = Discriminator + 1; + return E; +} + +unsigned DiscriminatorFinder::getNextDiscriminator() { + if (NextDiscriminator == AbstractClosureExpr::InvalidDiscriminator) + llvm::report_fatal_error("Out of valid closure discriminators"); + return NextDiscriminator++; +} diff --git a/lib/Sema/CodeSynthesis.h b/lib/Sema/CodeSynthesis.h index a73bb87903e9c..95db8aab697f0 100644 --- a/lib/Sema/CodeSynthesis.h +++ b/lib/Sema/CodeSynthesis.h @@ -18,6 +18,7 @@ #ifndef SWIFT_TYPECHECKING_CODESYNTHESIS_H #define SWIFT_TYPECHECKING_CODESYNTHESIS_H +#include "swift/AST/ASTWalker.h" #include "swift/AST/ForeignErrorConvention.h" #include "swift/Basic/ExternalUnion.h" #include "swift/Basic/LLVM.h" @@ -75,6 +76,20 @@ bool hasLetStoredPropertyWithInitialValue(NominalTypeDecl *nominal); /// Add `@_fixed_layout` attribute to the nominal type, if possible. void addFixedLayoutAttr(NominalTypeDecl *nominal); +/// Find available closure discriminators. +/// +/// The parser typically takes care of assigning unique discriminators to +/// closures, but the parser is unavailable during semantic analysis. +class DiscriminatorFinder : public ASTWalker { + unsigned NextDiscriminator = 0; + +public: + Expr *walkToExprPost(Expr *E) override; + + // Get the next available closure discriminator. + unsigned getNextDiscriminator(); +}; + } // end namespace swift #endif diff --git a/lib/Sema/DebuggerTestingTransform.cpp b/lib/Sema/DebuggerTestingTransform.cpp index 2e5719cfcef87..6f9e1c2ed073e 100644 --- a/lib/Sema/DebuggerTestingTransform.cpp +++ b/lib/Sema/DebuggerTestingTransform.cpp @@ -15,6 +15,7 @@ /// //===----------------------------------------------------------------------===// +#include "CodeSynthesis.h" #include "swift/AST/ASTContext.h" #include "swift/AST/ASTNode.h" #include "swift/AST/ASTWalker.h" @@ -33,35 +34,6 @@ using namespace swift; namespace { -/// Find available closure discriminators. -/// -/// The parser typically takes care of assigning unique discriminators to -/// closures, but the parser is unavailable to this transform. -class DiscriminatorFinder : public ASTWalker { - unsigned NextDiscriminator = 0; - -public: - Expr *walkToExprPost(Expr *E) override { - auto *ACE = dyn_cast(E); - if (!ACE) - return E; - - unsigned Discriminator = ACE->getDiscriminator(); - assert(Discriminator != AbstractClosureExpr::InvalidDiscriminator && - "Existing closures should have valid discriminators"); - if (Discriminator >= NextDiscriminator) - NextDiscriminator = Discriminator + 1; - return E; - } - - // Get the next available closure discriminator. - unsigned getNextDiscriminator() { - if (NextDiscriminator == AbstractClosureExpr::InvalidDiscriminator) - llvm::report_fatal_error("Out of valid closure discriminators"); - return NextDiscriminator++; - } -}; - /// Instrument decls with sanity-checks which the debugger can evaluate. class DebuggerTestingTransform : public ASTWalker { ASTContext &Ctx; diff --git a/lib/Sema/DerivedConformanceDifferentiable.cpp b/lib/Sema/DerivedConformanceDifferentiable.cpp index 6254ea6cb95f6..78312f6779e86 100644 --- a/lib/Sema/DerivedConformanceDifferentiable.cpp +++ b/lib/Sema/DerivedConformanceDifferentiable.cpp @@ -17,15 +17,14 @@ #include "CodeSynthesis.h" #include "TypeChecker.h" -#include "DerivedConformances.h" #include "swift/AST/AutoDiff.h" #include "swift/AST/Decl.h" #include "swift/AST/Expr.h" #include "swift/AST/Module.h" #include "swift/AST/ParameterList.h" #include "swift/AST/Pattern.h" -#include "swift/AST/ProtocolConformance.h" #include "swift/AST/PropertyWrappers.h" +#include "swift/AST/ProtocolConformance.h" #include "swift/AST/Stmt.h" #include "swift/AST/Types.h" #include "DerivedConformances.h" @@ -36,7 +35,8 @@ using namespace swift; /// differentiation, except the ones tagged `@noDerivative`. static void getStoredPropertiesForDifferentiation(NominalTypeDecl *nominal, DeclContext *DC, - SmallVectorImpl &result) { + SmallVectorImpl &result, + bool includeLetProperties = false) { auto &C = nominal->getASTContext(); auto *diffableProto = C.getProtocol(KnownProtocolKind::Differentiable); for (auto *vd : nominal->getStoredProperties()) { @@ -52,9 +52,10 @@ getStoredPropertiesForDifferentiation(NominalTypeDecl *nominal, DeclContext *DC, // Skip stored properties with `@noDerivative` attribute. if (vd->getAttrs().hasAttribute()) continue; - // Skip `let` stored properties. `mutating func move(along:)` cannot be - // synthesized to update these properties. - if (vd->isLet()) + // Skip `let` stored properties if requested. + // `mutating func move(along:)` cannot be synthesized to update `let` + // properties. + if (!includeLetProperties && vd->isLet()) continue; if (vd->getInterfaceType()->hasError()) continue; @@ -77,107 +78,150 @@ static StructDecl *convertToStructDecl(ValueDecl *v) { typeDecl->getDeclaredInterfaceType()->getAnyNominal()); } -/// Get the `Differentiable` protocol `TangentVector` associated type for the -/// given `VarDecl`. -/// TODO: Generalize and move function to shared place for use with other -/// derived conformances. -static Type getTangentVectorType(VarDecl *decl, DeclContext *DC) { - auto &C = decl->getASTContext(); +/// Get the `Differentiable` protocol `TangentVector` associated type witness +/// for the given interface type and declaration context. +static Type getTangentVectorInterfaceType(Type contextualType, + DeclContext *DC) { + auto &C = contextualType->getASTContext(); auto *diffableProto = C.getProtocol(KnownProtocolKind::Differentiable); - auto varType = DC->mapTypeIntoContext(decl->getValueInterfaceType()); - auto conf = TypeChecker::conformsToProtocol(varType, diffableProto, DC); + assert(diffableProto && "`Differentiable` protocol not found"); + auto conf = + TypeChecker::conformsToProtocol(contextualType, diffableProto, DC); + assert(conf && "Contextual type must conform to `Differentiable`"); if (!conf) return nullptr; - Type tangentType = conf.getTypeWitnessByName(varType, C.Id_TangentVector); - return tangentType; + auto tanType = conf.getTypeWitnessByName(contextualType, C.Id_TangentVector); + return tanType->hasArchetype() ? tanType->mapTypeOutOfContext() : tanType; } -// Get the `Differentiable` protocol associated `TangentVector` struct for the -// given nominal `DeclContext`. Asserts that the `TangentVector` struct type -// exists. -static StructDecl *getTangentVectorStructDecl(DeclContext *DC) { - assert(DC->getSelfNominalTypeDecl() && "Must be a nominal `DeclContext`"); - auto &C = DC->getASTContext(); +/// Returns true iff the given nominal type declaration can derive +/// `TangentVector` as `Self` in the given conformance context. +static bool canDeriveTangentVectorAsSelf(NominalTypeDecl *nominal, + DeclContext *DC) { + // `Self` must not be a class declaraiton. + if (nominal->getSelfClassDecl()) + return false; + + auto nominalTypeInContext = + DC->mapTypeIntoContext(nominal->getDeclaredInterfaceType()); + auto &C = nominal->getASTContext(); auto *diffableProto = C.getProtocol(KnownProtocolKind::Differentiable); - assert(diffableProto && "`Differentiable` protocol not found"); - auto conf = TypeChecker::conformsToProtocol(DC->getSelfTypeInContext(), - diffableProto, DC); - assert(conf && "Nominal must conform to `Differentiable`"); - auto assocType = - conf.getTypeWitnessByName(DC->getSelfTypeInContext(), C.Id_TangentVector); - assert(assocType && "`Differentiable.TangentVector` type not found"); - auto *structDecl = dyn_cast(assocType->getAnyNominal()); - assert(structDecl && "Associated type must be a struct type"); - return structDecl; + auto *addArithProto = C.getProtocol(KnownProtocolKind::AdditiveArithmetic); + // `Self` must conform to `AdditiveArithmetic`. + if (!TypeChecker::conformsToProtocol(nominalTypeInContext, addArithProto, DC)) + return false; + for (auto *field : nominal->getStoredProperties()) { + // `Self` must not have any `@noDerivative` stored properties. + if (field->getAttrs().hasAttribute()) + return false; + // `Self` must have all stored properties satisfy `Self == TangentVector`. + auto fieldType = DC->mapTypeIntoContext(field->getValueInterfaceType()); + auto conf = TypeChecker::conformsToProtocol(fieldType, diffableProto, DC); + if (!conf) + return false; + auto tangentType = conf.getTypeWitnessByName(fieldType, C.Id_TangentVector); + if (!fieldType->isEqual(tangentType)) + return false; + } + return true; +} + +// Synthesizable `Differentiable` protocol requirements. +enum class DifferentiableRequirement { + // associatedtype TangentVector + TangentVector, + // mutating func move(along direction: TangentVector) + MoveAlong, + // var zeroTangentVectorInitializer: () -> TangentVector + ZeroTangentVectorInitializer, +}; + +static DifferentiableRequirement +getDifferentiableRequirementKind(ValueDecl *requirement) { + auto &C = requirement->getASTContext(); + if (requirement->getBaseName() == C.Id_TangentVector) + return DifferentiableRequirement::TangentVector; + if (requirement->getBaseName() == C.Id_move) + return DifferentiableRequirement::MoveAlong; + if (requirement->getBaseName() == C.Id_zeroTangentVectorInitializer) + return DifferentiableRequirement::ZeroTangentVectorInitializer; + llvm_unreachable("Invalid `Differentiable` protocol requirement"); } bool DerivedConformance::canDeriveDifferentiable(NominalTypeDecl *nominal, - DeclContext *DC) { + DeclContext *DC, + ValueDecl *requirement) { // Experimental differentiable programming must be enabled. if (auto *SF = DC->getParentSourceFile()) if (!isDifferentiableProgrammingEnabled(*SF)) return false; - // Nominal type must be a struct or class. (No stored properties is okay.) - if (!isa(nominal) && !isa(nominal)) - return false; - auto &C = nominal->getASTContext(); - auto *diffableProto = C.getProtocol(KnownProtocolKind::Differentiable); - auto *addArithProto = C.getProtocol(KnownProtocolKind::AdditiveArithmetic); - // Nominal type must not customize `TangentVector` to anything other than - // `Self`. Otherwise, synthesis is semantically unsupported. - auto tangentDecls = nominal->lookupDirect(C.Id_TangentVector); - auto nominalTypeInContext = - DC->mapTypeIntoContext(nominal->getDeclaredInterfaceType()); + auto reqKind = getDifferentiableRequirementKind(requirement); - auto isValidAssocTypeCandidate = [&](ValueDecl *v) -> StructDecl * { + auto &C = nominal->getASTContext(); + // If there are any `TangentVector` type witness candidates, check whether + // there exists only a single valid candidate. + bool canUseTangentVectorAsSelf = canDeriveTangentVectorAsSelf(nominal, DC); + auto isValidTangentVectorCandidate = [&](ValueDecl *v) -> bool { + // If the requirement is `var zeroTangentVectorInitializer` and + // the candidate is a type declaration that conforms to + // `AdditiveArithmetic`, return true. + if (reqKind == DifferentiableRequirement::ZeroTangentVectorInitializer) { + if (auto *tangentVectorTypeDecl = dyn_cast(v)) { + auto tangentType = DC->mapTypeIntoContext( + tangentVectorTypeDecl->getDeclaredInterfaceType()); + auto *addArithProto = + C.getProtocol(KnownProtocolKind::AdditiveArithmetic); + if (TypeChecker::conformsToProtocol(tangentType, addArithProto, DC)) + return true; + } + } // Valid candidate must be a struct or a typealias to a struct. auto *structDecl = convertToStructDecl(v); if (!structDecl) - return nullptr; + return false; // Valid candidate must either: // 1. Be implicit (previously synthesized). if (structDecl->isImplicit()) - return structDecl; - // 2. Equal nominal's implicit parent. - // This can occur during mutually recursive constraints. Example: - // `X == X.TangentVector`. - if (nominal->isImplicit() && structDecl == nominal->getDeclContext() && - TypeChecker::conformsToProtocol(structDecl->getDeclaredInterfaceType(), - diffableProto, DC)) - return structDecl; - // 3. Equal nominal and conform to `AdditiveArithmetic`. - if (structDecl == nominal) { - // Check conformance to `AdditiveArithmetic`. - if (TypeChecker::conformsToProtocol(nominalTypeInContext, addArithProto, - DC)) - return structDecl; - } + return true; + // 2. Equal nominal, when the nominal can derive `TangentVector` as `Self`. + // Nominal type must not customize `TangentVector` to anything other than + // `Self`. Otherwise, synthesis is semantically unsupported. + if (structDecl == nominal && canUseTangentVectorAsSelf) + return true; // Otherwise, candidate is invalid. - return nullptr; + return false; }; - - auto invalidTangentDecls = llvm::partition( - tangentDecls, [&](ValueDecl *v) { return isValidAssocTypeCandidate(v); }); - - auto validTangentDeclCount = - std::distance(tangentDecls.begin(), invalidTangentDecls); - auto invalidTangentDeclCount = - std::distance(invalidTangentDecls, tangentDecls.end()); - - // There cannot be any invalid `TangentVector` types. + auto tangentDecls = nominal->lookupDirect(C.Id_TangentVector); // There can be at most one valid `TangentVector` type. - if (invalidTangentDeclCount != 0 || validTangentDeclCount > 1) + if (tangentDecls.size() > 1) return false; + // There cannot be any invalid `TangentVector` types. + if (tangentDecls.size() == 1) { + auto *tangentDecl = tangentDecls.front(); + if (!isValidTangentVectorCandidate(tangentDecl)) + return false; + } + bool hasValidTangentDecl = !tangentDecls.empty(); + + // Check requirement-specific derivation conditions. + if (reqKind == DifferentiableRequirement::ZeroTangentVectorInitializer) { + // If there is a valid `TangentVector` type witness (conforming to + // `AdditiveArithmetic`), return true. + if (hasValidTangentDecl) + return true; + // Otherwise, fallback on `TangentVector` struct derivation conditions. + } - // All stored properties not marked with `@noDerivative`: - // - Must conform to `Differentiable`. - // - Must not have any `let` stored properties with an initial value. - // - This restriction may be lifted later with support for "true" memberwise - // initializers that initialize all stored properties, including initial - // value information. + // Check `TangentVector` struct derivation conditions. + // Nominal type must be a struct or class. (No stored properties is okay.) + if (!isa(nominal) && !isa(nominal)) + return false; + // If there are no `TangentVector` candidates, derivation is possible if all + // differentiation stored properties conform to `Differentiable`. SmallVector diffProperties; getStoredPropertiesForDifferentiation(nominal, DC, diffProperties); + auto *diffableProto = C.getProtocol(KnownProtocolKind::Differentiable); return llvm::all_of(diffProperties, [&](VarDecl *v) { if (v->getInterfaceType()->hasError()) return false; @@ -186,18 +230,16 @@ bool DerivedConformance::canDeriveDifferentiable(NominalTypeDecl *nominal, }); } -/// Synthesize body for a `Differentiable` method requirement. +/// Synthesize body for `move(along:)`. static std::pair -deriveBodyDifferentiable_method(AbstractFunctionDecl *funcDecl, - Identifier methodName, - Identifier methodParamLabel) { +deriveBodyDifferentiable_move(AbstractFunctionDecl *funcDecl, void *) { + auto &C = funcDecl->getASTContext(); auto *parentDC = funcDecl->getParent(); auto *nominal = parentDC->getSelfNominalTypeDecl(); - auto &C = nominal->getASTContext(); - // Get method protocol requirement. + // Get `Differentiable.move(along:)` protocol requirement. auto *diffProto = C.getProtocol(KnownProtocolKind::Differentiable); - auto *methodReq = getProtocolRequirement(diffProto, methodName); + auto *requirement = getProtocolRequirement(diffProto, C.Id_move); // Get references to `self` and parameter declarations. auto *selfDecl = funcDecl->getImplicitSelfDecl(); @@ -210,9 +252,8 @@ deriveBodyDifferentiable_method(AbstractFunctionDecl *funcDecl, SmallVector diffProperties; getStoredPropertiesForDifferentiation(nominal, parentDC, diffProperties); - // Create call expression applying a member method to a parameter member. - // Format: `.method(.)`. - // Example: `x.move(along: direction.x)`. + // Create call expression applying a member `move(along:)` method to a + // parameter member: `self..move(along: direction.)`. auto createMemberMethodCallExpr = [&](VarDecl *member) -> Expr * { auto *module = nominal->getModuleContext(); auto memberType = @@ -220,27 +261,24 @@ deriveBodyDifferentiable_method(AbstractFunctionDecl *funcDecl, auto confRef = module->lookupConformance(memberType, diffProto); assert(confRef && "Member does not conform to `Differentiable`"); - // Get member type's method, e.g. `Member.move(along:)`. - // Use protocol requirement declaration for the method by default: this - // will be dynamically dispatched. - ValueDecl *memberMethodDecl = methodReq; - // If conformance reference is concrete, then use concrete witness - // declaration for the operator. + // Get member type's requirement witness: `.move(along:)`. + ValueDecl *memberWitnessDecl = requirement; if (confRef.isConcrete()) - memberMethodDecl = confRef.getConcrete()->getWitnessDecl(methodReq); - assert(memberMethodDecl && "Member method declaration must exist"); - auto *memberMethodDRE = - new (C) DeclRefExpr(memberMethodDecl, DeclNameLoc(), /*Implicit*/ true); + if (auto *witness = confRef.getConcrete()->getWitnessDecl(requirement)) + memberWitnessDecl = witness; + assert(memberWitnessDecl && "Member witness declaration must exist"); + auto *memberMethodDRE = new (C) + DeclRefExpr(memberWitnessDecl, DeclNameLoc(), /*Implicit*/ true); memberMethodDRE->setFunctionRefKind(FunctionRefKind::SingleApply); - // Create reference to member method: `x.move(along:)`. + // Create reference to member method: `self..move(along:)`. Expr *memberExpr = new (C) MemberRefExpr(selfDRE, SourceLoc(), member, DeclNameLoc(), /*Implicit*/ true); auto *memberMethodExpr = new (C) DotSyntaxCallExpr(memberMethodDRE, SourceLoc(), memberExpr); - // Create reference to parameter member: `direction.x`. + // Create reference to parameter member: `direction.`. VarDecl *paramMember = nullptr; auto *paramNominal = paramDecl->getType()->getAnyNominal(); assert(paramNominal && "Parameter should have a nominal type"); @@ -255,14 +293,14 @@ deriveBodyDifferentiable_method(AbstractFunctionDecl *funcDecl, auto *paramMemberExpr = new (C) MemberRefExpr(paramDRE, SourceLoc(), paramMember, DeclNameLoc(), /*Implicit*/ true); - // Create expression: `x.move(along: direction.x)`. + // Create expression: `self..move(along: direction.)`. return CallExpr::createImplicit(C, memberMethodExpr, {paramMemberExpr}, - {methodParamLabel}); + {C.Id_along}); }; - // Create array of member method call expressions. - llvm::SmallVector memberMethodCallExprs; - llvm::SmallVector memberNames; + // Collect member `move(along:)` method call expressions. + SmallVector memberMethodCallExprs; + SmallVector memberNames; for (auto *member : diffProperties) { memberMethodCallExprs.push_back(createMemberMethodCallExpr(member)); memberNames.push_back(member->getName()); @@ -272,11 +310,229 @@ deriveBodyDifferentiable_method(AbstractFunctionDecl *funcDecl, return std::pair(braceStmt, false); } -/// Synthesize body for `move(along:)`. +/// Synthesize body for `var zeroTangentVectorInitializer` getter. static std::pair -deriveBodyDifferentiable_move(AbstractFunctionDecl *funcDecl, void *) { +deriveBodyDifferentiable_zeroTangentVectorInitializer( + AbstractFunctionDecl *funcDecl, void *) { auto &C = funcDecl->getASTContext(); - return deriveBodyDifferentiable_method(funcDecl, C.Id_move, C.Id_along); + auto *parentDC = funcDecl->getParent(); + auto *nominal = parentDC->getSelfNominalTypeDecl(); + + // Get method protocol requirement. + auto *diffProto = C.getProtocol(KnownProtocolKind::Differentiable); + auto *requirement = + getProtocolRequirement(diffProto, C.Id_zeroTangentVectorInitializer); + + auto nominalType = + parentDC->mapTypeIntoContext(nominal->getDeclaredInterfaceType()); + auto conf = TypeChecker::conformsToProtocol(nominalType, diffProto, parentDC); + auto tangentType = conf.getTypeWitnessByName(nominalType, C.Id_TangentVector); + auto *tangentTypeExpr = TypeExpr::createImplicit(tangentType, C); + + // Get differentiation properties. + SmallVector diffProperties; + getStoredPropertiesForDifferentiation(nominal, parentDC, diffProperties, + /*includeLetProperties*/ true); + + // Check whether memberwise derivation of `zeroTangentVectorInitializer` is + // possible. + bool canPerformMemberwiseDerivation = [&]() -> bool { + // Memberwise derivation is possible only for struct `TangentVector` types. + auto *tangentTypeDecl = tangentType->getAnyNominal(); + if (!tangentTypeDecl || !tangentTypeDecl->getSelfStructDecl()) + return false; + // Get effective memberwise initializer. + auto *memberwiseInitDecl = + tangentTypeDecl->getEffectiveMemberwiseInitializer(); + // Return false if number of memberwise initializer parameters does not + // equal number of differentiation properties. + if (memberwiseInitDecl->getParameters()->size() != diffProperties.size()) + return false; + // Iterate over all initializer parameters and differentiation properties. + for (auto pair : llvm::zip(memberwiseInitDecl->getParameters()->getArray(), + diffProperties)) { + auto *initParam = std::get<0>(pair); + auto *diffProp = std::get<1>(pair); + // Return false if parameter label does not equal property name. + if (initParam->getParameterName() != diffProp->getName()) + return false; + auto diffPropContextualType = + parentDC->mapTypeIntoContext(diffProp->getValueInterfaceType()); + auto diffPropTangentType = + getTangentVectorInterfaceType(diffPropContextualType, parentDC); + // Return false if parameter type does not equal property tangent type. + if (!initParam->getValueInterfaceType()->isEqual(diffPropTangentType)) + return false; + } + return true; + }(); + + // If memberwise derivation is not possible, synthesize + // `{ TangentVector.zero }` as a fallback. + if (!canPerformMemberwiseDerivation) { + auto *module = nominal->getModuleContext(); + auto *addArithProto = C.getProtocol(KnownProtocolKind::AdditiveArithmetic); + auto confRef = module->lookupConformance(tangentType, addArithProto); + assert(confRef && + "`TangentVector` does not conform to `AdditiveArithmetic`"); + auto *zeroDecl = getProtocolRequirement(addArithProto, C.Id_zero); + // If conformance reference is concrete, then use concrete witness + // declaration for the operator. + if (confRef.isConcrete()) + if (auto *witnessDecl = confRef.getConcrete()->getWitnessDecl(zeroDecl)) + zeroDecl = witnessDecl; + assert(zeroDecl && "Member method declaration must exist"); + auto *zeroExpr = + new (C) MemberRefExpr(tangentTypeExpr, SourceLoc(), zeroDecl, + DeclNameLoc(), /*Implicit*/ true); + + // Create closure expression. + DiscriminatorFinder DF; + for (Decl *D : parentDC->getParentSourceFile()->getTopLevelDecls()) + D->walk(DF); + auto discriminator = DF.getNextDiscriminator(); + auto resultTy = funcDecl->getMethodInterfaceType() + ->castTo() + ->getResult(); + + auto *closureParams = ParameterList::createEmpty(C); + auto *closure = new (C) ClosureExpr( + SourceRange(), /*capturedSelfDecl*/ nullptr, closureParams, SourceLoc(), + SourceLoc(), SourceLoc(), TypeExpr::createImplicit(resultTy, C), + discriminator, funcDecl); + closure->setImplicit(); + auto *closureReturn = new (C) ReturnStmt(SourceLoc(), zeroExpr, true); + auto *closureBody = + BraceStmt::create(C, SourceLoc(), {closureReturn}, SourceLoc(), true); + closure->setBody(closureBody, /*isSingleExpression=*/true); + + ASTNode returnStmt = new (C) ReturnStmt(SourceLoc(), closure, true); + auto *braceStmt = + BraceStmt::create(C, SourceLoc(), returnStmt, SourceLoc(), true); + return std::pair(braceStmt, false); + } + + // Otherwise, perform memberwise derivation. + // Get effective memberwise initializer: `Nominal.init(...)`. + auto *tangentTypeDecl = tangentType->getAnyNominal(); + auto *memberwiseInitDecl = + tangentTypeDecl->getEffectiveMemberwiseInitializer(); + assert(memberwiseInitDecl && "Memberwise initializer must exist"); + auto *initDRE = + new (C) DeclRefExpr(memberwiseInitDecl, DeclNameLoc(), /*Implicit*/ true); + initDRE->setFunctionRefKind(FunctionRefKind::SingleApply); + auto *initExpr = new (C) ConstructorRefCallExpr(initDRE, tangentTypeExpr); + + // Get references to `self` and parameter declarations. + auto *selfDecl = funcDecl->getImplicitSelfDecl(); + + // Create `self..zeroTangentVectorInitializer` capture list entry. + auto createMemberZeroTanInitCaptureListEntry = + [&](VarDecl *member) -> CaptureListEntry { + // Create `_zeroTangentVectorInitializer` capture var declaration. + auto memberCaptureName = C.getIdentifier(std::string(member->getNameStr()) + + "_zeroTangentVectorInitializer"); + auto *memberZeroTanInitCaptureDecl = new (C) VarDecl( + /*isStatic*/ false, VarDecl::Introducer::Let, /*isCaptureList*/ true, + SourceLoc(), memberCaptureName, funcDecl); + memberZeroTanInitCaptureDecl->setImplicit(); + auto *memberZeroTanInitPattern = + NamedPattern::createImplicit(C, memberZeroTanInitCaptureDecl); + + auto *module = nominal->getModuleContext(); + auto memberType = + parentDC->mapTypeIntoContext(member->getValueInterfaceType()); + auto confRef = module->lookupConformance(memberType, diffProto); + assert(confRef && "Member does not conform to `Differentiable`"); + + // Get member type's `zeroTangentVectorInitializer` requirement witness. + ValueDecl *memberWitnessDecl = requirement; + if (confRef.isConcrete()) + if (auto *witness = confRef.getConcrete()->getWitnessDecl(requirement)) + memberWitnessDecl = witness; + assert(memberWitnessDecl && "Member witness declaration must exist"); + + // .zeroTangentVectorInitializer + auto *selfDRE = + new (C) DeclRefExpr(selfDecl, DeclNameLoc(), /*Implicit*/ true); + auto *memberExpr = + new (C) MemberRefExpr(selfDRE, SourceLoc(), member, DeclNameLoc(), + /*Implicit*/ true); + auto *memberZeroTangentVectorInitExpr = + new (C) MemberRefExpr(memberExpr, SourceLoc(), memberWitnessDecl, + DeclNameLoc(), /*Implicit*/ true); + auto *memberZeroTanInitPBD = PatternBindingDecl::createImplicit( + C, StaticSpellingKind::None, memberZeroTanInitPattern, + memberZeroTangentVectorInitExpr, funcDecl); + CaptureListEntry captureEntry(memberZeroTanInitCaptureDecl, + memberZeroTanInitPBD); + return captureEntry; + }; + + // Create `_zeroTangentVectorInitializer()` call expression. + auto createMemberZeroTanInitCallExpr = + [&](CaptureListEntry memberZeroTanInitEntry) -> Expr * { + // _zeroTangentVectorInitializer + auto *memberZeroTanInitDRE = new (C) DeclRefExpr( + memberZeroTanInitEntry.Var, DeclNameLoc(), /*Implicit*/ true); + // _zeroTangentVectorInitializer() + auto *memberZeroTangentVector = + CallExpr::createImplicit(C, memberZeroTanInitDRE, {}, {}); + return memberZeroTangentVector; + }; + + // Collect member zero tangent vector expressions. + SmallVector memberNames; + SmallVector memberZeroTanExprs; + SmallVector memberZeroTanInitCaptures; + for (auto *member : diffProperties) { + memberNames.push_back(member->getName()); + auto memberZeroTanInitCapture = + createMemberZeroTanInitCaptureListEntry(member); + memberZeroTanInitCaptures.push_back(memberZeroTanInitCapture); + memberZeroTanExprs.push_back( + createMemberZeroTanInitCallExpr(memberZeroTanInitCapture)); + } + + // Create `zeroTangentVectorInitializer` closure body: + // `TangentVector(x: x_zeroTangentVectorInitializer(), ...)`. + auto *callExpr = + CallExpr::createImplicit(C, initExpr, memberZeroTanExprs, memberNames); + + // Create closure expression: + // `{ TangentVector(x: x_zeroTangentVectorInitializer(), ...) }`. + DiscriminatorFinder DF; + for (Decl *D : parentDC->getParentSourceFile()->getTopLevelDecls()) + D->walk(DF); + auto discriminator = DF.getNextDiscriminator(); + auto resultTy = funcDecl->getMethodInterfaceType() + ->castTo() + ->getResult(); + auto *closureParams = ParameterList::createEmpty(C); + auto *closure = new (C) ClosureExpr( + SourceRange(), /*capturedSelfDecl*/ nullptr, closureParams, SourceLoc(), + SourceLoc(), SourceLoc(), TypeExpr::createImplicit(resultTy, C), + discriminator, funcDecl); + closure->setImplicit(); + auto *closureReturn = new (C) ReturnStmt(SourceLoc(), callExpr, true); + auto *closureBody = + BraceStmt::create(C, SourceLoc(), {closureReturn}, SourceLoc(), true); + closure->setBody(closureBody, /*isSingleExpression=*/true); + + // Create capture list expression: + // ``` + // { [x_zeroTangentVectorInitializer = x.zeroTangentVectorInitializer, ...] in + // TangentVector(x: x_zeroTangentVectorInitializer(), ...) + // } + // ``` + auto *captureList = + CaptureListExpr::create(C, memberZeroTanInitCaptures, closure); + captureList->setImplicit(); + + ASTNode returnStmt = new (C) ReturnStmt(SourceLoc(), captureList, true); + auto *braceStmt = + BraceStmt::create(C, SourceLoc(), returnStmt, SourceLoc(), true); + return std::pair(braceStmt, false); } /// Synthesize function declaration for a `Differentiable` method requirement. @@ -316,15 +572,41 @@ static ValueDecl *deriveDifferentiable_method( static ValueDecl *deriveDifferentiable_move(DerivedConformance &derived) { auto &C = derived.Context; auto *parentDC = derived.getConformanceContext(); - - auto *tangentDecl = getTangentVectorStructDecl(parentDC); - auto tangentType = tangentDecl->getDeclaredInterfaceType(); - + auto tangentType = + getTangentVectorInterfaceType(parentDC->getSelfTypeInContext(), parentDC); return deriveDifferentiable_method( derived, C.Id_move, C.Id_along, C.Id_direction, tangentType, C.TheEmptyTupleType, {deriveBodyDifferentiable_move, nullptr}); } +/// Synthesize the `zeroTangentVectorInitializer` computed property declaration. +static ValueDecl * +deriveDifferentiable_zeroTangentVectorInitializer(DerivedConformance &derived) { + auto &C = derived.Context; + auto *parentDC = derived.getConformanceContext(); + + auto tangentType = + getTangentVectorInterfaceType(parentDC->getSelfTypeInContext(), parentDC); + auto returnType = FunctionType::get({}, tangentType); + + VarDecl *propDecl; + PatternBindingDecl *pbDecl; + std::tie(propDecl, pbDecl) = derived.declareDerivedProperty( + C.Id_zeroTangentVectorInitializer, returnType, returnType, + /*isStatic*/ false, /*isFinal*/ true); + + // Define the getter. + auto *getterDecl = + derived.addGetterToReadOnlyDerivedProperty(propDecl, returnType); + // Add an implicit `@noDerivative` attribute. + // `zeroTangentVectorInitializer` getter calls should never be differentiated. + getterDecl->getAttrs().add(new (C) NoDerivativeAttr(/*Implicit*/ true)); + getterDecl->setBodySynthesizer( + &deriveBodyDifferentiable_zeroTangentVectorInitializer); + derived.addMembersToConformanceContext({propDecl, pbDecl}); + return propDecl; +} + /// Return associated `TangentVector` struct for a nominal type, if it exists. /// If not, synthesize the struct. static StructDecl * @@ -368,24 +650,22 @@ getOrSynthesizeTangentVectorStruct(DerivedConformance &derived, Identifier id) { for (auto *member : diffProperties) { // Add this member's corresponding `TangentVector` type to the parent's // `TangentVector` struct. + // Note: `newMember` is not marked as implicit here, because that + // incorrectly affects memberwise initializer synthesis. auto *newMember = new (C) VarDecl( member->isStatic(), member->getIntroducer(), member->isCaptureList(), /*NameLoc*/ SourceLoc(), member->getName(), structDecl); - // NOTE: `newMember` is not marked as implicit here, because that affects - // memberwise initializer synthesis. - - auto memberAssocType = getTangentVectorType(member, parentDC); - auto memberAssocInterfaceType = memberAssocType->hasArchetype() - ? memberAssocType->mapTypeOutOfContext() - : memberAssocType; - auto memberAssocContextualType = - parentDC->mapTypeIntoContext(memberAssocInterfaceType); - newMember->setInterfaceType(memberAssocInterfaceType); + + auto memberContextualType = + parentDC->mapTypeIntoContext(member->getValueInterfaceType()); + auto memberTanType = + getTangentVectorInterfaceType(memberContextualType, parentDC); + newMember->setInterfaceType(memberTanType); Pattern *memberPattern = NamedPattern::createImplicit(C, newMember); - memberPattern->setType(memberAssocContextualType); - memberPattern = TypedPattern::createImplicit(C, memberPattern, - memberAssocContextualType); - memberPattern->setType(memberAssocContextualType); + memberPattern->setType(memberTanType); + memberPattern = + TypedPattern::createImplicit(C, memberPattern, memberTanType); + memberPattern->setType(memberTanType); auto *memberBinding = PatternBindingDecl::createImplicit( C, StaticSpellingKind::None, memberPattern, /*initExpr*/ nullptr, structDecl); @@ -582,13 +862,6 @@ getOrSynthesizeTangentVectorStructType(DerivedConformance &derived) { addAssociatedTypeAliasDecl(C.Id_TangentVector, tangentStruct, tangentStruct, C); - // Sanity checks for synthesized struct. - assert(DerivedConformance::canDeriveAdditiveArithmetic(tangentStruct, - parentDC) && - "Should be able to derive `AdditiveArithmetic`"); - assert(DerivedConformance::canDeriveDifferentiable(tangentStruct, parentDC) && - "Should be able to derive `Differentiable`"); - // Return the `TangentVector` struct type. return parentDC->mapTypeIntoContext( tangentStruct->getDeclaredInterfaceType()); @@ -599,82 +872,75 @@ static Type deriveDifferentiable_TangentVectorStruct(DerivedConformance &derived) { auto *parentDC = derived.getConformanceContext(); auto *nominal = derived.Nominal; - auto &C = nominal->getASTContext(); - - // Get all stored properties for differentation. - SmallVector diffProperties; - getStoredPropertiesForDifferentiation(nominal, parentDC, diffProperties); - // If any member has an invalid `TangentVector` type, return nullptr. - for (auto *member : diffProperties) - if (!getTangentVectorType(member, parentDC)) - return nullptr; - - // Prevent re-synthesis during repeated calls. - // FIXME: Investigate why this is necessary to prevent duplicate synthesis. - auto lookup = nominal->lookupDirect(C.Id_TangentVector); - if (lookup.size() == 1) - if (auto *structDecl = convertToStructDecl(lookup.front())) - if (structDecl->isImplicit()) - return structDecl->getDeclaredInterfaceType(); - - // Check whether at least one `@noDerivative` stored property exists. - unsigned numStoredProperties = - std::distance(nominal->getStoredProperties().begin(), - nominal->getStoredProperties().end()); - bool hasNoDerivativeStoredProp = diffProperties.size() != numStoredProperties; - - // Check conditions for returning `Self`. - // - `Self` is not a class type. - // - No `@noDerivative` stored properties exist. - // - All stored properties must have `TangentVector` type equal to `Self`. - // - Parent type must also conform to `AdditiveArithmetic`. - bool allMembersAssocTypeEqualsSelf = - llvm::all_of(diffProperties, [&](VarDecl *member) { - auto memberAssocType = getTangentVectorType(member, parentDC); - return member->getType()->isEqual(memberAssocType); - }); - - auto *addArithProto = C.getProtocol(KnownProtocolKind::AdditiveArithmetic); - auto nominalConformsToAddArith = TypeChecker::conformsToProtocol( - parentDC->getSelfTypeInContext(), addArithProto, parentDC); - - // Return `Self` if conditions are met. - if (!hasNoDerivativeStoredProp && !nominal->getSelfClassDecl() && - allMembersAssocTypeEqualsSelf && nominalConformsToAddArith) { - auto selfType = parentDC->getSelfTypeInContext(); - auto *aliasDecl = - new (C) TypeAliasDecl(SourceLoc(), SourceLoc(), C.Id_TangentVector, - SourceLoc(), {}, parentDC); - aliasDecl->setUnderlyingType(selfType); - aliasDecl->setImplicit(); - aliasDecl->copyFormalAccessFrom(nominal, /*sourceIsParentContext*/ true); - derived.addMembersToConformanceContext({aliasDecl}); - return selfType; - } + // If nominal type can derive `TangentVector` as the contextual `Self` type, + // return it. + if (canDeriveTangentVectorAsSelf(nominal, parentDC)) + return parentDC->getSelfTypeInContext(); // Otherwise, get or synthesize `TangentVector` struct type. return getOrSynthesizeTangentVectorStructType(derived); } ValueDecl *DerivedConformance::deriveDifferentiable(ValueDecl *requirement) { + // Diagnose unknown requirements. + if (requirement->getBaseName() != Context.Id_move && + requirement->getBaseName() != Context.Id_zeroTangentVectorInitializer) { + Context.Diags.diagnose(requirement->getLoc(), + diag::broken_differentiable_requirement); + return nullptr; + } // Diagnose conformances in disallowed contexts. if (checkAndDiagnoseDisallowedContext(requirement)) return nullptr; - if (requirement->getBaseName() == Context.Id_move) - return deriveDifferentiable_move(*this); - Context.Diags.diagnose(requirement->getLoc(), - diag::broken_differentiable_requirement); + + // Start an error diagnostic before attempting derivation. + // If derivation succeeds, cancel the diagnostic. + DiagnosticTransaction diagnosticTransaction(Context.Diags); + ConformanceDecl->diagnose(diag::type_does_not_conform, + Nominal->getDeclaredType(), getProtocolType()); + requirement->diagnose(diag::no_witnesses, + getProtocolRequirementKind(requirement), + requirement->getName(), getProtocolType(), + /*AddFixIt=*/false); + + // If derivation is possible, cancel the diagnostic and perform derivation. + if (canDeriveDifferentiable(Nominal, getConformanceContext(), requirement)) { + diagnosticTransaction.abort(); + if (requirement->getBaseName() == Context.Id_move) + return deriveDifferentiable_move(*this); + if (requirement->getBaseName() == Context.Id_zeroTangentVectorInitializer) + return deriveDifferentiable_zeroTangentVectorInitializer(*this); + } + + // Otheriwse, return nullptr. return nullptr; } Type DerivedConformance::deriveDifferentiable(AssociatedTypeDecl *requirement) { + // Diagnose unknown requirements. + if (requirement->getBaseName() != Context.Id_TangentVector) { + Context.Diags.diagnose(requirement->getLoc(), + diag::broken_differentiable_requirement); + return nullptr; + } // Diagnose conformances in disallowed contexts. if (checkAndDiagnoseDisallowedContext(requirement)) return nullptr; - if (requirement->getBaseName() == Context.Id_TangentVector) + + // Start an error diagnostic before attempting derivation. + // If derivation succeeds, cancel the diagnostic. + DiagnosticTransaction diagnosticTransaction(Context.Diags); + ConformanceDecl->diagnose(diag::type_does_not_conform, + Nominal->getDeclaredType(), getProtocolType()); + requirement->diagnose(diag::no_witnesses_type, requirement->getName()); + + // If derivation is possible, cancel the diagnostic and perform derivation. + if (canDeriveDifferentiable(Nominal, getConformanceContext(), requirement)) { + diagnosticTransaction.abort(); return deriveDifferentiable_TangentVectorStruct(*this); - Context.Diags.diagnose(requirement->getLoc(), - diag::broken_differentiable_requirement); + } + + // Otherwise, return nullptr. return nullptr; } diff --git a/lib/Sema/DerivedConformances.cpp b/lib/Sema/DerivedConformances.cpp index 771fcf01fd4be..46f6ad101ce1d 100644 --- a/lib/Sema/DerivedConformances.cpp +++ b/lib/Sema/DerivedConformances.cpp @@ -74,8 +74,11 @@ bool DerivedConformance::derivesProtocolConformance(DeclContext *DC, if (*derivableKind == KnownDerivableProtocolKind::AdditiveArithmetic) return canDeriveAdditiveArithmetic(Nominal, DC); + // Eagerly return true here. Actual synthesis conditions are checked in + // `DerivedConformance::deriveDifferentiable`: they are complicated and depend + // on the requirement being derived. if (*derivableKind == KnownDerivableProtocolKind::Differentiable) - return canDeriveDifferentiable(Nominal, DC); + return true; if (auto *enumDecl = dyn_cast(Nominal)) { switch (*derivableKind) { @@ -227,6 +230,10 @@ ValueDecl *DerivedConformance::getDerivableRequirement(NominalTypeDecl *nominal, if (name.isSimpleName(ctx.Id_intValue)) return getRequirement(KnownProtocolKind::CodingKey); + // Differentiable.zeroTangentVectorInitializer + if (name.isSimpleName(ctx.Id_zeroTangentVectorInitializer)) + return getRequirement(KnownProtocolKind::Differentiable); + // AdditiveArithmetic.zero if (name.isSimpleName(ctx.Id_zero)) return getRequirement(KnownProtocolKind::AdditiveArithmetic); diff --git a/lib/Sema/DerivedConformances.h b/lib/Sema/DerivedConformances.h index 0073f8edbbfd3..1f6296ee84fcb 100644 --- a/lib/Sema/DerivedConformances.h +++ b/lib/Sema/DerivedConformances.h @@ -107,10 +107,12 @@ class DerivedConformance { /// \returns the derived member, which will also be added to the type. ValueDecl *deriveAdditiveArithmetic(ValueDecl *requirement); - /// Determine if a Differentiable requirement can be derived for a type. + /// Determine if a Differentiable requirement can be derived for a nominal + /// type. /// /// \returns True if the requirement can be derived. - static bool canDeriveDifferentiable(NominalTypeDecl *type, DeclContext *DC); + static bool canDeriveDifferentiable(NominalTypeDecl *type, DeclContext *DC, + ValueDecl *requirement); /// Derive a Differentiable requirement for a nominal type. /// diff --git a/lib/Sema/TypeCheckProtocol.cpp b/lib/Sema/TypeCheckProtocol.cpp index de0c2a454d382..2eb55f5da0b62 100644 --- a/lib/Sema/TypeCheckProtocol.cpp +++ b/lib/Sema/TypeCheckProtocol.cpp @@ -2039,19 +2039,17 @@ static Type getRequirementTypeForDisplay(ModuleDecl *module, return substType(type, /*result*/false); } -/// Retrieve the kind of requirement described by the given declaration, -/// for use in some diagnostics. -static diag::RequirementKind getRequirementKind(ValueDecl *VD) { - if (isa(VD)) - return diag::RequirementKind::Constructor; +diag::RequirementKind +swift::getProtocolRequirementKind(ValueDecl *Requirement) { + assert(Requirement->isProtocolRequirement()); - if (isa(VD)) + if (isa(Requirement)) + return diag::RequirementKind::Constructor; + if (isa(Requirement)) return diag::RequirementKind::Func; - - if (isa(VD)) + if (isa(Requirement)) return diag::RequirementKind::Var; - - assert(isa(VD) && "Unhandled requirement kind"); + assert(isa(Requirement) && "Unhandled requirement kind"); return diag::RequirementKind::Subscript; } @@ -2254,7 +2252,7 @@ diagnoseMatch(ModuleDecl *module, NormalProtocolConformance *conformance, case MatchKind::KindConflict: diags.diagnose(match.Witness, diag::protocol_witness_kind_conflict, - getRequirementKind(req)); + getProtocolRequirementKind(req)); break; case MatchKind::WitnessInvalid: @@ -3053,13 +3051,14 @@ diagnoseMissingWitnesses(MissingWitnessDiagnosisKind Kind) { // If the protocol member decl is in the same file of the stub, // we can directly associate the fixit with the note issued to the // requirement. - Diags.diagnose(VD, diag::no_witnesses, getRequirementKind(VD), - VD->getName(), RequirementType, true) + Diags + .diagnose(VD, diag::no_witnesses, getProtocolRequirementKind(VD), + VD->getName(), RequirementType, true) .fixItInsertAfter(FixitLocation, FixIt); } else { // Otherwise, we have to issue another note to carry the fixit, // because editor may assume the fixit is in the same file with the note. - Diags.diagnose(VD, diag::no_witnesses, getRequirementKind(VD), + Diags.diagnose(VD, diag::no_witnesses, getProtocolRequirementKind(VD), VD->getName(), RequirementType, false); if (EditorMode) { Diags.diagnose(ComplainLoc, diag::missing_witnesses_general) @@ -3067,7 +3066,7 @@ diagnoseMissingWitnesses(MissingWitnessDiagnosisKind Kind) { } } } else { - Diags.diagnose(VD, diag::no_witnesses, getRequirementKind(VD), + Diags.diagnose(VD, diag::no_witnesses, getProtocolRequirementKind(VD), VD->getName(), RequirementType, true); } } @@ -3425,11 +3424,8 @@ ConformanceChecker::resolveWitnessViaLookup(ValueDecl *requirement) { auto &diags = DC->getASTContext().Diags; diags.diagnose(getLocForDiagnosingWitness(conformance, witness), - diagKind, - getRequirementKind(requirement), - witness->getName(), - isSetter, - requiredAccess, + diagKind, getProtocolRequirementKind(requirement), + witness->getName(), isSetter, requiredAccess, protoAccessScope.accessLevelForDiagnostics(), proto->getName()); if (auto *decl = dyn_cast(witness)) { @@ -3619,9 +3615,8 @@ ConformanceChecker::resolveWitnessViaLookup(ValueDecl *requirement) { diagnosticMessage = diag::ambiguous_witnesses_wrong_name; } diags.diagnose(requirement, diagnosticMessage, - getRequirementKind(requirement), - requirement->getName(), - reqType); + getProtocolRequirementKind(requirement), + requirement->getName(), reqType); // Diagnose each of the matches. for (const auto &match : matches) diff --git a/lib/Sema/TypeChecker.h b/lib/Sema/TypeChecker.h index ab1e52a13eeb6..cbae4220752e6 100644 --- a/lib/Sema/TypeChecker.h +++ b/lib/Sema/TypeChecker.h @@ -1319,6 +1319,12 @@ class EncodedDiagnosticMessage { const StringRef Message; }; +/// Returns the protocol requirement kind of the given declaration. +/// Used in diagnostics. +/// +/// Asserts that the given declaration is a protocol requirement. +diag::RequirementKind getProtocolRequirementKind(ValueDecl *Requirement); + /// Returns true if the given method is an valid implementation of a /// @dynamicCallable attribute requirement. The method is given to be defined /// as one of the following: `dynamicallyCall(withArguments:)` or diff --git a/stdlib/public/Differentiation/AnyDifferentiable.swift b/stdlib/public/Differentiation/AnyDifferentiable.swift index 421212f90037f..3a116104256c0 100644 --- a/stdlib/public/Differentiation/AnyDifferentiable.swift +++ b/stdlib/public/Differentiation/AnyDifferentiable.swift @@ -24,6 +24,7 @@ import Swift internal protocol _AnyDifferentiableBox { // `Differentiable` requirements. mutating func _move(along direction: AnyDerivative) + var _zeroTangentVectorInitializer: () -> AnyDerivative { get } /// The underlying base value, type-erased to `Any`. var _typeErasedBase: Any { get } @@ -59,6 +60,10 @@ internal struct _ConcreteDifferentiableBox: _AnyDifferentiabl } _base.move(along: directionBase) } + + var _zeroTangentVectorInitializer: () -> AnyDerivative { + { AnyDerivative(_base.zeroTangentVector) } + } } public struct AnyDifferentiable: Differentiable { @@ -103,6 +108,10 @@ public struct AnyDifferentiable: Differentiable { public mutating func move(along direction: TangentVector) { _box._move(along: direction) } + + public var zeroTangentVectorInitializer: () -> TangentVector { + _box._zeroTangentVectorInitializer + } } //===----------------------------------------------------------------------===// diff --git a/stdlib/public/Differentiation/ArrayDifferentiation.swift b/stdlib/public/Differentiation/ArrayDifferentiation.swift index cd1f20e308798..fbaef9c34fe80 100644 --- a/stdlib/public/Differentiation/ArrayDifferentiation.swift +++ b/stdlib/public/Differentiation/ArrayDifferentiation.swift @@ -168,8 +168,8 @@ extension Array: Differentiable where Element: Differentiable { /// A closure that produces a `TangentVector` of zeros with the same /// `count` as `self`. public var zeroTangentVectorInitializer: () -> TangentVector { - { [count = self.count] in - TangentVector(.init(repeating: .zero, count: count)) + { [zeroInits = map(\.zeroTangentVectorInitializer)] in + TangentVector(zeroInits.map { $0() }) } } } diff --git a/stdlib/public/Differentiation/Differentiable.swift b/stdlib/public/Differentiation/Differentiable.swift index 077144e40f1e9..1341e034811ef 100644 --- a/stdlib/public/Differentiation/Differentiable.swift +++ b/stdlib/public/Differentiation/Differentiable.swift @@ -80,21 +80,6 @@ public extension Differentiable where TangentVector == Self { } public extension Differentiable { - // This is a temporary solution enabling the addition of - // `zeroTangentVectorInitializer` without implementing derived conformances. - // This property will produce incorrect results when tangent vectors depend - // on instance-specific information from `self`. - // TODO: Implement derived conformances and remove this default - // implementation. - @available(*, deprecated, message: """ - `zeroTangentVectorInitializer` derivation has not been implemented; this \ - default implementation is not correct when tangent vectors depend on \ - instance-specific information from `self` and should not be used - """) - var zeroTangentVectorInitializer: () -> TangentVector { - { TangentVector.zero } - } - /// A tangent vector initialized using `zeroTangentVectorInitializer`. /// `move(along: zeroTangentVector)` should not modify `self`. var zeroTangentVector: TangentVector { zeroTangentVectorInitializer() } diff --git a/stdlib/public/Differentiation/FloatingPointDifferentiation.swift.gyb b/stdlib/public/Differentiation/FloatingPointDifferentiation.swift.gyb index e4d4ab68350a1..c4ea3fb01f06c 100644 --- a/stdlib/public/Differentiation/FloatingPointDifferentiation.swift.gyb +++ b/stdlib/public/Differentiation/FloatingPointDifferentiation.swift.gyb @@ -42,6 +42,11 @@ extension ${Self}: Differentiable { public mutating func move(along direction: TangentVector) { self += direction } + + @inlinable + public var zeroTangentVectorInitializer: () -> TangentVector { + { 0 } + } } //===----------------------------------------------------------------------===// diff --git a/stdlib/public/Differentiation/SIMDDifferentiation.swift.gyb b/stdlib/public/Differentiation/SIMDDifferentiation.swift.gyb index d70b7201d722d..a60ab8dd461ed 100644 --- a/stdlib/public/Differentiation/SIMDDifferentiation.swift.gyb +++ b/stdlib/public/Differentiation/SIMDDifferentiation.swift.gyb @@ -31,6 +31,11 @@ where Scalar.TangentVector: BinaryFloatingPoint { public typealias TangentVector = SIMD${n} + + @inlinable + public var zeroTangentVectorInitializer: () -> TangentVector { + { .init(repeating: 0) } + } } //===----------------------------------------------------------------------===// diff --git a/test/AutoDiff/Sema/DerivativeRegistrationCrossModule/Inputs/b.swift b/test/AutoDiff/Sema/DerivativeRegistrationCrossModule/Inputs/b.swift index 3d8acb8d2979c..90244a1973c87 100644 --- a/test/AutoDiff/Sema/DerivativeRegistrationCrossModule/Inputs/b.swift +++ b/test/AutoDiff/Sema/DerivativeRegistrationCrossModule/Inputs/b.swift @@ -4,6 +4,7 @@ import a extension Struct: Differentiable { public struct TangentVector: Differentiable & AdditiveArithmetic {} public mutating func move(along _: TangentVector) {} + public var zeroTangentVectorInitializer: () -> TangentVector { { .zero } } @usableFromInline @derivative(of: method, wrt: x) diff --git a/test/AutoDiff/Sema/DerivedConformances/derived_zero_tangent_vector_initializer.swift b/test/AutoDiff/Sema/DerivedConformances/derived_zero_tangent_vector_initializer.swift new file mode 100644 index 0000000000000..98ac63b275aaf --- /dev/null +++ b/test/AutoDiff/Sema/DerivedConformances/derived_zero_tangent_vector_initializer.swift @@ -0,0 +1,249 @@ +// RUN: %target-swift-frontend -emit-silgen %s | %FileCheck %s + +// Check `Differentiable.zeroTangentVectorInitializer` derivation. +// +// There are two cases: +// 1. Memberwise derivation. +// +// var zeroTangentVectorInitializer: () -> TangentVector { +// { [xZeroTanInit = x.zeroTangentVectorInitializer, +// yZeroTanInit = y.zeroTangentVectorInitializer, ...] in +// return TangentVector(x: xZeroTanInit(), y: yZeroTanInit(), ...) +// } +// } +// +// 2. `{ TangentVector.zero }` fallback derivation. +// +// var zeroTangentVectorInitializer: () -> TangentVector { +// { TangentVector.zero } +// } + +import _Differentiation + +// - MARK: Structs + +struct MemberwiseTangentVectorStruct: Differentiable { + var x: Float + var y: Double + + // Expected memberwise `zeroTangentVectorInitializer` synthesis (1). +} + +struct SelfTangentVectorStruct: Differentiable & AdditiveArithmetic { + var x: Float + var y: Double + typealias TangentVector = Self + + // Expected memberwise `zeroTangentVectorInitializer` synthesis (1). +} + +struct CustomTangentVectorStruct: Differentiable { + var x: T + var y: U + + typealias TangentVector = T.TangentVector + mutating func move(along direction: TangentVector) {} + + // Expected fallback `zeroTangentVectorInitializer` synthesis (2). +} + +// - MARK: Classes + +class MemberwiseTangentVectorClass: Differentiable { + var x: Float = 0.0 + var y: Double = 0.0 + + // Expected memberwise `zeroTangentVectorInitializer` synthesis (1). +} + +final class SelfTangentVectorClass: Differentiable & AdditiveArithmetic { + var x: Float = 0.0 + var y: Double = 0.0 + typealias TangentVector = SelfTangentVectorClass + + static func ==(lhs: SelfTangentVectorClass, rhs: SelfTangentVectorClass) -> Bool { fatalError() } + static var zero: Self { fatalError() } + static func +(lhs: SelfTangentVectorClass, rhs: SelfTangentVectorClass) -> Self { fatalError() } + static func -(lhs: SelfTangentVectorClass, rhs: SelfTangentVectorClass) -> Self { fatalError() } + + // Expected memberwise `zeroTangentVectorInitializer` synthesis (1). +} + +class CustomTangentVectorClass: Differentiable { + var x: T + var y: U + + init(x: T, y: U) { + self.x = x + self.y = y + } + + typealias TangentVector = T.TangentVector + func move(along direction: TangentVector) {} + + // Expected fallback `zeroTangentVectorInitializer` synthesis (2). +} + +// - MARK: Enums + +enum SelfTangentVectorEnum: Differentiable & AdditiveArithmetic { + case a([Float]) + case b([Float], Float) + case c + + typealias TangentVector = SelfTangentVectorEnum + + static func ==(lhs: Self, rhs: Self) -> Bool { fatalError() } + static var zero: Self { fatalError() } + static func +(lhs: Self, rhs: Self) -> Self { fatalError() } + static func -(lhs: Self, rhs: Self) -> Self { fatalError() } + + // TODO(TF-1012): Implement memberwise `zeroTangentVectorInitializer` synthesis for enums. + // Expected fallback `zeroTangentVectorInitializer` synthesis (2). +} + +enum CustomTangentVectorEnum: Differentiable { + case a(T) + + typealias TangentVector = T.TangentVector + mutating func move(along direction: TangentVector) {} + + // Expected fallback `zeroTangentVectorInitializer` synthesis (2). +} + +// CHECK-LABEL: // MemberwiseTangentVectorStruct.zeroTangentVectorInitializer.getter +// CHECK-NEXT: sil hidden [_semantics "autodiff.nonvarying"] [ossa] @${{.*}}MemberwiseTangentVectorStructV0bgH11InitializerAC0gH0Vycvg : $@convention(method) (MemberwiseTangentVectorStruct) -> @owned @callee_guaranteed () -> MemberwiseTangentVectorStruct.TangentVector { +// CHECK: bb0([[SELF:%.*]] : $MemberwiseTangentVectorStruct): +// CHECK: [[X_PROP:%.*]] = struct_extract [[SELF]] : $MemberwiseTangentVectorStruct, #MemberwiseTangentVectorStruct.x +// CHECK: [[X_ZERO_INIT_FN:%.*]] = function_ref @$sSf{{.*}}E28zeroTangentVectorInitializerSfycvg : $@convention(method) (Float) -> @owned @callee_guaranteed () -> Float +// CHECK: [[X_ZERO_INIT:%.*]] = apply [[X_ZERO_INIT_FN]]([[X_PROP]]) +// CHECK: [[Y_PROP:%.*]] = struct_extract [[SELF]] : $MemberwiseTangentVectorStruct, #MemberwiseTangentVectorStruct.y +// CHECK: [[Y_ZERO_INIT_FN:%.*]] = function_ref @$sSd{{.*}}E28zeroTangentVectorInitializerSdycvg : $@convention(method) (Double) -> @owned @callee_guaranteed () -> Double +// CHECK: [[Y_ZERO_INIT:%.*]] = apply [[Y_ZERO_INIT_FN]]([[Y_PROP]]) +// CHECK: // function_ref closure #1 in MemberwiseTangentVectorStruct.zeroTangentVectorInitializer.getter +// CHECK: [[CLOSURE_FN:%.*]] = function_ref @${{.*}}MemberwiseTangentVectorStructV0bgH11InitializerAC0gH0VycvgAFycfU_ +// CHECK: [[X_ZERO_INIT_COPY:%.*]] = copy_value [[X_ZERO_INIT]] +// CHECK: [[Y_ZERO_INIT_COPY:%.*]] = copy_value [[Y_ZERO_INIT]] +// CHECK: [[ZERO_INIT:%.*]] = partial_apply [callee_guaranteed] [[CLOSURE_FN]]([[X_ZERO_INIT_COPY]], [[Y_ZERO_INIT_COPY]]) +// CHECK: return [[ZERO_INIT]] : $@callee_guaranteed () -> MemberwiseTangentVectorStruct.TangentVector +// CHECK: } + +// CHECK-LABEL: // SelfTangentVectorStruct.zeroTangentVectorInitializer.getter +// CHECK-NEXT: sil hidden [_semantics "autodiff.nonvarying"] [ossa] @${{.*}}SelfTangentVectorStructV0bgH11InitializerACycvg : $@convention(method) (SelfTangentVectorStruct) -> @owned @callee_guaranteed () -> SelfTangentVectorStruct { +// CHECK: bb0([[SELF:%.*]] : $SelfTangentVectorStruct): +// CHECK: [[X_PROP:%.*]] = struct_extract [[SELF]] : $SelfTangentVectorStruct, #SelfTangentVectorStruct.x +// CHECK: [[X_ZERO_INIT_FN:%.*]] = function_ref @$sSf{{.*}}E28zeroTangentVectorInitializerSfycvg : $@convention(method) (Float) -> @owned @callee_guaranteed () -> Float +// CHECK: [[X_ZERO_INIT:%.*]] = apply [[X_ZERO_INIT_FN]]([[X_PROP]]) +// CHECK: [[Y_PROP:%.*]] = struct_extract [[SELF]] : $SelfTangentVectorStruct, #SelfTangentVectorStruct.y +// CHECK: [[Y_ZERO_INIT_FN:%.*]] = function_ref @$sSd{{.*}}E28zeroTangentVectorInitializerSdycvg : $@convention(method) (Double) -> @owned @callee_guaranteed () -> Double +// CHECK: [[Y_ZERO_INIT:%.*]] = apply [[Y_ZERO_INIT_FN]]([[Y_PROP]]) +// CHECK: // function_ref closure #2 in SelfTangentVectorStruct.zeroTangentVectorInitializer.getter +// CHECK: [[CLOSURE_FN:%.*]] = function_ref @${{.*}}SelfTangentVectorStructV0bgH11InitializerACycvgACycfU0_ +// CHECK: [[X_ZERO_INIT_COPY:%.*]] = copy_value [[X_ZERO_INIT]] +// CHECK: [[Y_ZERO_INIT_COPY:%.*]] = copy_value [[Y_ZERO_INIT]] +// CHECK: [[ZERO_INIT:%.*]] = partial_apply [callee_guaranteed] [[CLOSURE_FN]]([[X_ZERO_INIT_COPY]], [[Y_ZERO_INIT_COPY]]) +// CHECK: return [[ZERO_INIT]] : $@callee_guaranteed () -> SelfTangentVectorStruct +// CHECK: } + +// CHECK-LABEL: // CustomTangentVectorStruct.zeroTangentVectorInitializer.getter +// CHECK-NEXT: sil hidden [_semantics "autodiff.nonvarying"] [ossa] @${{.*}}CustomTangentVectorStructV0bgH11Initializer0gH0Qzycvg : $@convention(method) (@in_guaranteed CustomTangentVectorStruct) -> @owned @callee_guaranteed @substituted <Ï„_0_0> () -> @out Ï„_0_0 for { +// CHECK: bb0([[SELF:%.*]] : $*CustomTangentVectorStruct): +// CHECK: // function_ref closure #3 in CustomTangentVectorStruct.zeroTangentVectorInitializer.getter +// CHECK: function_ref @${{.*}}CustomTangentVectorStructV0bgH11Initializer0gH0QzycvgAFycfU1_ +// CHECK: } + +// CHECK-LABEL: // MemberwiseTangentVectorClass.zeroTangentVectorInitializer.getter +// CHECK-NEXT: sil hidden [_semantics "autodiff.nonvarying"] [ossa] @${{.*}}MemberwiseTangentVectorClassC0bgH11InitializerAC0gH0Vycvg : $@convention(method) (@guaranteed MemberwiseTangentVectorClass) -> @owned @callee_guaranteed () -> MemberwiseTangentVectorClass.TangentVector { +// CHECK: bb0([[SELF:%.*]] : @guaranteed $MemberwiseTangentVectorClass): +// CHECK: [[X_PROP_METHOD:%.*]] = class_method [[SELF]] : $MemberwiseTangentVectorClass, #MemberwiseTangentVectorClass.x!getter +// CHECK: [[X_PROP:%.*]] = apply [[X_PROP_METHOD]]([[SELF]]) +// CHECK: [[X_ZERO_INIT_FN:%.*]] = function_ref @$sSf{{.*}}E28zeroTangentVectorInitializerSfycvg : $@convention(method) (Float) -> @owned @callee_guaranteed () -> Float +// CHECK: [[X_ZERO_INIT:%.*]] = apply [[X_ZERO_INIT_FN]]([[X_PROP]]) +// CHECK: [[Y_PROP_METHOD:%.*]] = class_method [[SELF]] : $MemberwiseTangentVectorClass, #MemberwiseTangentVectorClass.y!getter +// CHECK: [[Y_PROP:%.*]] = apply [[Y_PROP_METHOD]]([[SELF]]) +// CHECK: [[Y_ZERO_INIT_FN:%.*]] = function_ref @$sSd{{.*}}E28zeroTangentVectorInitializerSdycvg : $@convention(method) (Double) -> @owned @callee_guaranteed () -> Double +// CHECK: [[Y_ZERO_INIT:%.*]] = apply [[Y_ZERO_INIT_FN]]([[Y_PROP]]) +// CHECK: // function_ref closure #4 in MemberwiseTangentVectorClass.zeroTangentVectorInitializer.getter +// CHECK: [[CLOSURE_FN:%.*]] = function_ref @${{.*}}MemberwiseTangentVectorClassC0bgH11InitializerAC0gH0VycvgAFycfU2_ +// CHECK: [[X_ZERO_INIT_COPY:%.*]] = copy_value [[X_ZERO_INIT]] +// CHECK: [[Y_ZERO_INIT_COPY:%.*]] = copy_value [[Y_ZERO_INIT]] +// CHECK: [[ZERO_INIT:%.*]] = partial_apply [callee_guaranteed] [[CLOSURE_FN]]([[X_ZERO_INIT_COPY]], [[Y_ZERO_INIT_COPY]]) +// CHECK: return [[ZERO_INIT]] : $@callee_guaranteed () -> MemberwiseTangentVectorClass.TangentVector +// CHECK: } + +// CHECK-LABEL: // SelfTangentVectorClass.zeroTangentVectorInitializer.getter +// CHECK-NEXT: sil hidden [_semantics "autodiff.nonvarying"] [ossa] @${{.*}}SelfTangentVectorClassC0bgH11InitializerACycvg : $@convention(method) (@guaranteed SelfTangentVectorClass) -> @owned @callee_guaranteed () -> @owned SelfTangentVectorClass { +// CHECK: bb0([[SELF:%.*]] : @guaranteed $SelfTangentVectorClass): +// CHECK: // function_ref closure #5 in SelfTangentVectorClass.zeroTangentVectorInitializer.getter +// CHECK: function_ref @${{.*}}SelfTangentVectorClassC0bgH11InitializerACycvgACycfU3_ +// CHECK: } + +// CHECK-LABEL: // CustomTangentVectorClass.zeroTangentVectorInitializer.getter +// CHECK-NEXT: sil hidden [_semantics "autodiff.nonvarying"] [ossa] @${{.*}}CustomTangentVectorClassC0bgH11Initializer0gH0Qzycvg : $@convention(method) (@guaranteed CustomTangentVectorClass) -> @owned @callee_guaranteed @substituted <Ï„_0_0> () -> @out Ï„_0_0 for { +// CHECK: bb0(%0 : @guaranteed $CustomTangentVectorClass): +// CHECK: // function_ref closure #6 in CustomTangentVectorClass.zeroTangentVectorInitializer.getter +// CHECK: function_ref @${{.*}}CustomTangentVectorClassC0bgH11Initializer0gH0QzycvgAFycfU4_ +// CHECK: } + +// CHECK-LABEL: // SelfTangentVectorEnum.zeroTangentVectorInitializer.getter +// CHECK-NEXT: sil hidden [_semantics "autodiff.nonvarying"] [ossa] @${{.*}}SelfTangentVectorEnumO0bgH11InitializerACycvg : $@convention(method) (@guaranteed SelfTangentVectorEnum) -> @owned @callee_guaranteed () -> @owned SelfTangentVectorEnum { +// CHECK: bb0([[SELF:%.*]] : @guaranteed $SelfTangentVectorEnum): +// CHECK: // function_ref closure #7 in SelfTangentVectorEnum.zeroTangentVectorInitializer.getter +// CHECK: function_ref @${{.*}}SelfTangentVectorEnumO0bgH11InitializerACycvgACycfU5_ +// CHECK: } + +// CHECK-LABEL: // CustomTangentVectorEnum.zeroTangentVectorInitializer.getter +// CHECK-NEXT: sil hidden [_semantics "autodiff.nonvarying"] [ossa] @${{.*}}CustomTangentVectorEnumO0bgH11Initializer0gH0Qzycvg : $@convention(method) (@in_guaranteed CustomTangentVectorEnum) -> @owned @callee_guaranteed @substituted <Ï„_0_0> () -> @out Ï„_0_0 for { +// CHECK: bb0([[SELF:%.*]] : $*CustomTangentVectorEnum): +// CHECK: // function_ref closure #8 in CustomTangentVectorEnum.zeroTangentVectorInitializer.getter +// CHECK: function_ref @${{.*}}CustomTangentVectorEnumO0bgH11Initializer0gH0QzycvgAFycfU6_ +// CHECK: } + +// CHECK-LABEL: // closure #1 in MemberwiseTangentVectorStruct.zeroTangentVectorInitializer.getter +// CHECK-NEXT: sil private [ossa] @${{.*}}MemberwiseTangentVectorStructV0bgH11InitializerAC0gH0VycvgAFycfU_ : $@convention(thin) (@guaranteed @callee_guaranteed () -> Float, @guaranteed @callee_guaranteed () -> Double) -> MemberwiseTangentVectorStruct.TangentVector { +// CHECK: // function_ref MemberwiseTangentVectorStruct.TangentVector.init(x:y:) +// CHECK-NOT: // function_ref static {{.*}}.zero.getter +// CHECK-NOT: witness_method {{.*}}, #AdditiveArithmetic.zero!getter +// CHECK: } + +// CHECK-LABEL: // closure #2 in SelfTangentVectorStruct.zeroTangentVectorInitializer.getter +// CHECK-NEXT: sil private [ossa] @${{.*}}SelfTangentVectorStructV0bgH11InitializerACycvgACycfU0_ : $@convention(thin) (@guaranteed @callee_guaranteed () -> Float, @guaranteed @callee_guaranteed () -> Double) -> SelfTangentVectorStruct { +// CHECK: // function_ref SelfTangentVectorStruct.init(x:y:) +// CHECK-NOT: // function_ref static {{.*}}.zero.getter +// CHECK-NOT: witness_method {{.*}}, #AdditiveArithmetic.zero!getter +// CHECK: } + +// CHECK-LABEL: // closure #3 in CustomTangentVectorStruct.zeroTangentVectorInitializer.getter +// CHECK-NEXT: sil private [ossa] @${{.*}}CustomTangentVectorStructV0bgH11Initializer0gH0QzycvgAFycfU1_ : $@convention(thin) () -> @out T.TangentVector { +// CHECK: witness_method $T.TangentVector, #AdditiveArithmetic.zero!getter +// CHECK: } + +// CHECK-LABEL: // closure #4 in MemberwiseTangentVectorClass.zeroTangentVectorInitializer.getter +// CHECK-NEXT: sil private [ossa] @${{.*}}MemberwiseTangentVectorClassC0bgH11InitializerAC0gH0VycvgAFycfU2_ : $@convention(thin) (@guaranteed @callee_guaranteed () -> Float, @guaranteed @callee_guaranteed () -> Double) -> MemberwiseTangentVectorClass.TangentVector { +// CHECK: // function_ref MemberwiseTangentVectorClass.TangentVector.init(x:y:) +// CHECK-NOT: // function_ref static {{.*}}.zero.getter +// CHECK-NOT: witness_method {{.*}}, #AdditiveArithmetic.zero!getter +// CHECK: } + +// CHECK-LABEL: // closure #5 in SelfTangentVectorClass.zeroTangentVectorInitializer.getter +// CHECK-NEXT: sil private [ossa] @${{.*}}SelfTangentVectorClassC0bgH11InitializerACycvgACycfU3_ : $@convention(thin) () -> @owned SelfTangentVectorClass { +// CHECK: // function_ref static SelfTangentVectorClass.zero.getter +// CHECK: function_ref @${{.*}}SelfTangentVectorClassC0B0ACXDvgZ : $@convention(method) (@thick SelfTangentVectorClass.Type) -> @owned SelfTangentVectorClass +// CHECK: } + +// CHECK-LABEL: // closure #6 in CustomTangentVectorClass.zeroTangentVectorInitializer.getter +// CHECK-NEXT: sil private [ossa] @${{.*}}CustomTangentVectorClassC0bgH11Initializer0gH0QzycvgAFycfU4_ : $@convention(thin) () -> @out T.TangentVector { +// CHECK: witness_method $T.TangentVector, #AdditiveArithmetic.zero!getter +// CHECK: } + +// TODO(TF-1012): Implement memberwise `zeroTangentVectorInitializer` synthesis for enums. +// CHECK-LABEL: // closure #7 in SelfTangentVectorEnum.zeroTangentVectorInitializer.getter +// CHECK-NEXT: sil private [ossa] @${{.*}}SelfTangentVectorEnumO0bgH11InitializerACycvgACycfU5_ : $@convention(thin) () -> @owned SelfTangentVectorEnum { +// CHECK: // function_ref static SelfTangentVectorEnum.zero.getter +// CHECK: function_ref @${{.*}}SelfTangentVectorEnumO0B0ACvgZ : $@convention(method) (@thin SelfTangentVectorEnum.Type) -> @owned SelfTangentVectorEnum +// CHECK: } + +// CHECK-LABEL: // closure #8 in CustomTangentVectorEnum.zeroTangentVectorInitializer.getter +// CHECK-NEXT: sil private [ossa] @$s39derived_zero_tangent_vector_initializer23CustomTangentVectorEnumO0bgH11Initializer0gH0QzycvgAFycfU6_ : $@convention(thin) () -> @out T.TangentVector { +// CHECK: witness_method $T.TangentVector, #AdditiveArithmetic.zero!getter +// CHECK: } diff --git a/test/AutoDiff/validation-test/zero_tangent_vector_initializer.swift b/test/AutoDiff/validation-test/zero_tangent_vector_initializer.swift new file mode 100644 index 0000000000000..05b8c8bb514fb --- /dev/null +++ b/test/AutoDiff/validation-test/zero_tangent_vector_initializer.swift @@ -0,0 +1,59 @@ +// RUN: %target-run-simple-swift +// REQUIRES: executable_test + +import _Differentiation +import StdlibUnittest + +var ZeroTangentVectorTests = TestSuite("zeroTangentVectorInitializer") + +struct Generic: Differentiable { + var x: T + var y: U +} + +struct Nested: Differentiable { + var generic: Generic +} + +ZeroTangentVectorTests.test("Derivation") { + typealias G = Generic<[Float], [[Float]]> + + let generic = G(x: [1, 2, 3], y: [[4, 5, 6], [], [2]]) + let genericZero = G.TangentVector(x: [0, 0, 0], y: [[0, 0, 0], [], [0]]) + expectEqual(generic.zeroTangentVector, genericZero) + + let nested = Nested(generic: generic) + let nestedZero = Nested.TangentVector(generic: genericZero) + expectEqual(nested.zeroTangentVector, nestedZero) +} + +// Test differentiation correctness involving projection operations and +// per-instance zeros. +ZeroTangentVectorTests.test("DifferentiationCorrectness") { + struct Struct: Differentiable { + var x, y: [Float] + } + func concatenated(_ lhs: Struct, _ rhs: Struct) -> Struct { + return Struct(x: lhs.x + rhs.x, y: lhs.y + rhs.y) + } + func test(_ s: Struct) -> [Float] { + let result = concatenated(s, s).withDerivative { dresult in + // FIXME(TF-1008): Fix incorrect derivative values for + // "projection operation" operands when differentiation transform uses + // `Differentiable.zeroTangentVectorInitializer`. + // Actual: TangentVector(x: [1.0, 1.0, 1.0], y: []) + // Expected: TangentVector(x: [1.0, 1.0, 1.0], y: [1.0, 1.0, 1.0]) + expectEqual(dresult, Struct.TangentVector(x: [1, 1, 1], y: [1, 1, 1])) + } + return result.x + } + let s = Struct(x: [1, 2, 3], y: [1, 2, 3]) + let pb = pullback(at: s, in: test) + // FIXME(TF-1008): Remove `expectCrash` when differentiation transform uses + // `Differentiable.zeroTangentVectorInitializer`. + expectCrash { + _ = pb([1, 1, 1]) + } +} + +runAllTests() From ff97ae798db922cddaac61ecdd2b72f84a87df6e Mon Sep 17 00:00:00 2001 From: Dan Zheng Date: Tue, 26 May 2020 18:50:42 -0700 Subject: [PATCH 23/28] [AutoDiff] Fix SIL locations and debug scopes. Fix SIL locations and debug scopes in `VJPEmitter` for: - Pullback struct `struct` instructions - Predecessor enums `enum` instructions These instructions are not directly cloned from the original function and should have auto-generated locations. Resolves SR-12887: debug scope error for `VJPEmitter`-generated function. --- .../SILOptimizer/Differentiation/VJPEmitter.h | 5 ++-- .../Differentiation/VJPEmitter.cpp | 24 ++++++++++--------- ...vjp-emitter-definite-initialization.swift} | 13 ++++++++-- 3 files changed, 27 insertions(+), 15 deletions(-) rename test/AutoDiff/compiler_crashers_fixed/{sr12886-clone-alloc-stack-dynamic-lifetime.swift => sr12886-sr112887-vjp-emitter-definite-initialization.swift} (76%) diff --git a/include/swift/SILOptimizer/Differentiation/VJPEmitter.h b/include/swift/SILOptimizer/Differentiation/VJPEmitter.h index 475ac7b3f7272..db196e7374b89 100644 --- a/include/swift/SILOptimizer/Differentiation/VJPEmitter.h +++ b/include/swift/SILOptimizer/Differentiation/VJPEmitter.h @@ -130,8 +130,9 @@ class VJPEmitter final StructInst *pbStructVal, SILBasicBlock *succBB); - /// Build a pullback struct value for the given original block. - StructInst *buildPullbackValueStructValue(SILBasicBlock *bb); + /// Build a pullback struct value for the given original terminator + /// instruction. + StructInst *buildPullbackValueStructValue(TermInst *termInst); /// Build a predecessor enum instance using the given builder for the given /// original predecessor/successor blocks and pullback struct value. diff --git a/lib/SILOptimizer/Differentiation/VJPEmitter.cpp b/lib/SILOptimizer/Differentiation/VJPEmitter.cpp index ec6b82cc7fb17..49e85bea7b1d0 100644 --- a/lib/SILOptimizer/Differentiation/VJPEmitter.cpp +++ b/lib/SILOptimizer/Differentiation/VJPEmitter.cpp @@ -315,9 +315,10 @@ SILType VJPEmitter::getNominalDeclLoweredType(NominalTypeDecl *nominal) { return getLoweredType(nominalType); } -StructInst *VJPEmitter::buildPullbackValueStructValue(SILBasicBlock *origBB) { - assert(origBB->getParent() == original); - auto loc = origBB->getParent()->getLocation(); +StructInst *VJPEmitter::buildPullbackValueStructValue(TermInst *termInst) { + assert(termInst->getFunction() == original); + auto loc = RegularLocation::getAutoGeneratedLocation(); + auto origBB = termInst->getParent(); auto *vjpBB = BBMap[origBB]; auto *pbStruct = pullbackInfo.getLinearMapStruct(origBB); auto structLoweredTy = getNominalDeclLoweredType(pbStruct); @@ -326,6 +327,7 @@ StructInst *VJPEmitter::buildPullbackValueStructValue(SILBasicBlock *origBB) { auto *predEnumArg = vjpBB->getArguments().back(); bbPullbackValues.insert(bbPullbackValues.begin(), predEnumArg); } + getBuilder().setCurrentDebugScope(getOpScope(termInst->getDebugScope())); return getBuilder().createStruct(loc, structLoweredTy, bbPullbackValues); } @@ -333,7 +335,7 @@ EnumInst *VJPEmitter::buildPredecessorEnumValue(SILBuilder &builder, SILBasicBlock *predBB, SILBasicBlock *succBB, SILValue pbStructVal) { - auto loc = pbStructVal.getLoc(); + auto loc = RegularLocation::getAutoGeneratedLocation(); auto *succEnum = pullbackInfo.getBranchingTraceDecl(succBB); auto enumLoweredTy = getNominalDeclLoweredType(succEnum); auto *enumEltDecl = @@ -361,7 +363,7 @@ void VJPEmitter::visitReturnInst(ReturnInst *ri) { // Build pullback struct value for original block. auto *origExit = ri->getParent(); - auto *pbStructVal = buildPullbackValueStructValue(origExit); + auto *pbStructVal = buildPullbackValueStructValue(ri); // Get the value in the VJP corresponding to the original result. auto *origRetInst = cast(origExit->getTerminator()); @@ -416,7 +418,7 @@ void VJPEmitter::visitBranchInst(BranchInst *bi) { // Build pullback struct value for original block. // Build predecessor enum value for destination block. auto *origBB = bi->getParent(); - auto *pbStructVal = buildPullbackValueStructValue(origBB); + auto *pbStructVal = buildPullbackValueStructValue(bi); auto *enumVal = buildPredecessorEnumValue(getBuilder(), origBB, bi->getDestBB(), pbStructVal); @@ -433,7 +435,7 @@ void VJPEmitter::visitBranchInst(BranchInst *bi) { void VJPEmitter::visitCondBranchInst(CondBranchInst *cbi) { // Build pullback struct value for original block. - auto *pbStructVal = buildPullbackValueStructValue(cbi->getParent()); + auto *pbStructVal = buildPullbackValueStructValue(cbi); // Create a new `cond_br` instruction. getBuilder().createCondBranch( cbi->getLoc(), getOpValue(cbi->getCondition()), @@ -443,7 +445,7 @@ void VJPEmitter::visitCondBranchInst(CondBranchInst *cbi) { void VJPEmitter::visitSwitchEnumInstBase(SwitchEnumInstBase *sei) { // Build pullback struct value for original block. - auto *pbStructVal = buildPullbackValueStructValue(sei->getParent()); + auto *pbStructVal = buildPullbackValueStructValue(sei); // Create trampoline successor basic blocks. SmallVector, 4> caseBBs; @@ -483,7 +485,7 @@ void VJPEmitter::visitSwitchEnumAddrInst(SwitchEnumAddrInst *seai) { void VJPEmitter::visitCheckedCastBranchInst(CheckedCastBranchInst *ccbi) { // Build pullback struct value for original block. - auto *pbStructVal = buildPullbackValueStructValue(ccbi->getParent()); + auto *pbStructVal = buildPullbackValueStructValue(ccbi); // Create a new `checked_cast_branch` instruction. getBuilder().createCheckedCastBranch( ccbi->getLoc(), ccbi->isExact(), getOpValue(ccbi->getOperand()), @@ -497,7 +499,7 @@ void VJPEmitter::visitCheckedCastBranchInst(CheckedCastBranchInst *ccbi) { void VJPEmitter::visitCheckedCastValueBranchInst( CheckedCastValueBranchInst *ccvbi) { // Build pullback struct value for original block. - auto *pbStructVal = buildPullbackValueStructValue(ccvbi->getParent()); + auto *pbStructVal = buildPullbackValueStructValue(ccvbi); // Create a new `checked_cast_value_branch` instruction. getBuilder().createCheckedCastValueBranch( ccvbi->getLoc(), getOpValue(ccvbi->getOperand()), @@ -511,7 +513,7 @@ void VJPEmitter::visitCheckedCastValueBranchInst( void VJPEmitter::visitCheckedCastAddrBranchInst( CheckedCastAddrBranchInst *ccabi) { // Build pullback struct value for original block. - auto *pbStructVal = buildPullbackValueStructValue(ccabi->getParent()); + auto *pbStructVal = buildPullbackValueStructValue(ccabi); // Create a new `checked_cast_addr_branch` instruction. getBuilder().createCheckedCastAddrBranch( ccabi->getLoc(), ccabi->getConsumptionKind(), getOpValue(ccabi->getSrc()), diff --git a/test/AutoDiff/compiler_crashers_fixed/sr12886-clone-alloc-stack-dynamic-lifetime.swift b/test/AutoDiff/compiler_crashers_fixed/sr12886-sr112887-vjp-emitter-definite-initialization.swift similarity index 76% rename from test/AutoDiff/compiler_crashers_fixed/sr12886-clone-alloc-stack-dynamic-lifetime.swift rename to test/AutoDiff/compiler_crashers_fixed/sr12886-sr112887-vjp-emitter-definite-initialization.swift index 67f441445ae4b..5d233b2b47337 100644 --- a/test/AutoDiff/compiler_crashers_fixed/sr12886-clone-alloc-stack-dynamic-lifetime.swift +++ b/test/AutoDiff/compiler_crashers_fixed/sr12886-sr112887-vjp-emitter-definite-initialization.swift @@ -1,9 +1,14 @@ // RUN: %target-build-swift %s // RUN: %target-swift-frontend -emit-sil %s | %FileCheck %s -// SR-12493: SIL memory lifetime verification error due to +// Test crashes related to differentiation and definite intiialization. + +// SR-12886: SIL memory lifetime verification error due to // `SILCloner::visitAllocStack` not copying the `[dynamic_lifetime]` attribute. +// SR-12887: Debug scope error for pullback struct `struct` instruction +// generated by `VJPEmitter`. + import _Differentiation enum Enum { @@ -36,7 +41,11 @@ struct Tensor: Differentiable { // CHECK-LABEL: sil hidden @AD__$s4main6TensorVyACyxGx_ADtcfC__vjp_src_0_wrt_1_l : $@convention(method) <Ï„_0_0> (@in Ï„_0_0, @in Tensor<Ï„_0_0>, @thin Tensor<Ï„_0_0>.Type) -> (@out Tensor<Ï„_0_0>, @owned @callee_guaranteed @substituted <Ï„_0_0, Ï„_0_1> (@in_guaranteed Ï„_0_0) -> @out Ï„_0_1 for .TangentVector, Tensor<Ï„_0_0>.TangentVector>) { // CHECK: [[SELF_ALLOC:%.*]] = alloc_stack [dynamic_lifetime] $Tensor<Ï„_0_0>, var, name "self" -// Original error: +// SR-12886 original error: // SIL memory lifetime failure in @AD__$s5crash6TensorVyACyxGx_ADtcfC__vjp_src_0_wrt_1_l: memory is not initialized, but should // memory location: %29 = struct_element_addr %5 : $*Tensor<Ï„_0_0>, #Tensor.x // user: %30 // at instruction: destroy_addr %29 : $*Ï„_0_0 // id: %30 + +// SR-12887 original error: +// SIL verification failed: Basic block contains a non-contiguous lexical scope at -Onone: DS == LastSeenScope +// %26 = struct $_AD__$s5crash6TensorVyACyxGx_ADtcfC_bb0__PB__src_0_wrt_1_l<Ï„_0_0> () // users: %34, %28 From 99a0919b345c2deb6177f5895d261e5991df1b0b Mon Sep 17 00:00:00 2001 From: Artem Chikin Date: Fri, 29 May 2020 09:52:22 -0700 Subject: [PATCH 24/28] [Fast Dependency Scanner] Ensure Swift modules don't depend on self. When resolving direct dependencies for a given Swift module, we go over all Clang module dependencies and add, as additional dependencies, their Swift overlays. We find overlays by querying `ASTContext::getModuleDependencies` with the Clang module's name. If the Clang module in question is a dependency of a Swift module with the same name, we will end up adding the Swift module as its own dependence. e.g. - Swift A depends on Clang A - Add Clang A to dependencies of Swift A - We look for overlays of Clang A, by name, and find Swift A - Add Swift A to dependencies of Swift A From what I can tell, the logic upstream is sound, and `getModuleDependencies` is doing the right thing, so this change is simply restricting what gets added when we are looking for overlays. Resolves rdar://problem/63731428 --- lib/FrontendTool/ScanDependencies.cpp | 5 ++++- test/ScanDependencies/module_deps.swift | 5 ++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/lib/FrontendTool/ScanDependencies.cpp b/lib/FrontendTool/ScanDependencies.cpp index 0778dcf3dc757..31ce01280bf81 100644 --- a/lib/FrontendTool/ScanDependencies.cpp +++ b/lib/FrontendTool/ScanDependencies.cpp @@ -130,7 +130,10 @@ static std::vector resolveDirectDependencies( for (const auto &clangDep : allClangModules) { if (auto found = ctx.getModuleDependencies( clangDep, /*onlyClangModule=*/false, cache, ASTDelegate)) { - if (found->getKind() == ModuleDependenciesKind::Swift) + // ASTContext::getModuleDependencies returns dependencies for a module with a given name. + // This Clang module may have the same name as the Swift module we are resolving, so we + // need to make sure we don't add a dependency from a Swift module to itself. + if (found->getKind() == ModuleDependenciesKind::Swift && clangDep != module.first) result.push_back({clangDep, found->getKind()}); } } diff --git a/test/ScanDependencies/module_deps.swift b/test/ScanDependencies/module_deps.swift index 84ff06c2b4c58..72885e5072f9b 100644 --- a/test/ScanDependencies/module_deps.swift +++ b/test/ScanDependencies/module_deps.swift @@ -126,10 +126,9 @@ import G // CHECK: "directDependencies" // CHECK-NEXT: { // CHECK-NEXT: "clang": "G" -// CHECK-NEXT: }, -// CHECK-NEXT: { -// CHECK-NEXT: "swift": "G" // CHECK-NEXT: } +// CHECK-NEXT: ], +// CHECK-NEXT: "details": { // CHECK: "contextHash": "{{.*}}", // CHECK: "commandLine": [ From bdef1cef23d683e463cf5a0cf218628efd4b7a00 Mon Sep 17 00:00:00 2001 From: Nate Chandler Date: Fri, 20 Dec 2019 15:37:39 -0800 Subject: [PATCH 25/28] [metadata prespecialization] Support for classes. When generic metadata for a class is requested in the same module where the class is defined, rather than a call to the generic metadata accessor or to a variant of typeForMangledNode, a call to a new accessor--a canonical specialized generic metadata accessor--is emitted. The new function is defined schematically as follows: MetadataResponse `canonical specialized metadata accessor for C`(MetadataRequest request) { (void)`canonical specialized metadata accessor for superclass(C)`(::Complete) (void)`canonical specialized metadata accessor for generic_argument_class(C, 1)`(::Complete) ... (void)`canonical specialized metadata accessor for generic_argument_class(C, count)`(::Complete) auto *metadata = objc_opt_self(`canonical specialized metadata for C`); return {metadata, MetadataState::Complete}; } where generic_argument_class(C, N) denotes the Nth generic argument which is both (1) itself a specialized generic type and is also (2) a class. These calls to the specialized metadata accessors for these related types ensure that all generic class types are registered with the Objective-C runtime. To enable these new canonical specialized generic metadata accessors, metadata for generic classes is prespecialized as needed. So are the metaclasses and the corresponding rodata. Previously, the lazy objc naming hook was registered during process execution when the first generic class metadata was instantiated. Since that instantiation may occur "before process launch" (i.e. if the generic metadata is prespecialized), the lazy naming hook is now installed at process launch. --- include/swift/ABI/Metadata.h | 2 +- include/swift/AST/Types.h | 12 +- include/swift/Demangling/DemangleNodes.def | 4 + include/swift/IRGen/Linking.h | 31 +- lib/AST/Type.cpp | 11 + lib/Demangling/Demangler.cpp | 6 + lib/Demangling/NodePrinter.cpp | 10 + lib/Demangling/OldRemangler.cpp | 11 + lib/Demangling/Remangler.cpp | 11 + lib/IRGen/ClassTypeInfo.h | 66 +++ lib/IRGen/GenClass.cpp | 247 ++++++---- lib/IRGen/GenClass.h | 8 + lib/IRGen/GenDecl.cpp | 115 ++++- lib/IRGen/GenMeta.cpp | 446 +++++++++++++----- lib/IRGen/GenMeta.h | 11 + lib/IRGen/IRGenMangler.h | 9 + lib/IRGen/IRGenModule.h | 19 +- lib/IRGen/Linking.cpp | 32 +- lib/IRGen/MetadataRequest.cpp | 280 +++++++++-- lib/IRGen/MetadataRequest.h | 20 +- stdlib/public/runtime/Metadata.cpp | 8 +- .../Inputs/isPrespecialized.cpp | 2 +- .../Inputs/isPrespecialized.h | 7 +- .../class-class-flags-run.swift | 47 ++ ...-2nd_argument_distinct_generic_class.swift | 394 ++++++++++++++++ ...t_same_generic_class_different_value.swift | 385 +++++++++++++++ ...gument_same_generic_class_same_value.swift | 383 +++++++++++++++ ...s_arg-2nd_anc_gen-1st-arg_subcls_arg.swift | 327 +++++++++++++ ...-1argument-1st_argument_constant_int.swift | 267 +++++++++++ ...ument-1st_argument_subclass_argument.swift | 315 +++++++++++++ ...ic-1argument-1st_argument_superclass.swift | 322 +++++++++++++ ...1st_argument_generic_class-1argument.swift | 297 ++++++++++++ ...ate-inmodule-1argument-1distinct_use.swift | 141 ++++++ ...module-1argument-1distinct_use_class.swift | 50 ++ ...ule-1argument-1distinct_use_function.swift | 39 ++ ...argument-1distinct_use_generic_class.swift | 306 ++++++++++++ ...c_class_specialized_at_generic_class.swift | 314 ++++++++++++ ...1argument-1distinct_use_generic_enum.swift | 272 +++++++++++ ...rgument-1distinct_use_generic_struct.swift | 277 +++++++++++ ...module-1argument-1distinct_use_tuple.swift | 39 ++ ...generic-1st-argument_constant_double.swift | 351 ++++++++++++++ ...neric-1st-argument_subclass_argument.swift | 335 +++++++++++++ ...or_generic-1st-argument_constant_int.swift | 338 +++++++++++++ ...lass-inmodule-1argument-metatype-run.swift | 28 ++ .../class-inmodule-1argument-run.swift | 74 +++ ...odule-2argument-1super-2argument-run.swift | 85 ++++ .../enum-trailing-flags-run.swift | 2 +- .../struct-trailing-flags-run.swift | 2 +- ...ditional_conformance_subclass_future.swift | 3 +- test/lit.cfg | 1 + 50 files changed, 6459 insertions(+), 303 deletions(-) create mode 100644 lib/IRGen/ClassTypeInfo.h create mode 100644 test/IRGen/prespecialized-metadata/class-class-flags-run.swift create mode 100644 test/IRGen/prespecialized-metadata/class-fileprivate-2argument-1_distinct_use-1st_argument_generic_class-2nd_argument_distinct_generic_class.swift create mode 100644 test/IRGen/prespecialized-metadata/class-fileprivate-2argument-1_distinct_use-1st_argument_generic_class-2nd_argument_same_generic_class_different_value.swift create mode 100644 test/IRGen/prespecialized-metadata/class-fileprivate-2argument-1_distinct_use-1st_argument_generic_class-2nd_argument_same_generic_class_same_value.swift create mode 100644 test/IRGen/prespecialized-metadata/class-fileprivate-inmodule-1arg-2ancs-1distinct_use-1st_anc_gen-1arg-1st_arg_subcls_arg-2nd_anc_gen-1st-arg_subcls_arg.swift create mode 100644 test/IRGen/prespecialized-metadata/class-fileprivate-inmodule-1argument-1ancestor-1distinct_use-1st_ancestor_generic-1argument-1st_argument_constant_int.swift create mode 100644 test/IRGen/prespecialized-metadata/class-fileprivate-inmodule-1argument-1ancestor-1distinct_use-1st_ancestor_generic-1argument-1st_argument_subclass_argument.swift create mode 100644 test/IRGen/prespecialized-metadata/class-fileprivate-inmodule-1argument-1ancestor-1distinct_use-1st_ancestor_generic-1argument-1st_argument_superclass.swift create mode 100644 test/IRGen/prespecialized-metadata/class-fileprivate-inmodule-1argument-1distinct_use-1st_argument_generic_class-1argument.swift create mode 100644 test/IRGen/prespecialized-metadata/class-fileprivate-inmodule-1argument-1distinct_use.swift create mode 100644 test/IRGen/prespecialized-metadata/class-fileprivate-inmodule-1argument-1distinct_use_class.swift create mode 100644 test/IRGen/prespecialized-metadata/class-fileprivate-inmodule-1argument-1distinct_use_function.swift create mode 100644 test/IRGen/prespecialized-metadata/class-fileprivate-inmodule-1argument-1distinct_use_generic_class.swift create mode 100644 test/IRGen/prespecialized-metadata/class-fileprivate-inmodule-1argument-1distinct_use_generic_class_specialized_at_generic_class.swift create mode 100644 test/IRGen/prespecialized-metadata/class-fileprivate-inmodule-1argument-1distinct_use_generic_enum.swift create mode 100644 test/IRGen/prespecialized-metadata/class-fileprivate-inmodule-1argument-1distinct_use_generic_struct.swift create mode 100644 test/IRGen/prespecialized-metadata/class-fileprivate-inmodule-1argument-1distinct_use_tuple.swift create mode 100644 test/IRGen/prespecialized-metadata/class-fileprivate-inmodule-1argument-2ancestors-1distinct_use-1st_ancestor_generic-1argument-1st_argument_constant_int-2nd_ancestor_generic-1st-argument_constant_double.swift create mode 100644 test/IRGen/prespecialized-metadata/class-fileprivate-inmodule-1argument-2ancestors-1distinct_use-1st_ancestor_generic-1argument-1st_argument_constant_int-2nd_ancestor_generic-1st-argument_subclass_argument.swift create mode 100644 test/IRGen/prespecialized-metadata/class-fileprivate-inmodule-1argument-2ancestors-1distinct_use-1st_ancestor_generic-1argument-1st_argument_subclass_argument-2nd_ancestor_generic-1st-argument_constant_int.swift create mode 100644 test/IRGen/prespecialized-metadata/class-inmodule-1argument-metatype-run.swift create mode 100644 test/IRGen/prespecialized-metadata/class-inmodule-1argument-run.swift create mode 100644 test/IRGen/prespecialized-metadata/class-inmodule-2argument-1super-2argument-run.swift diff --git a/include/swift/ABI/Metadata.h b/include/swift/ABI/Metadata.h index 8409c8dcd724e..8049387c4e191 100644 --- a/include/swift/ABI/Metadata.h +++ b/include/swift/ABI/Metadata.h @@ -1025,7 +1025,7 @@ struct TargetAnyClassMetadata : public TargetHeapMetadata { using TargetMetadata::setClassISA; #endif - // Note that ObjC classes does not have a metadata header. + // Note that ObjC classes do not have a metadata header. /// The metadata for the superclass. This is null for the root class. ConstTargetMetadataPointer Superclass; diff --git a/include/swift/AST/Types.h b/include/swift/AST/Types.h index 0bf41995878ab..6e8c2cd1b0330 100644 --- a/include/swift/AST/Types.h +++ b/include/swift/AST/Types.h @@ -852,7 +852,17 @@ class alignas(1 << TypeAlignInBits) TypeBase { /// \returns The superclass of this type, or a null type if it has no /// superclass. Type getSuperclass(bool useArchetypes = true); - + + /// Retrieve the root class of this type by repeatedly retrieving the + /// superclass. + /// + /// \param useArchetypes Whether to use context archetypes for outer generic + /// parameters if the class is nested inside a generic function. + /// + /// \returns The base class of this type, or this type itself if it has no + /// superclasses. + Type getRootClass(bool useArchetypes = true); + /// True if this type is the exact superclass of another type. /// /// \param ty The potential subclass. diff --git a/include/swift/Demangling/DemangleNodes.def b/include/swift/Demangling/DemangleNodes.def index 8321255140d40..5e534c3e901fe 100644 --- a/include/swift/Demangling/DemangleNodes.def +++ b/include/swift/Demangling/DemangleNodes.def @@ -287,5 +287,9 @@ NODE(OpaqueTypeDescriptorAccessorVar) NODE(OpaqueReturnType) CONTEXT_NODE(OpaqueReturnTypeOf) +// Added in Swift 5.3 +NODE(CanonicalSpecializedGenericMetaclass) +NODE(CanonicalSpecializedGenericTypeMetadataAccessFunction) + #undef CONTEXT_NODE #undef NODE diff --git a/include/swift/IRGen/Linking.h b/include/swift/IRGen/Linking.h index acfa09b4b9cf7..fac7f4510c30a 100644 --- a/include/swift/IRGen/Linking.h +++ b/include/swift/IRGen/Linking.h @@ -193,7 +193,7 @@ class LinkEntity { /// The nominal type descriptor for a nominal type. /// The pointer is a NominalTypeDecl*. NominalTypeDescriptor, - + /// The descriptor for an opaque type. /// The pointer is an OpaqueTypeDecl*. OpaqueTypeDescriptor, @@ -295,12 +295,12 @@ class LinkEntity { /// The descriptor for an extension. /// The pointer is an ExtensionDecl*. ExtensionDescriptor, - + /// The descriptor for a runtime-anonymous context. /// The pointer is the DeclContext* of a child of the context that should /// be considered private. AnonymousDescriptor, - + /// A SIL global variable. The pointer is a SILGlobalVariable*. SILGlobalVariable, @@ -384,6 +384,15 @@ class LinkEntity { /// A global function pointer for dynamically replaceable functions. DynamicallyReplaceableFunctionVariable, + + /// A reference to a metaclass-stub for a statically specialized generic + /// class. + /// The pointer is a canonical TypeBase*. + CanonicalSpecializedGenericSwiftMetaclassStub, + + /// An access function for prespecialized type metadata. + /// The pointer is a canonical TypeBase*. + CanonicalSpecializedGenericTypeMetadataAccessFunction, }; friend struct llvm::DenseMapInfo; @@ -1020,6 +1029,22 @@ class LinkEntity { return entity; } + static LinkEntity + forSpecializedGenericSwiftMetaclassStub(CanType concreteType) { + LinkEntity entity; + entity.setForType(Kind::CanonicalSpecializedGenericSwiftMetaclassStub, + concreteType); + return entity; + } + + static LinkEntity + forPrespecializedTypeMetadataAccessFunction(CanType theType) { + LinkEntity entity; + entity.setForType( + Kind::CanonicalSpecializedGenericTypeMetadataAccessFunction, theType); + return entity; + } + void mangle(llvm::raw_ostream &out) const; void mangle(SmallVectorImpl &buffer) const; std::string mangleAsString() const; diff --git a/lib/AST/Type.cpp b/lib/AST/Type.cpp index 7532ea5bbefcc..56f7302be72d1 100644 --- a/lib/AST/Type.cpp +++ b/lib/AST/Type.cpp @@ -1546,6 +1546,17 @@ Type TypeBase::getSuperclass(bool useArchetypes) { return superclassTy.subst(subMap); } +Type TypeBase::getRootClass(bool useArchetypes) { + Type iterator = this; + assert(iterator); + + while (auto superclass = iterator->getSuperclass(useArchetypes)) { + iterator = superclass; + } + + return iterator; +} + bool TypeBase::isExactSuperclassOf(Type ty) { // For there to be a superclass relationship, we must be a class, and // the potential subtype must be a class, superclass-bounded archetype, diff --git a/lib/Demangling/Demangler.cpp b/lib/Demangling/Demangler.cpp index 2ea37950dd96e..f447c2cef583b 100644 --- a/lib/Demangling/Demangler.cpp +++ b/lib/Demangling/Demangler.cpp @@ -1874,6 +1874,9 @@ NodePointer Demangler::demangleMetatype() { case 'A': return createWithChild(Node::Kind::ReflectionMetadataAssocTypeDescriptor, popProtocolConformance()); + case 'b': + return createWithPoppedType( + Node::Kind::CanonicalSpecializedGenericTypeMetadataAccessFunction); case 'B': return createWithChild(Node::Kind::ReflectionMetadataBuiltinDescriptor, popNode(Node::Kind::Type)); @@ -1917,6 +1920,9 @@ NodePointer Demangler::demangleMetatype() { return createWithPoppedType(Node::Kind::TypeMetadataLazyCache); case 'm': return createWithPoppedType(Node::Kind::Metaclass); + case 'M': + return createWithPoppedType( + Node::Kind::CanonicalSpecializedGenericMetaclass); case 'n': return createWithPoppedType(Node::Kind::NominalTypeDescriptor); case 'o': diff --git a/lib/Demangling/NodePrinter.cpp b/lib/Demangling/NodePrinter.cpp index 655cbac01ea24..22d52337aae39 100644 --- a/lib/Demangling/NodePrinter.cpp +++ b/lib/Demangling/NodePrinter.cpp @@ -540,6 +540,8 @@ class NodePrinter { case Node::Kind::OpaqueTypeDescriptorSymbolicReference: case Node::Kind::OpaqueReturnType: case Node::Kind::OpaqueReturnTypeOf: + case Node::Kind::CanonicalSpecializedGenericMetaclass: + case Node::Kind::CanonicalSpecializedGenericTypeMetadataAccessFunction: return false; } printer_unreachable("bad node kind"); @@ -2418,6 +2420,14 @@ NodePointer NodePrinter::print(NodePointer Node, bool asPrefixContext) { case Node::Kind::AccessorFunctionReference: Printer << "accessor function at " << Node->getIndex(); return nullptr; + case Node::Kind::CanonicalSpecializedGenericMetaclass: + Printer << "specialized generic metaclass for "; + print(Node->getFirstChild()); + return nullptr; + case Node::Kind::CanonicalSpecializedGenericTypeMetadataAccessFunction: + Printer << "canonical specialized generic type metadata accessor for "; + print(Node->getChild(0)); + return nullptr; } printer_unreachable("bad node kind!"); } diff --git a/lib/Demangling/OldRemangler.cpp b/lib/Demangling/OldRemangler.cpp index 09b1896d9f9e1..02e974a926ee0 100644 --- a/lib/Demangling/OldRemangler.cpp +++ b/lib/Demangling/OldRemangler.cpp @@ -2133,6 +2133,17 @@ void Remangler::mangleAccessorFunctionReference(Node *node) { unreachable("can't remangle"); } +void Remangler::mangleCanonicalSpecializedGenericMetaclass(Node *node) { + Buffer << "MM"; + mangleSingleChildNode(node); // type +} + +void Remangler::mangleCanonicalSpecializedGenericTypeMetadataAccessFunction( + Node *node) { + mangleSingleChildNode(node); + Buffer << "Mb"; +} + /// The top-level interface to the remangler. std::string Demangle::mangleNodeOld(NodePointer node) { if (!node) return ""; diff --git a/lib/Demangling/Remangler.cpp b/lib/Demangling/Remangler.cpp index 4e9907d7d9699..3c8832acac374 100644 --- a/lib/Demangling/Remangler.cpp +++ b/lib/Demangling/Remangler.cpp @@ -2533,6 +2533,17 @@ void Remangler::mangleAccessorFunctionReference(Node *node) { unreachable("can't remangle"); } +void Remangler::mangleCanonicalSpecializedGenericMetaclass(Node *node) { + mangleChildNodes(node); + Buffer << "MM"; +} + +void Remangler::mangleCanonicalSpecializedGenericTypeMetadataAccessFunction( + Node *node) { + mangleSingleChildNode(node); + Buffer << "Mb"; +} + } // anonymous namespace /// The top-level interface to the remangler. diff --git a/lib/IRGen/ClassTypeInfo.h b/lib/IRGen/ClassTypeInfo.h new file mode 100644 index 0000000000000..0e47a01df8a43 --- /dev/null +++ b/lib/IRGen/ClassTypeInfo.h @@ -0,0 +1,66 @@ +//===--- ClassTypeInfo.h - The layout info for class types. -----*- C++ -*-===// +// +// This source file is part of the Swift.org open source project +// +// Copyright (c) 2014 - 2020 Apple Inc. and the Swift project authors +// Licensed under Apache License v2.0 with Runtime Library Exception +// +// See https://swift.org/LICENSE.txt for license information +// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors +// +//===----------------------------------------------------------------------===// +// +// This file contains layout information for class types. +// +//===----------------------------------------------------------------------===// + +#ifndef SWIFT_IRGEN_CLASSTYPEINFO_H +#define SWIFT_IRGEN_CLASSTYPEINFO_H + +#include "ClassLayout.h" +#include "HeapTypeInfo.h" + +namespace swift { +namespace irgen { + +/// Layout information for class types. +class ClassTypeInfo : public HeapTypeInfo { + ClassDecl *TheClass; + + // The resilient layout of the class, without making any assumptions + // that violate resilience boundaries. This is used to allocate + // and deallocate instances of the class, and to access fields. + mutable Optional ResilientLayout; + + // A completely fragile layout, used for metadata emission. + mutable Optional FragileLayout; + + /// Can we use swift reference-counting, or do we have to use + /// objc_retain/release? + const ReferenceCounting Refcount; + + ClassLayout generateLayout(IRGenModule &IGM, SILType classType, + bool forBackwardDeployment) const; + +public: + ClassTypeInfo(llvm::PointerType *irType, Size size, SpareBitVector spareBits, + Alignment align, ClassDecl *theClass, + ReferenceCounting refcount) + : HeapTypeInfo(irType, size, std::move(spareBits), align), + TheClass(theClass), Refcount(refcount) {} + + ReferenceCounting getReferenceCounting() const { return Refcount; } + + ClassDecl *getClass() const { return TheClass; } + + const ClassLayout &getClassLayout(IRGenModule &IGM, SILType type, + bool forBackwardDeployment) const; + + StructLayout *createLayoutWithTailElems(IRGenModule &IGM, SILType classType, + ArrayRef tailTypes) const; +}; + +} // namespace irgen +} // namespace swift + +#endif diff --git a/lib/IRGen/GenClass.cpp b/lib/IRGen/GenClass.cpp index 5c1d736fb99e4..db231ee086ced 100644 --- a/lib/IRGen/GenClass.cpp +++ b/lib/IRGen/GenClass.cpp @@ -32,6 +32,7 @@ #include "swift/SIL/SILModule.h" #include "swift/SIL/SILType.h" #include "swift/SIL/SILVTableVisitor.h" +#include "llvm/ADT/None.h" #include "llvm/ADT/SmallString.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Function.h" @@ -40,72 +41,29 @@ #include "Callee.h" #include "ClassLayout.h" +#include "ClassTypeInfo.h" #include "ConstantBuilder.h" #include "Explosion.h" #include "GenFunc.h" +#include "GenHeap.h" #include "GenMeta.h" #include "GenObjC.h" #include "GenPointerAuth.h" #include "GenProto.h" #include "GenType.h" +#include "HeapTypeInfo.h" #include "IRGenDebugInfo.h" #include "IRGenFunction.h" #include "IRGenModule.h" -#include "GenHeap.h" -#include "HeapTypeInfo.h" #include "MemberAccessStrategy.h" #include "MetadataLayout.h" #include "MetadataRequest.h" - using namespace swift; using namespace irgen; -namespace { - /// Layout information for class types. - class ClassTypeInfo : public HeapTypeInfo { - ClassDecl *TheClass; - - // The resilient layout of the class, without making any assumptions - // that violate resilience boundaries. This is used to allocate - // and deallocate instances of the class, and to access fields. - mutable Optional ResilientLayout; - - // A completely fragile layout, used for metadata emission. - mutable Optional FragileLayout; - - /// Can we use swift reference-counting, or do we have to use - /// objc_retain/release? - const ReferenceCounting Refcount; - - ClassLayout generateLayout(IRGenModule &IGM, SILType classType, - bool forBackwardDeployment) const; - - public: - ClassTypeInfo(llvm::PointerType *irType, Size size, - SpareBitVector spareBits, Alignment align, - ClassDecl *theClass, ReferenceCounting refcount) - : HeapTypeInfo(irType, size, std::move(spareBits), align), - TheClass(theClass), Refcount(refcount) {} - - ReferenceCounting getReferenceCounting() const { - return Refcount; - } - - ClassDecl *getClass() const { return TheClass; } - - - const ClassLayout &getClassLayout(IRGenModule &IGM, SILType type, - bool forBackwardDeployment) const; - - StructLayout *createLayoutWithTailElems(IRGenModule &IGM, - SILType classType, - ArrayRef tailTypes) const; - }; -} // end anonymous namespace - /// Return the lowered type for the class's 'self' type within its context. -static SILType getSelfType(const ClassDecl *base) { +SILType irgen::getSelfType(const ClassDecl *base) { auto loweredTy = base->getDeclaredTypeInContext()->getCanonicalType(); return SILType::getPrimitiveObjectType(loweredTy); } @@ -328,12 +286,17 @@ namespace { auto element = ElementLayout::getIncomplete(*eltType); bool isKnownEmpty = !addField(element, LayoutStrategy::Universal); + bool isSpecializedGeneric = + (theClass->isGenericContext() && !classType.getASTType() + ->getRecursiveProperties() + .hasUnboundGeneric()); + // The 'Elements' list only contains superclass fields when we're // building a layout for tail allocation. - if (!superclass || TailTypes) + if (!superclass || TailTypes || isSpecializedGeneric) Elements.push_back(element); - if (!superclass) { + if (!superclass || isSpecializedGeneric) { AllStoredProperties.push_back(var); AllFieldAccesses.push_back(getFieldAccess(isKnownEmpty)); } @@ -964,26 +927,49 @@ namespace { /// category data (category_t), or protocol data (protocol_t). class ClassDataBuilder : public ClassMemberVisitor { IRGenModule &IGM; - PointerUnion TheEntity; + using ClassPair = std::pair; + using ClassUnion = TaggedUnion; + TaggedUnion TheEntity; ExtensionDecl *TheExtension; const ClassLayout *FieldLayout; ClassDecl *getClass() const { - return TheEntity.get(); + const ClassUnion *classUnion; + if (!(classUnion = TheEntity.dyn_cast())) { + return nullptr; + } + if (auto *const *theClass = classUnion->dyn_cast()) { + return *theClass; + } + auto pair = classUnion->get(); + return pair.first; } ProtocolDecl *getProtocol() const { - return TheEntity.get(); + if (auto *const *theProtocol = TheEntity.dyn_cast()) { + return *theProtocol; + } + return nullptr; } - + Optional getSpecializedGenericType() const { + const ClassUnion *classUnion; + if (!(classUnion = TheEntity.dyn_cast())) { + return llvm::None; + } + const ClassPair *classPair; + if (!(classPair = classUnion->dyn_cast())) { + return llvm::None; + } + auto &pair = *classPair; + return pair.second; + } + bool isBuildingClass() const { - return TheEntity.is() && !TheExtension; + return TheEntity.isa() && !TheExtension; } bool isBuildingCategory() const { - return TheEntity.is() && TheExtension; - } - bool isBuildingProtocol() const { - return TheEntity.is(); + return TheEntity.isa() && TheExtension; } + bool isBuildingProtocol() const { return TheEntity.isa(); } bool HasNonTrivialDestructor = false; bool HasNonTrivialConstructor = false; @@ -1063,23 +1049,27 @@ namespace { public: ClassDataBuilder(IRGenModule &IGM, ClassDecl *theClass, const ClassLayout &fieldLayout) - : IGM(IGM), TheEntity(theClass), TheExtension(nullptr), - FieldLayout(&fieldLayout) - { - visitConformances(theClass); - visitMembers(theClass); - - if (Lowering::usesObjCAllocator(theClass)) { - addIVarInitializer(); - addIVarDestroyer(); + : ClassDataBuilder(IGM, ClassUnion(theClass), fieldLayout) {} + + ClassDataBuilder( + IRGenModule &IGM, + TaggedUnion> theUnion, + const ClassLayout &fieldLayout) + : IGM(IGM), TheEntity(theUnion), TheExtension(nullptr), + FieldLayout(&fieldLayout) { + visitConformances(getClass()); + visitMembers(getClass()); + + if (Lowering::usesObjCAllocator(getClass())) { + addIVarInitializer(); + addIVarDestroyer(); } } - + ClassDataBuilder(IRGenModule &IGM, ClassDecl *theClass, ExtensionDecl *theExtension) - : IGM(IGM), TheEntity(theClass), TheExtension(theExtension), - FieldLayout(nullptr) - { + : IGM(IGM), TheEntity(ClassUnion(theClass)), TheExtension(theExtension), + FieldLayout(nullptr) { buildCategoryName(CategoryName); visitConformances(theExtension); @@ -1087,7 +1077,7 @@ namespace { for (Decl *member : TheExtension->getMembers()) visit(member); } - + ClassDataBuilder(IRGenModule &IGM, ProtocolDecl *theProtocol) : IGM(IGM), TheEntity(theProtocol), TheExtension(nullptr) { @@ -1143,7 +1133,12 @@ namespace { } } - llvm::Constant *getMetaclassRefOrNull(ClassDecl *theClass) { + llvm::Constant *getMetaclassRefOrNull(Type specializedGenericType, + ClassDecl *theClass) { + if (specializedGenericType) { + return IGM.getAddrOfCanonicalSpecializedGenericMetaclassObject( + specializedGenericType->getCanonicalType(), NotForDefinition); + } if (theClass->isGenericContext() && !theClass->hasClangNode()) { return llvm::ConstantPointerNull::get(IGM.ObjCClassPtrTy); } else { @@ -1153,9 +1148,20 @@ namespace { void buildMetaclassStub() { assert(FieldLayout && "can't build a metaclass from a category"); + + auto specializedGenericType = getSpecializedGenericType().map( + [](auto canType) { return (Type)canType; }); + // The isa is the metaclass pointer for the root class. - auto rootClass = getRootClassForMetaclass(IGM, TheEntity.get()); - auto rootPtr = getMetaclassRefOrNull(rootClass); + auto rootClass = getRootClassForMetaclass(IGM, getClass()); + Type rootType; + if (specializedGenericType && rootClass->isGenericContext()) { + rootType = + (*specializedGenericType)->getRootClass(/*useArchetypes=*/false); + } else { + rootType = Type(); + } + auto rootPtr = getMetaclassRefOrNull(rootType, rootClass); // The superclass of the metaclass is the metaclass of the // superclass. Note that for metaclass stubs, we can always @@ -1166,10 +1172,16 @@ namespace { llvm::Constant *superPtr; if (getClass()->hasSuperclass()) { auto base = getClass()->getSuperclassDecl(); - superPtr = getMetaclassRefOrNull(base); + if (specializedGenericType && base->isGenericContext()) { + superPtr = getMetaclassRefOrNull( + (*specializedGenericType)->getSuperclass(/*useArchetypes=*/false), + base); + } else { + superPtr = getMetaclassRefOrNull(Type(), base); + } } else { superPtr = getMetaclassRefOrNull( - IGM.getObjCRuntimeBaseForSwiftRootClass(getClass())); + Type(), IGM.getObjCRuntimeBaseForSwiftRootClass(getClass())); } auto dataPtr = emitROData(ForMetaClass, DoesNotHaveUpdateCallback); @@ -1184,9 +1196,16 @@ namespace { }; auto init = llvm::ConstantStruct::get(IGM.ObjCClassStructTy, makeArrayRef(fields)); - auto metaclass = - cast( - IGM.getAddrOfMetaclassObject(getClass(), ForDefinition)); + llvm::Constant *uncastMetaclass; + if (auto theType = getSpecializedGenericType()) { + uncastMetaclass = + IGM.getAddrOfCanonicalSpecializedGenericMetaclassObject( + *theType, ForDefinition); + } else { + uncastMetaclass = + IGM.getAddrOfMetaclassObject(getClass(), ForDefinition); + } + auto metaclass = cast(uncastMetaclass); metaclass->setInitializer(init); } @@ -1349,9 +1368,20 @@ namespace { b.addInt32(0); } - // const uint8_t *ivarLayout; - // GC/ARC layout. TODO. - b.addNullPointer(IGM.Int8PtrTy); + // union { + // const uint8_t *IvarLayout; + // ClassMetadata *NonMetaClass; + // }; + Optional specializedGenericType; + if ((specializedGenericType = getSpecializedGenericType()) && forMeta) { + // ClassMetadata *NonMetaClass; + b.addBitCast(IGM.getAddrOfTypeMetadata(*specializedGenericType), + IGM.Int8PtrTy); + } else { + // const uint8_t *IvarLayout; + // GC/ARC layout. TODO. + b.addNullPointer(IGM.Int8PtrTy); + } // const char *name; // It is correct to use the same name for both class and metaclass. @@ -1383,9 +1413,8 @@ namespace { if (hasUpdater) { // Class _Nullable (*metadataUpdateCallback)(Class _Nonnull cls, // void * _Nullable arg); - auto *impl = IGM.getAddrOfObjCMetadataUpdateFunction( - TheEntity.get(), - NotForDefinition); + auto *impl = IGM.getAddrOfObjCMetadataUpdateFunction(getClass(), + NotForDefinition); const auto &schema = IGM.getOptions().PointerAuth.ObjCMethodListFunctionPointers; b.addSignedPointer(impl, schema, PointerAuthEntity()); @@ -1995,14 +2024,14 @@ namespace { /// Get the name of the class or protocol to mangle into the ObjC symbol /// name. StringRef getEntityName(llvm::SmallVectorImpl &buffer) const { - if (auto theClass = TheEntity.dyn_cast()) { + if (auto theClass = getClass()) { return theClass->getObjCRuntimeName(buffer); } - - if (auto theProtocol = TheEntity.dyn_cast()) { + + if (auto theProtocol = getProtocol()) { return theProtocol->getObjCRuntimeName(buffer); } - + llvm_unreachable("not a class or protocol?!"); } @@ -2137,18 +2166,27 @@ void IRGenModule::emitObjCResilientClassStub(ClassDecl *D) { defineAlias(entity, objcStub); } -/// Emit the private data (RO-data) associated with a class. -llvm::Constant *irgen::emitClassPrivateData(IRGenModule &IGM, - ClassDecl *cls) { +static llvm::Constant *doEmitClassPrivateData( + IRGenModule &IGM, + TaggedUnion> classUnion) { assert(IGM.ObjCInterop && "emitting RO-data outside of interop mode"); - PrettyStackTraceDecl stackTraceRAII("emitting ObjC metadata for", cls); + + ClassDecl *cls; + + if (auto *theClass = classUnion.dyn_cast()) { + cls = *theClass; + } else { + auto pair = classUnion.get>(); + cls = pair.first; + } + SILType selfType = getSelfType(cls); auto &classTI = IGM.getTypeInfo(selfType).as(); // FIXME: For now, always use the fragile layout when emitting metadata. auto &fieldLayout = classTI.getClassLayout(IGM, selfType, /*forBackwardDeployment=*/true); - ClassDataBuilder builder(IGM, cls, fieldLayout); + ClassDataBuilder builder(IGM, classUnion, fieldLayout); // First, build the metaclass object. builder.buildMetaclassStub(); @@ -2171,6 +2209,25 @@ llvm::Constant *irgen::emitClassPrivateData(IRGenModule &IGM, return builder.emitROData(ForClass, hasUpdater); } +llvm::Constant *irgen::emitSpecializedGenericClassPrivateData( + IRGenModule &IGM, ClassDecl *theClass, CanType theType) { + assert(theType->getClassOrBoundGenericClass() == theClass); + assert(theClass->getGenericEnvironment()); + Type ty = theType; + PrettyStackTraceType stackTraceRAII(theClass->getASTContext(), + "emitting ObjC metadata for", ty); + return doEmitClassPrivateData( + IGM, TaggedUnion>( + std::make_pair(theClass, theType))); +} + +/// Emit the private data (RO-data) associated with a class. +llvm::Constant *irgen::emitClassPrivateData(IRGenModule &IGM, ClassDecl *cls) { + PrettyStackTraceDecl stackTraceRAII("emitting ObjC metadata for", cls); + return doEmitClassPrivateData( + IGM, TaggedUnion>(cls)); +} + std::pair irgen::emitClassPrivateDataFields(IRGenModule &IGM, ConstantStructBuilder &init, diff --git a/lib/IRGen/GenClass.h b/lib/IRGen/GenClass.h index 834b10d32ba72..549a4a502ec3c 100644 --- a/lib/IRGen/GenClass.h +++ b/lib/IRGen/GenClass.h @@ -51,6 +51,9 @@ namespace irgen { enum class ClassDeallocationKind : unsigned char; enum class FieldAccess : uint8_t; + /// Return the lowered type for the class's 'self' type within its context. + SILType getSelfType(const ClassDecl *base); + OwnedAddress projectPhysicalClassMemberAddress( IRGenFunction &IGF, llvm::Value *base, SILType baseType, SILType fieldType, VarDecl *field); @@ -118,6 +121,11 @@ namespace irgen { ClassDecl *cls); llvm::Constant *emitClassPrivateData(IRGenModule &IGM, ClassDecl *theClass); + + llvm::Constant *emitSpecializedGenericClassPrivateData(IRGenModule &IGM, + ClassDecl *theClass, + CanType theType); + void emitGenericClassPrivateDataTemplate(IRGenModule &IGM, ClassDecl *theClass, llvm::SmallVectorImpl &fields, diff --git a/lib/IRGen/GenDecl.cpp b/lib/IRGen/GenDecl.cpp index 97ad57e20d1a6..290f1d24a3138 100644 --- a/lib/IRGen/GenDecl.cpp +++ b/lib/IRGen/GenDecl.cpp @@ -38,6 +38,7 @@ #include "clang/AST/DeclCXX.h" #include "clang/AST/GlobalDecl.h" #include "llvm/ADT/SmallString.h" +#include "llvm/IR/DerivedTypes.h" #include "llvm/IR/GlobalAlias.h" #include "llvm/IR/InlineAsm.h" #include "llvm/IR/Module.h" @@ -1164,7 +1165,8 @@ void IRGenerator::emitLazyDefinitions() { !LazyOpaqueTypeDescriptors.empty() || !LazyFieldDescriptors.empty() || !LazyFunctionDefinitions.empty() || - !LazyWitnessTables.empty()) { + !LazyWitnessTables.empty() || + !LazyCanonicalSpecializedMetadataAccessors.empty()) { // Emit any lazy type metadata we require. while (!LazyTypeMetadata.empty()) { @@ -1220,6 +1222,15 @@ void IRGenerator::emitLazyDefinitions() { && "function with externally-visible linkage emitted lazily?"); IGM->emitSILFunction(f); } + + while (!LazyCanonicalSpecializedMetadataAccessors.empty()) { + CanType theType = + LazyCanonicalSpecializedMetadataAccessors.pop_back_val(); + auto *nominal = theType->getAnyNominal(); + assert(nominal); + CurrentIGMPtr IGM = getGenModule(nominal->getDeclContext()); + emitLazyCanonicalSpecializedMetadataAccessor(*IGM.get(), theType); + } } while (!LazyMetadataAccessors.empty()) { @@ -1394,14 +1405,30 @@ void IRGenerator::noteUseOfFieldDescriptor(NominalTypeDecl *type) { LazyFieldDescriptors.push_back(type); } +void IRGenerator::noteUseOfCanonicalSpecializedMetadataAccessor( + CanType forType) { + auto key = forType->getAnyNominal(); + assert(key); + assert(key->isGenericContext()); + auto &enqueuedSpecializedAccessors = + CanonicalSpecializedAccessorsForGenericTypes[key]; + if (llvm::all_of(enqueuedSpecializedAccessors, + [&](CanType enqueued) { return enqueued != forType; })) { + assert(!FinishedEmittingLazyDefinitions); + LazyCanonicalSpecializedMetadataAccessors.insert(forType); + enqueuedSpecializedAccessors.push_back(forType); + } +} + void IRGenerator::noteUseOfSpecializedGenericTypeMetadata(CanType type) { auto key = type->getAnyNominal(); assert(key); - auto &enqueuedSpecializedTypes = this->SpecializationsForGenericTypes[key]; + assert(key->isGenericContext()); + auto &enqueuedSpecializedTypes = CanonicalSpecializationsForGenericTypes[key]; if (llvm::all_of(enqueuedSpecializedTypes, [&](CanType enqueued) { return enqueued != type; })) { assert(!FinishedEmittingLazyDefinitions); - this->LazySpecializedTypeMetadataRecords.push_back(type); + LazySpecializedTypeMetadataRecords.push_back(type); enqueuedSpecializedTypes.push_back(type); } } @@ -3516,6 +3543,22 @@ IRGenModule::getAddrOfMetaclassObject(ClassDecl *decl, return addr; } +llvm::Constant * +IRGenModule::getAddrOfCanonicalSpecializedGenericMetaclassObject( + CanType concreteType, ForDefinition_t forDefinition) { + auto *theClass = concreteType->getClassOrBoundGenericClass(); + assert(theClass && "only classes have metaclasses"); + assert(concreteType->getClassOrBoundGenericClass()->isGenericContext()); + + auto entity = + LinkEntity::forSpecializedGenericSwiftMetaclassStub(concreteType); + + auto DbgTy = DebugTypeInfo::getObjCClass( + theClass, ObjCClassPtrTy, getPointerSize(), getPointerAlignment()); + auto addr = getAddrOfLLVMVariable(entity, forDefinition, DbgTy); + return addr; +} + /// Fetch the declaration of an Objective-C metadata update callback. llvm::Function * IRGenModule::getAddrOfObjCMetadataUpdateFunction(ClassDecl *classDecl, @@ -3640,6 +3683,38 @@ IRGenModule::getAddrOfGenericTypeMetadataAccessFunction( return entry; } +llvm::Function * +IRGenModule::getAddrOfCanonicalSpecializedGenericTypeMetadataAccessFunction( + CanType theType, ForDefinition_t forDefinition) { + assert(shouldPrespecializeGenericMetadata()); + assert(!theType->hasUnboundGenericType()); + auto *nominal = theType->getAnyNominal(); + assert(nominal); + assert(nominal->isGenericContext()); + + IRGen.noteUseOfCanonicalSpecializedMetadataAccessor(theType); + + LinkEntity entity = + LinkEntity::forPrespecializedTypeMetadataAccessFunction(theType); + llvm::Function *&entry = GlobalFuncs[entity]; + if (entry) { + if (forDefinition) + updateLinkageForDefinition(*this, entry, entity); + return entry; + } + + llvm::Type *paramTypesArray[1]; + paramTypesArray[0] = SizeTy; // MetadataRequest + + auto paramTypes = llvm::makeArrayRef(paramTypesArray, 1); + auto functionType = + llvm::FunctionType::get(TypeMetadataResponseTy, paramTypes, false); + Signature signature(functionType, llvm::AttributeList(), SwiftCC); + LinkInfo link = LinkInfo::get(*this, entity, forDefinition); + entry = createFunction(*this, link, signature); + return entry; +} + /// Get or create a type metadata cache variable. These are an /// implementation detail of type metadata access functions. llvm::Constant * @@ -3826,14 +3901,19 @@ ConstantReference IRGenModule::getAddrOfTypeMetadata(CanType concreteType, llvm::Type *defaultVarTy; unsigned adjustmentIndex; - bool fullMetadata = (nominal && requiresForeignTypeMetadata(nominal)) || - (concreteType->getAnyGeneric() && - concreteType->getAnyGeneric()->isGenericContext()); + bool foreign = nominal && requiresForeignTypeMetadata(nominal); + bool fullMetadata = + foreign || (concreteType->getAnyGeneric() && + concreteType->getAnyGeneric()->isGenericContext()); // Foreign classes reference the full metadata with a GEP. if (fullMetadata) { defaultVarTy = FullTypeMetadataStructTy; - adjustmentIndex = MetadataAdjustmentIndex::ValueType; + if (concreteType->getClassOrBoundGenericClass() && !foreign) { + adjustmentIndex = MetadataAdjustmentIndex::Class; + } else { + adjustmentIndex = MetadataAdjustmentIndex::ValueType; + } // The symbol for other nominal type metadata is generated at the address // point. } else if (nominal) { @@ -3870,19 +3950,24 @@ ConstantReference IRGenModule::getAddrOfTypeMetadata(CanType concreteType, if (fullMetadata) { entity = LinkEntity::forTypeMetadata(concreteType, TypeMetadataAddress::FullMetadata); - DbgTy = DebugTypeInfo::getMetadata(MetatypeType::get(concreteType), - defaultVarTy->getPointerTo(), Size(0), - Alignment(1));; } else { entity = LinkEntity::forTypeMetadata(concreteType, TypeMetadataAddress::AddressPoint); - DbgTy = DebugTypeInfo::getMetadata(MetatypeType::get(concreteType), - defaultVarTy->getPointerTo(), Size(0), - Alignment(1));; + } + DbgTy = DebugTypeInfo::getMetadata(MetatypeType::get(concreteType), + defaultVarTy->getPointerTo(), Size(0), + Alignment(1)); + + ConstantReference addr; + + if (fullMetadata && !foreign) { + addr = getAddrOfLLVMVariable(*entity, ConstantInit(), DbgTy, refKind, + /*overrideDeclType=*/nullptr); + } else { + addr = getAddrOfLLVMVariable(*entity, ConstantInit(), DbgTy, refKind, + /*overrideDeclType=*/defaultVarTy); } - auto addr = getAddrOfLLVMVariable(*entity, ConstantInit(), DbgTy, refKind, - defaultVarTy); if (auto *GV = dyn_cast(addr.getValue())) GV->setComdat(nullptr); diff --git a/lib/IRGen/GenMeta.cpp b/lib/IRGen/GenMeta.cpp index 7eca86628b80c..541c627bf15f4 100644 --- a/lib/IRGen/GenMeta.cpp +++ b/lib/IRGen/GenMeta.cpp @@ -44,6 +44,7 @@ #include "Callee.h" #include "ClassLayout.h" #include "ClassMetadataVisitor.h" +#include "ClassTypeInfo.h" #include "ConstantBuilder.h" #include "EnumMetadataVisitor.h" #include "FixedTypeInfo.h" @@ -1901,6 +1902,25 @@ void irgen::emitLazyMetadataAccessor(IRGenModule &IGM, isReadNone); } +void irgen::emitLazyCanonicalSpecializedMetadataAccessor(IRGenModule &IGM, + CanType theType) { + llvm::Function *accessor = + IGM.getAddrOfCanonicalSpecializedGenericTypeMetadataAccessFunction( + theType, ForDefinition); + + if (IGM.getOptions().optimizeForSize()) { + accessor->addFnAttr(llvm::Attribute::NoInline); + } + + emitCacheAccessFunction( + IGM, accessor, /*cache=*/nullptr, CacheStrategy::None, + [&](IRGenFunction &IGF, Explosion ¶ms) { + return emitCanonicalSpecializedGenericTypeMetadataAccessFunction( + IGF, params, theType); + }, + /*isReadNone=*/true); +} + void irgen::emitLazySpecializedGenericTypeMetadata(IRGenModule &IGM, CanType type) { switch (type->getKind()) { @@ -1914,6 +1934,11 @@ void irgen::emitLazySpecializedGenericTypeMetadata(IRGenModule &IGM, emitSpecializedGenericEnumMetadata(IGM, type, *type.getEnumOrBoundGenericEnum()); break; + case TypeKind::Class: + case TypeKind::BoundGenericClass: + emitSpecializedGenericClassMetadata(IGM, type, + *type.getClassOrBoundGenericClass()); + break; default: llvm_unreachable("Cannot statically specialize types of kind other than " "struct and enum."); @@ -2662,6 +2687,8 @@ namespace { using super = ClassMetadataVisitor; protected: + using NominalDecl = ClassDecl; + using super::asImpl; using super::IGM; using super::Target; @@ -2683,26 +2710,34 @@ namespace { VTable(IGM.getSILModule().lookUpVTable(theClass)) {} public: + SILType getLoweredType() { + return IGM.getLoweredType(Target->getDeclaredTypeInContext()); + } + void noteAddressPoint() { ClassMetadataVisitor::noteAddressPoint(); AddressPoint = B.getNextOffsetFromGlobal(); } - void addClassFlags() { - B.addInt32((uint32_t) getClassFlags(Target)); - } + ClassFlags getClassFlags() { return ::getClassFlags(Target); } + + void addClassFlags() { B.addInt32((uint32_t)asImpl().getClassFlags()); } void noteResilientSuperclass() {} void noteStartOfImmediateMembers(ClassDecl *theClass) {} - void addValueWitnessTable() { + ConstantReference getValueWitnessTable(bool relativeReference) { + assert( + !relativeReference && + "Cannot get a relative reference to a class' value witness table."); switch (IGM.getClassMetadataStrategy(Target)) { case ClassMetadataStrategy::Resilient: case ClassMetadataStrategy::Singleton: // The runtime fills in the value witness table for us. - B.add(llvm::ConstantPointerNull::get(IGM.WitnessTablePtrTy)); - break; + return ConstantReference( + llvm::ConstantPointerNull::get(IGM.WitnessTablePtrTy), + swift::irgen::ConstantReference::Direct); case ClassMetadataStrategy::Update: case ClassMetadataStrategy::FixedOrUpdate: @@ -2712,12 +2747,20 @@ namespace { ? IGM.Context.getAnyObjectType() : IGM.Context.TheNativeObjectType); auto wtable = IGM.getAddrOfValueWitnessTable(type); - B.add(wtable); - break; + return ConstantReference(wtable, + swift::irgen::ConstantReference::Direct); } } } + void addValueWitnessTable() { + B.add(asImpl().getValueWitnessTable(false).getValue()); + } + + llvm::Constant *getAddrOfMetaclassObject(ForDefinition_t forDefinition) { + return IGM.getAddrOfMetaclassObject(Target, forDefinition); + } + /// The 'metadata flags' field in a class is actually a pointer to /// the metaclass object for the class. /// @@ -2730,8 +2773,7 @@ namespace { if (IGM.ObjCInterop) { // Get the metaclass pointer as an intptr_t. - auto metaclass = IGM.getAddrOfMetaclassObject(Target, - NotForDefinition); + auto metaclass = asImpl().getAddrOfMetaclassObject(NotForDefinition); auto flags = llvm::ConstantExpr::getPtrToInt(metaclass, IGM.MetadataKindTy); B.add(flags); @@ -2744,18 +2786,32 @@ namespace { } } - void addSuperclass() { + llvm::Constant *getSuperclassMetadata() { + Type type = Target->mapTypeIntoContext(Target->getSuperclass()); + auto *metadata = + tryEmitConstantHeapMetadataRef(IGM, type->getCanonicalType(), + /*allowUninit*/ false); + return metadata; + } + + bool shouldAddNullSuperclass() { // If we might have generic ancestry, leave a placeholder since - // swift_initClassMetdata() will fill in the superclass. + // swift_initClassMetadata() will fill in the superclass. switch (IGM.getClassMetadataStrategy(Target)) { case ClassMetadataStrategy::Resilient: case ClassMetadataStrategy::Singleton: - B.addNullPointer(IGM.TypeMetadataPtrTy); - return; + return true; case ClassMetadataStrategy::Update: case ClassMetadataStrategy::FixedOrUpdate: case ClassMetadataStrategy::Fixed: - break; + return false; + } + } + + void addSuperclass() { + if (asImpl().shouldAddNullSuperclass()) { + B.addNullPointer(IGM.TypeMetadataPtrTy); + return; } // If this is a root class, use SwiftObject as our formal parent. @@ -2775,10 +2831,7 @@ namespace { return; } - Type type = Target->mapTypeIntoContext(Target->getSuperclass()); - auto *metadata = tryEmitConstantHeapMetadataRef( - IGM, type->getCanonicalType(), - /*allowUninit*/ false); + auto *metadata = asImpl().getSuperclassMetadata(); assert(metadata != nullptr); B.add(metadata); } @@ -2813,8 +2866,12 @@ namespace { return ClassContextDescriptorBuilder(IGM, Target, RequireMetadata).emit(); } + llvm::Constant *getNominalTypeDescriptor() { + return emitNominalTypeDescriptor(); + } + void addNominalTypeDescriptor() { - B.addSignedPointer(emitNominalTypeDescriptor(), + B.addSignedPointer(asImpl().getNominalTypeDescriptor(), IGM.getOptions().PointerAuth.TypeDescriptors, PointerAuthEntity::Special::TypeDescriptor); } @@ -2832,9 +2889,13 @@ namespace { B.addInt32(0); } + bool hasFixedLayout() { return FieldLayout.isFixedLayout(); } + + const ClassLayout &getFieldLayout() { return FieldLayout; } + void addInstanceSize() { - if (FieldLayout.isFixedLayout()) { - B.addInt32(FieldLayout.getSize().getValue()); + if (asImpl().hasFixedLayout()) { + B.addInt32(asImpl().getFieldLayout().getSize().getValue()); } else { // Leave a zero placeholder to be filled at runtime B.addInt32(0); @@ -2842,8 +2903,8 @@ namespace { } void addInstanceAlignMask() { - if (FieldLayout.isFixedLayout()) { - B.addInt16(FieldLayout.getAlignMask().getValue()); + if (asImpl().hasFixedLayout()) { + B.addInt16(asImpl().getFieldLayout().getAlignMask().getValue()); } else { // Leave a zero placeholder to be filled at runtime B.addInt16(0); @@ -2873,6 +2934,12 @@ namespace { B.add(IGM.getObjCEmptyVTablePtr()); } + llvm::Constant *getROData() { return emitClassPrivateData(IGM, Target); } + + uint64_t getClassDataPointerHasSwiftMetadataBits() { + return IGM.UseDarwinPreStableABIBit ? 1 : 2; + } + void addClassDataPointer() { if (!IGM.ObjCInterop) { // with no Objective-C runtime, just give an empty pointer with the @@ -2883,11 +2950,11 @@ namespace { } // Derive the RO-data. - llvm::Constant *data = emitClassPrivateData(IGM, Target); + llvm::Constant *data = asImpl().getROData(); // Set a low bit to indicate this class has Swift metadata. - auto bit = llvm::ConstantInt::get(IGM.IntPtrTy, - IGM.UseDarwinPreStableABIBit ? 1 : 2); + auto bit = llvm::ConstantInt::get( + IGM.IntPtrTy, asImpl().getClassDataPointerHasSwiftMetadataBits()); // Emit data + bit. data = llvm::ConstantExpr::getPtrToInt(data, IGM.IntPtrTy); @@ -2941,6 +3008,14 @@ namespace { } }; + static void + addFixedFieldOffset(IRGenModule &IGM, ConstantStructBuilder &B, VarDecl *var, + std::function typeFromContext) { + SILType baseType = SILType::getPrimitiveObjectType( + typeFromContext(var->getDeclContext())->getCanonicalType()); + B.addInt(IGM.SizeTy, getClassFieldOffset(IGM, baseType, var).getValue()); + } + /// A builder for non-generic class metadata which does not require any /// runtime initialization, or that only requires runtime initialization /// on newer Objective-C runtimes. @@ -2957,10 +3032,9 @@ namespace { : super(IGM, theClass, builder, fieldLayout) {} void addFieldOffset(VarDecl *var) { - SILType baseType = SILType::getPrimitiveObjectType( - var->getDeclContext()->getDeclaredTypeInContext() - ->getCanonicalType()); - B.addInt(IGM.SizeTy, getClassFieldOffset(IGM, baseType, var).getValue()); + addFixedFieldOffset(IGM, B, var, [](DeclContext *dc) { + return dc->getDeclaredTypeInContext(); + }); } void addFieldOffsetPlaceholders(MissingMemberDecl *placeholder) { @@ -2982,6 +3056,7 @@ namespace { /// fields or generic ancestry. class SingletonClassMetadataBuilder : public ClassMetadataBuilderBase { + using NominalDecl = StructDecl; using super = ClassMetadataBuilderBase; using super::IGM; using super::B; @@ -3287,6 +3362,183 @@ namespace { metadata, collector); } }; + + template