Skip to content

Commit c75bc66

Browse files
committed
[AutoDiff] Support custom derivatives for @_alwaysEmitIntoClient functions
Fixes #54445
1 parent 09d122a commit c75bc66

File tree

30 files changed

+464
-28
lines changed

30 files changed

+464
-28
lines changed

include/swift/AST/DiagnosticsSema.def

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4290,6 +4290,9 @@ NOTE(derivative_attr_fix_access,none,
42904290
"mark the derivative function as "
42914291
"'%select{private|fileprivate|internal|package|@usableFromInline|@usableFromInline}0' "
42924292
"to match the original function", (AccessLevel))
4293+
ERROR(derivative_attr_always_emit_into_client_mismatch,none,
4294+
"either both or none of derivative and original function must have "
4295+
"@alwaysEmitIntoClient attribute", ())
42934296
ERROR(derivative_attr_static_method_mismatch_original,none,
42944297
"unexpected derivative function declaration; "
42954298
"%0 requires the derivative function %1 to be %select{an instance|a 'static'}2 method",

lib/SIL/IR/Linker.cpp

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -159,9 +159,29 @@ void SILLinkerVisitor::maybeAddFunctionToWorklist(
159159
// HiddenExternal linkage when they are declarations, then they
160160
// become Shared after the body has been deserialized.
161161
// So try deserializing HiddenExternal functions too.
162-
if (linkage == SILLinkage::HiddenExternal)
163-
return deserializeAndPushToWorklist(F);
164-
162+
if (linkage == SILLinkage::HiddenExternal) {
163+
deserializeAndPushToWorklist(F);
164+
if (!F->markedAsAlwaysEmitIntoClient())
165+
return;
166+
for (SILDifferentiabilityWitness &witness :
167+
F->getModule().getDifferentiabilityWitnesses()) {
168+
if (witness.getOriginalFunction() != F)
169+
continue;
170+
SILDifferentiabilityWitness *loadedWitness =
171+
F->getModule().getSILLoader()->lookupDifferentiabilityWitness(
172+
witness.getKey());
173+
if (loadedWitness == nullptr)
174+
continue;
175+
assert(loadedWitness == &witness);
176+
SILFunction *jvp = loadedWitness->getJVP();
177+
SILFunction *vjp = loadedWitness->getVJP();
178+
assert(jvp && vjp);
179+
deserializeAndPushToWorklist(jvp);
180+
deserializeAndPushToWorklist(vjp);
181+
}
182+
return;
183+
}
184+
165185
// Update the linkage of the function in case it's different in the serialized
166186
// SIL than derived from the AST. This can be the case with cross-module-
167187
// optimizations.

lib/SILGen/SILGen.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1401,13 +1401,15 @@ void SILGenModule::emitDifferentiabilityWitness(
14011401
if (!diffWitness) {
14021402
// Differentiability witnesses have the same linkage as the original
14031403
// function, stripping external.
1404-
auto linkage = stripExternalFromLinkage(originalFunction->getLinkage());
1404+
auto linkage =
1405+
originalFunction->markedAsAlwaysEmitIntoClient()
1406+
? SILLinkage::PublicNonABI
1407+
: stripExternalFromLinkage(originalFunction->getLinkage());
14051408
diffWitness = SILDifferentiabilityWitness::createDefinition(
14061409
M, linkage, originalFunction, diffKind, silConfig.parameterIndices,
14071410
silConfig.resultIndices, config.derivativeGenericSignature,
14081411
/*jvp*/ nullptr, /*vjp*/ nullptr,
1409-
/*isSerialized*/ hasPublicVisibility(originalFunction->getLinkage()),
1410-
attr);
1412+
/*isSerialized*/ hasPublicVisibility(linkage), attr);
14111413
}
14121414

14131415
// Set derivative function in differentiability witness.

lib/SILGen/SILGenPoly.cpp

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6282,10 +6282,19 @@ SILFunction *SILGenModule::getOrCreateCustomDerivativeThunk(
62826282
SILGenFunctionBuilder fb(*this);
62836283
// Derivative thunks have the same linkage as the original function, stripping
62846284
// external.
6285-
auto linkage = stripExternalFromLinkage(originalFn->getLinkage());
6285+
auto linkage = originalFn->markedAsAlwaysEmitIntoClient()
6286+
? SILLinkage::PublicNonABI
6287+
: stripExternalFromLinkage(originalFn->getLinkage());
6288+
6289+
auto serializedKind = customDerivativeFn->getSerializedKind();
6290+
// See comment for an identical if statement in
6291+
// DifferentiationTransformer::canonicalizeDifferentiabilityWitness.
6292+
if (originalFn->getLinkage() == SILLinkage::HiddenExternal &&
6293+
!originalFn->markedAsAlwaysEmitIntoClient())
6294+
serializedKind = IsNotSerialized;
6295+
62866296
auto *thunk = fb.getOrCreateFunction(
6287-
loc, name, linkage, thunkFnTy, IsBare, IsNotTransparent,
6288-
customDerivativeFn->getSerializedKind(),
6297+
loc, name, linkage, thunkFnTy, IsBare, IsNotTransparent, serializedKind,
62896298
customDerivativeFn->isDynamicallyReplaceable(),
62906299
customDerivativeFn->isDistributed(),
62916300
customDerivativeFn->isRuntimeAccessible(),

lib/SILOptimizer/Differentiation/Common.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -538,9 +538,11 @@ SILDifferentiabilityWitness *getOrCreateMinimalASTDifferentiabilityWitness(
538538
"definitions with explicit differentiable attributes");
539539

540540
return SILDifferentiabilityWitness::createDeclaration(
541-
module, SILLinkage::PublicExternal, original, kind,
542-
minimalConfig->parameterIndices, minimalConfig->resultIndices,
543-
minimalConfig->derivativeGenericSignature);
541+
module,
542+
original->markedAsAlwaysEmitIntoClient() ? SILLinkage::PublicNonABI
543+
: SILLinkage::PublicExternal,
544+
original, kind, minimalConfig->parameterIndices,
545+
minimalConfig->resultIndices, minimalConfig->derivativeGenericSignature);
544546
}
545547

546548
} // end namespace autodiff

lib/SILOptimizer/Mandatory/Differentiation.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -914,7 +914,8 @@ bool DifferentiationTransformer::canonicalizeDifferentiabilityWitness(
914914
// flag. Important exception here hidden_external functions as they are
915915
// serializable but corresponding hidden ones would be not and the SIL
916916
// verifier will fail. Patch `serializeFunctions` for this case.
917-
if (orig->getLinkage() == SILLinkage::HiddenExternal)
917+
if (orig->getLinkage() == SILLinkage::HiddenExternal &&
918+
!orig->markedAsAlwaysEmitIntoClient())
918919
serializeFunctions = IsNotSerialized;
919920

920921
// If the JVP doesn't exist, need to synthesize it.

lib/Sema/TypeCheckAttr.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6743,6 +6743,13 @@ static bool typeCheckDerivativeAttr(DerivativeAttr *attr) {
67436743
return true;
67446744
}
67456745

6746+
if (originalAFD->getAttrs().hasAttribute<AlwaysEmitIntoClientAttr>() !=
6747+
derivative->getAttrs().hasAttribute<AlwaysEmitIntoClientAttr>()) {
6748+
diags.diagnose(derivative->getLoc(),
6749+
diag::derivative_attr_always_emit_into_client_mismatch);
6750+
return true;
6751+
}
6752+
67466753
// Get the resolved differentiability parameter indices.
67476754
auto *resolvedDiffParamIndices = attr->getParameterIndices();
67486755

stdlib/public/Differentiation/SIMDDifferentiation.swift.gyb

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -405,9 +405,6 @@ where
405405
}
406406
}
407407

408-
// FIXME(TF-1103): Derivative registration does not yet support
409-
// `@_alwaysEmitIntoClient` original functions like `SIMD.sum()`.
410-
/*
411408
extension SIMD
412409
where
413410
Self: Differentiable,
@@ -417,6 +414,7 @@ where
417414
TangentVector == Self
418415
{
419416
@inlinable
417+
@_alwaysEmitIntoClient
420418
@derivative(of: sum)
421419
func _vjpSum() -> (
422420
value: Scalar, pullback: (Scalar.TangentVector) -> TangentVector
@@ -425,14 +423,14 @@ where
425423
}
426424

427425
@inlinable
426+
@_alwaysEmitIntoClient
428427
@derivative(of: sum)
429428
func _jvpSum() -> (
430429
value: Scalar, differential: (TangentVector) -> Scalar.TangentVector
431430
) {
432431
return (sum(), { v in Scalar.TangentVector(v.sum()) })
433432
}
434433
}
435-
*/
436434

437435
extension SIMD
438436
where

test/AutoDiff/SILGen/nil_coalescing.swift

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1-
// RUN: %target-swift-frontend -Xllvm -sil-print-types -emit-sil -verify %s | %FileCheck %s
1+
/// Note: -primary-file prevents non_abi->shared linkage change in `removeSerializedFlagFromAllFunctions`
2+
// RUN: %target-swift-frontend -Xllvm -sil-print-types -emit-sil -verify -primary-file %s | %FileCheck %s
23

34
import _Differentiation
45

5-
// CHECK: sil @test_nil_coalescing
6+
// CHECK: sil non_abi @test_nil_coalescing
67
// CHECK: bb0(%{{.*}} : $*T, %[[ARG_OPT:.*]] : $*Optional<T>, %[[ARG_PB:.*]] :
78
// CHECK: $@noescape @callee_guaranteed @substituted <τ_0_0> () -> (@out τ_0_0, @error any Error) for <T>):
89
// CHECK: %[[ALLOC_OPT:.*]] = alloc_stack [lexical] $Optional<T>
@@ -15,7 +16,7 @@ import _Differentiation
1516
//
1617
@_silgen_name("test_nil_coalescing")
1718
@derivative(of: ??)
18-
@usableFromInline
19+
@_alwaysEmitIntoClient
1920
func nilCoalescing<T: Differentiable>(optional: T?, defaultValue: @autoclosure () throws -> T)
2021
rethrows -> (value: T, pullback: (T.TangentVector) -> Optional<T>.TangentVector)
2122
{

test/AutoDiff/Sema/derivative_attr_type_checking.swift

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1062,6 +1062,15 @@ func _internal_original_inlinable_derivative(_ x: Float) -> (value: Float, pullb
10621062
fatalError()
10631063
}
10641064

1065+
func internal_original_alwaysemitintoclient_derivative_error(_ x: Float) -> Float { x }
1066+
@_alwaysEmitIntoClient
1067+
@derivative(of: internal_original_alwaysemitintoclient_derivative_error)
1068+
// expected-error @+1 {{either both or none of derivative and original function must have @alwaysEmitIntoClient attribute}}
1069+
func _internal_original_alwaysemitintoclient_derivative_error(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
1070+
fatalError()
1071+
}
1072+
1073+
@_alwaysEmitIntoClient
10651074
func internal_original_alwaysemitintoclient_derivative(_ x: Float) -> Float { x }
10661075
@_alwaysEmitIntoClient
10671076
@derivative(of: internal_original_alwaysemitintoclient_derivative)
@@ -1084,6 +1093,15 @@ package func _package_original_inlinable_derivative(_ x: Float) -> (value: Float
10841093
fatalError()
10851094
}
10861095

1096+
@_alwaysEmitIntoClient
1097+
package func package_original_alwaysemitintoclient_derivative_error(_ x: Float) -> Float { x }
1098+
@derivative(of: package_original_alwaysemitintoclient_derivative_error)
1099+
// expected-error @+1 {{either both or none of derivative and original function must have @alwaysEmitIntoClient attribute}}
1100+
package func _package_original_alwaysemitintoclient_derivative_error(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
1101+
fatalError()
1102+
}
1103+
1104+
@_alwaysEmitIntoClient
10871105
package func package_original_alwaysemitintoclient_derivative(_ x: Float) -> Float { x }
10881106
@_alwaysEmitIntoClient
10891107
@derivative(of: package_original_alwaysemitintoclient_derivative)

0 commit comments

Comments
 (0)