From eab465e1821c647c372e1497d6fcaefdd4cd40d1 Mon Sep 17 00:00:00 2001 From: Dan Zheng Date: Sun, 3 May 2020 19:43:06 -0700 Subject: [PATCH 1/2] [AutoDiff] Fix `@differentiable` attribute derivative configurations. In `AbstractFunctionDecl::getDerivativeFunctionConfigurations`, type-check `@differentiable` attributes. This is important to populate derivative configurations for original functions in other files. Resolves TF-1271. Exposes TF-1272: fix derivative configurations for cross-file `@derivative` attributes. This is a more difficult issue. --- lib/AST/Decl.cpp | 4 +++ ...fferentiation_diagnostics_other_file.swift | 19 ++++++++++++++ ...fferentiation_diagnostics_cross_file.swift | 25 +++++++++++++++++++ 3 files changed, 48 insertions(+) create mode 100644 test/AutoDiff/SILOptimizer/Inputs/differentiation_diagnostics_other_file.swift create mode 100644 test/AutoDiff/SILOptimizer/differentiation_diagnostics_cross_file.swift diff --git a/lib/AST/Decl.cpp b/lib/AST/Decl.cpp index b28ab04ce3fd3..21b1b354ab1b9 100644 --- a/lib/AST/Decl.cpp +++ b/lib/AST/Decl.cpp @@ -7106,6 +7106,10 @@ void AbstractFunctionDecl::prepareDerivativeFunctionConfigurations() { ArrayRef AbstractFunctionDecl::getDerivativeFunctionConfigurations() { prepareDerivativeFunctionConfigurations(); + // Check `@differentiable` attributes. + for (auto *diffAttr : getAttrs().getAttributes()) + (void)diffAttr->getParameterIndices(); + // Load derivative configurations from imported modules. auto &ctx = getASTContext(); if (ctx.getCurrentGeneration() > DerivativeFunctionConfigGeneration) { unsigned previousGeneration = DerivativeFunctionConfigGeneration; diff --git a/test/AutoDiff/SILOptimizer/Inputs/differentiation_diagnostics_other_file.swift b/test/AutoDiff/SILOptimizer/Inputs/differentiation_diagnostics_other_file.swift new file mode 100644 index 0000000000000..818d07a5a446a --- /dev/null +++ b/test/AutoDiff/SILOptimizer/Inputs/differentiation_diagnostics_other_file.swift @@ -0,0 +1,19 @@ +import _Differentiation + +protocol Protocol: Differentiable { + // Test cross-file `@differentiable` attribute. + @differentiable(wrt: self) + func identityDifferentiableAttr() -> Self +} + +extension Protocol { + func identityDerivativeAttr() -> Self { self } + + // Test cross-file `@derivative` attribute. + @derivative(of: identityDerivativeAttr) + func vjpIdentityDerivativeAttr() -> ( + value: Self, pullback: (TangentVector) -> TangentVector + ) { + fatalError() + } +} diff --git a/test/AutoDiff/SILOptimizer/differentiation_diagnostics_cross_file.swift b/test/AutoDiff/SILOptimizer/differentiation_diagnostics_cross_file.swift new file mode 100644 index 0000000000000..95e048e1864f4 --- /dev/null +++ b/test/AutoDiff/SILOptimizer/differentiation_diagnostics_cross_file.swift @@ -0,0 +1,25 @@ +// RUN: %target-swift-frontend -emit-sil -verify -primary-file %s %S/Inputs/differentiation_diagnostics_other_file.swift -module-name main -o /dev/null + +// Test differentiation transform cross-file diagnostics. + +import _Differentiation + +// TF-1271: Test `@differentiable` original function in other file. +@differentiable +func crossFileDifferentiableAttr( + _ input: T +) -> T { + return input.identityDifferentiableAttr() +} + +// TF-1272: Test original function with registered derivatives in other files. +// FIXME(TF-1272): Find a way to type-check `@derivative` attributes in other +// files. +@differentiable +func crossFileDerivativeAttr( + _ input: T +) -> T { + // expected-error @+2 {{expression is not differentiable}} + // expected-note @+1 {{cannot differentiate functions that have not been marked '@differentiable' and that are defined in other files}} + return input.identityDerivativeAttr() +} From 80b0258621849f5ab4c8e33c97a4567420f55cdd Mon Sep 17 00:00:00 2001 From: Dan Zheng Date: Sun, 3 May 2020 22:32:52 -0700 Subject: [PATCH 2/2] Improve explanatory comment. --- lib/AST/Decl.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/lib/AST/Decl.cpp b/lib/AST/Decl.cpp index 21b1b354ab1b9..ea072f6537180 100644 --- a/lib/AST/Decl.cpp +++ b/lib/AST/Decl.cpp @@ -7106,7 +7106,8 @@ void AbstractFunctionDecl::prepareDerivativeFunctionConfigurations() { ArrayRef AbstractFunctionDecl::getDerivativeFunctionConfigurations() { prepareDerivativeFunctionConfigurations(); - // Check `@differentiable` attributes. + // Resolve derivative function configurations from `@differentiable` + // attributes by type-checking them. for (auto *diffAttr : getAttrs().getAttributes()) (void)diffAttr->getParameterIndices(); // Load derivative configurations from imported modules.