Skip to content

[AutoDiff] Check derivative function @usableFromInline consistency. #31724

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions include/swift/AST/DiagnosticsSema.def
Original file line number Diff line number Diff line change
Expand Up @@ -3076,6 +3076,15 @@ NOTE(derivative_attr_fix_access,none,
"mark the derivative function as "
"'%select{private|fileprivate|internal|@usableFromInline|@usableFromInline}0' "
"to match the original function", (AccessLevel))
ERROR(derivative_attr_usable_from_inline_mismatch,none,
"non-'@usableFromInline' original function must not have a "
"'@usableFromInline' derivative function", ())
NOTE(derivative_attr_fix_add_usable_from_inline,none,
"consider adding '@usableFromInline' to the original function %0",
(DeclName))
NOTE(derivative_attr_fix_remove_usable_from_inline,none,
"consider removing '@usableFromInline' from the derivative function %0",
(DeclName))

// @transpose
ERROR(transpose_attr_invalid_linearity_parameter_or_result,none,
Expand Down
30 changes: 29 additions & 1 deletion lib/Sema/TypeCheckAttr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4540,11 +4540,22 @@ static bool typeCheckDerivativeAttr(ASTContext &Ctx, Decl *D,
}
attr->setOriginalFunction(originalAFD);

bool usableFromInlineMismatch =
originalAFD->getFormalAccess() <= AccessLevel::Internal &&
derivative->getFormalAccess() <= AccessLevel::Internal &&
!originalAFD->isUsableFromInline() && derivative->isUsableFromInline();

bool sameAccessLevel =
originalAFD->getFormalAccess() == derivative->getFormalAccess();

// Returns true if:
// - Original function and derivative function have the same access level.
// - Original function and derivative function have the same access level and
// the same `@usableFromInline` status.
// - Original function is public and derivative function is internal
// `@usableFromInline`. This is the only special case.
auto compatibleAccessLevels = [&]() {
if (usableFromInlineMismatch)
return false;
if (originalAFD->getFormalAccess() == derivative->getFormalAccess())
return true;
return originalAFD->getFormalAccess() == AccessLevel::Public &&
Expand All @@ -4553,6 +4564,23 @@ static bool typeCheckDerivativeAttr(ASTContext &Ctx, Decl *D,

// Check access level compatibility for original and derivative functions.
if (!compatibleAccessLevels()) {
// Diagnose if access levels match, but derivative is `@usableFromInline`
// while original is not.
if (sameAccessLevel && usableFromInlineMismatch) {
diags.diagnose(originalName.Loc,
diag::derivative_attr_usable_from_inline_mismatch);
// Suggest adding `@usableFromInline` to original.
originalAFD
->diagnose(diag::derivative_attr_fix_add_usable_from_inline,
originalAFD->getName())
.fixItInsert(
originalAFD->getAttributeInsertionLoc(/*forModifier*/ false),
"@usableFromInline ");
// Suggest remove `@usableFromInline` from derivative.
derivative->diagnose(diag::derivative_attr_fix_remove_usable_from_inline,
Copy link
Contributor Author

@dan-zheng dan-zheng May 12, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This "remove attribute" note isn't entirely accurate: the derivative function may have an @inlinable or @_alwaysEmitIntoClient attribute instead of @usableFromInline.

I think the current message and source location are reasonable though. It seems more fragile to make the diagnostic exactly accurate (checking for one of the three attributes above and putting the location and a removal fix-it there).

derivative->getName());
return true;
}
auto originalAccess = originalAFD->getFormalAccess();
auto derivativeAccess =
derivative->getFormalAccessScope().accessLevelForDiagnostics();
Expand Down
32 changes: 32 additions & 0 deletions test/AutoDiff/Sema/derivative_attr_type_checking.swift
Original file line number Diff line number Diff line change
Expand Up @@ -922,3 +922,35 @@ func internal_original_fileprivate_derivative(_ x: Float) -> Float { x }
fileprivate func _internal_original_fileprivate_derivative(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
fatalError()
}

// MARK: - Original vs derivative `@usableFromInline` mismatch

// expected-note @+1 {{consider adding '@usableFromInline' to the original function 'internal_original_usablefrominline_derivative'}}
func internal_original_usablefrominline_derivative(_ x: Float) -> Float { x }
@usableFromInline
// expected-error @+1 {{non-'@usableFromInline' original function must not have a '@usableFromInline' derivative function}}
@derivative(of: internal_original_usablefrominline_derivative)
// expected-note @+1 {{consider removing '@usableFromInline' from the derivative function '_internal_original_usablefrominline_derivative'}}
func _internal_original_usablefrominline_derivative(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
fatalError()
}

// expected-note @+1 {{consider adding '@usableFromInline' to the original function 'internal_original_inlinable_derivative'}}
func internal_original_inlinable_derivative(_ x: Float) -> Float { x }
@inlinable
// expected-error @+1 {{non-'@usableFromInline' original function must not have a '@usableFromInline' derivative function}}
@derivative(of: internal_original_inlinable_derivative)
// expected-note @+1 {{consider removing '@usableFromInline' from the derivative function '_internal_original_inlinable_derivative'}}
func _internal_original_inlinable_derivative(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
fatalError()
}

// expected-note @+1 {{consider adding '@usableFromInline' to the original function 'internal_original_alwaysemitintoclient_derivative'}}
func internal_original_alwaysemitintoclient_derivative(_ x: Float) -> Float { x }
@inlinable
// expected-error @+1 {{non-'@usableFromInline' original function must not have a '@usableFromInline' derivative function}}
@derivative(of: internal_original_alwaysemitintoclient_derivative)
// expected-note @+1 {{consider removing '@usableFromInline' from the derivative function '_internal_original_alwaysemitintoclient_derivative'}}
func _internal_original_alwaysemitintoclient_derivative(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
fatalError()
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// RUN: %target-swift-frontend -c %s -verify
// REQUIRES: asserts

// TF-1160: Linker error for `@usableFromInline` derivative function but
// non-`@usableFromInline` internal original function.

import _Differentiation

// expected-note @+1 {{consider adding '@usableFromInline' to the original function 'internalOriginal'}}
func internalOriginal(_ x: Float) -> Float {
x
}

@usableFromInline
// expected-error @+1 {{non-'@usableFromInline' original function must not have a '@usableFromInline' derivative function}}
@derivative(of: internalOriginal)
// expected-note @+1 {{consider removing '@usableFromInline' from the derivative function 'usableFromInlineDerivative'}}
func usableFromInlineDerivative(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
(x, { $0 })
}

// Original error: type-checking passes but TBDGen is not consistent with IRGen.
// <unknown>:0: error: symbol 'AD__$s4main16internalOriginalyS2fF__vjp_src_0_wrt_0' (AD__$s4main16internalOriginalyS2fF__vjp_src_0_wrt_0) is in generated IR file, but not in TBD file
// <unknown>:0: error: please file a radar or open a bug on bugs.swift.org with this code, and add -Xfrontend -validate-tbd-against-ir=none to squash the errors