Skip to content

[AutoDiff] Require same access level for original/derivative functions. #31527

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 6, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions include/swift/AST/DiagnosticsSema.def
Original file line number Diff line number Diff line change
Expand Up @@ -3054,6 +3054,18 @@ ERROR(derivative_attr_original_already_has_derivative,none,
"a derivative already exists for %0", (DeclName))
NOTE(derivative_attr_duplicate_note,none,
"other attribute declared here", ())
ERROR(derivative_attr_access_level_mismatch,none,
"derivative function must have same access level as original function; "
"derivative function %2 is "
"%select{private|fileprivate|internal|public|open}3, "
"but original function %0 is "
"%select{private|fileprivate|internal|public|open}1",
(/*original*/ DeclName, /*original*/ AccessLevel,
/*derivative*/ DeclName, /*derivative*/ AccessLevel))
NOTE(derivative_attr_fix_access,none,
"mark the derivative function as "
"'%select{private|fileprivate|internal|@usableFromInline|@usableFromInline}0' "
"to match the original function", (AccessLevel))

// @transpose
ERROR(transpose_attr_invalid_linearity_parameter_or_result,none,
Expand Down
36 changes: 36 additions & 0 deletions lib/Sema/TypeCheckAttr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4540,6 +4540,42 @@ static bool typeCheckDerivativeAttr(ASTContext &Ctx, Decl *D,
}
attr->setOriginalFunction(originalAFD);

// Returns true if:
// - Original function and derivative function have the same access level.
// - Original function is public and derivative function is internal
// `@usableFromInline`. This is the only special case.
auto compatibleAccessLevels = [&]() {
if (originalAFD->getFormalAccess() == derivative->getFormalAccess())
return true;
return originalAFD->getFormalAccess() == AccessLevel::Public &&
derivative->getEffectiveAccess() == AccessLevel::Public;
};

// Check access level compatibility for original and derivative functions.
if (!compatibleAccessLevels()) {
auto originalAccess = originalAFD->getFormalAccess();
auto derivativeAccess =
derivative->getFormalAccessScope().accessLevelForDiagnostics();
diags.diagnose(originalName.Loc,
diag::derivative_attr_access_level_mismatch,
originalAFD->getName(), originalAccess,
derivative->getName(), derivativeAccess);
auto fixItDiag =
derivative->diagnose(diag::derivative_attr_fix_access, originalAccess);
// If original access is public, suggest adding `@usableFromInline` to
// derivative.
if (originalAccess == AccessLevel::Public) {
fixItDiag.fixItInsert(
derivative->getAttributeInsertionLoc(/*forModifier*/ false),
"@usableFromInline ");
}
// Otherwise, suggest changing derivative access level.
else {
fixItAccess(fixItDiag, derivative, originalAccess);
}
return true;
}

// Get the resolved differentiability parameter indices.
auto *resolvedDiffParamIndices = attr->getParameterIndices();

Expand Down
2 changes: 1 addition & 1 deletion test/AutoDiff/SILGen/sil_differentiability_witness.swift
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ public func foo_vjp(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
func bar<T>(_ x: Float, _ y: T) -> Float { x }

@derivative(of: bar)
public func bar_jvp<T>(_ x: Float, _ y: T) -> (value: Float, differential: (Float) -> Float) {
func bar_jvp<T>(_ x: Float, _ y: T) -> (value: Float, differential: (Float) -> Float) {
(x, { $0 })
}

Expand Down
120 changes: 120 additions & 0 deletions test/AutoDiff/Sema/derivative_attr_type_checking.swift
Original file line number Diff line number Diff line change
Expand Up @@ -746,6 +746,7 @@ extension InoutParameters {
// Test cross-file derivative registration.

extension FloatingPoint where Self: Differentiable {
@usableFromInline
@derivative(of: rounded)
func vjpRounded() -> (
value: Self,
Expand Down Expand Up @@ -802,3 +803,122 @@ extension HasADefaultDerivative {
(x, { 10 * $0 })
}
}

// MARK: - Original function visibility = derivative function visibility

public func public_original_public_derivative(_ x: Float) -> Float { x }
@derivative(of: public_original_public_derivative)
public func _public_original_public_derivative(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
fatalError()
}

public func public_original_usablefrominline_derivative(_ x: Float) -> Float { x }
@usableFromInline
@derivative(of: public_original_usablefrominline_derivative)
func _public_original_usablefrominline_derivative(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
fatalError()
}

func internal_original_internal_derivative(_ x: Float) -> Float { x }
@derivative(of: internal_original_internal_derivative)
func _internal_original_internal_derivative(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
fatalError()
}

private func private_original_private_derivative(_ x: Float) -> Float { x }
@derivative(of: private_original_private_derivative)
private func _private_original_private_derivative(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
fatalError()
}

fileprivate func fileprivate_original_fileprivate_derivative(_ x: Float) -> Float { x }
@derivative(of: fileprivate_original_fileprivate_derivative)
fileprivate func _fileprivate_original_fileprivate_derivative(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
fatalError()
}

// MARK: - Original function visibility < derivative function visibility

@usableFromInline
func usablefrominline_original_public_derivative(_ x: Float) -> Float { x }
// expected-error @+1 {{derivative function must have same access level as original function; derivative function '_usablefrominline_original_public_derivative' is public, but original function 'usablefrominline_original_public_derivative' is internal}}
@derivative(of: usablefrominline_original_public_derivative)
// expected-note @+1 {{mark the derivative function as 'internal' to match the original function}} {{1-7=internal}}
public func _usablefrominline_original_public_derivative(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
fatalError()
}

func internal_original_public_derivative(_ x: Float) -> Float { x }
// expected-error @+1 {{derivative function must have same access level as original function; derivative function '_internal_original_public_derivative' is public, but original function 'internal_original_public_derivative' is internal}}
@derivative(of: internal_original_public_derivative)
// expected-note @+1 {{mark the derivative function as 'internal' to match the original function}} {{1-7=internal}}
public func _internal_original_public_derivative(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
fatalError()
}

private func private_original_usablefrominline_derivative(_ x: Float) -> Float { x }
// expected-error @+1 {{derivative function must have same access level as original function; derivative function '_private_original_usablefrominline_derivative' is internal, but original function 'private_original_usablefrominline_derivative' is private}}
@derivative(of: private_original_usablefrominline_derivative)
@usableFromInline
// expected-note @+1 {{mark the derivative function as 'private' to match the original function}} {{1-1=private }}
func _private_original_usablefrominline_derivative(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
fatalError()
}

private func private_original_public_derivative(_ x: Float) -> Float { x }
// expected-error @+1 {{derivative function must have same access level as original function; derivative function '_private_original_public_derivative' is public, but original function 'private_original_public_derivative' is private}}
@derivative(of: private_original_public_derivative)
// expected-note @+1 {{mark the derivative function as 'private' to match the original function}} {{1-7=private}}
public func _private_original_public_derivative(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
fatalError()
}

private func private_original_internal_derivative(_ x: Float) -> Float { x }
// expected-error @+1 {{derivative function must have same access level as original function; derivative function '_private_original_internal_derivative' is internal, but original function 'private_original_internal_derivative' is private}}
@derivative(of: private_original_internal_derivative)
// expected-note @+1 {{mark the derivative function as 'private' to match the original function}}
func _private_original_internal_derivative(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
fatalError()
}

fileprivate func fileprivate_original_private_derivative(_ x: Float) -> Float { x }
// expected-error @+1 {{derivative function must have same access level as original function; derivative function '_fileprivate_original_private_derivative' is private, but original function 'fileprivate_original_private_derivative' is fileprivate}}
@derivative(of: fileprivate_original_private_derivative)
// expected-note @+1 {{mark the derivative function as 'fileprivate' to match the original function}} {{1-8=fileprivate}}
private func _fileprivate_original_private_derivative(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
fatalError()
}

private func private_original_fileprivate_derivative(_ x: Float) -> Float { x }
// expected-error @+1 {{derivative function must have same access level as original function; derivative function '_private_original_fileprivate_derivative' is fileprivate, but original function 'private_original_fileprivate_derivative' is private}}
@derivative(of: private_original_fileprivate_derivative)
// expected-note @+1 {{mark the derivative function as 'private' to match the original function}} {{1-12=private}}
fileprivate func _private_original_fileprivate_derivative(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
fatalError()
}

// MARK: - Original function visibility > derivative function visibility

public func public_original_private_derivative(_ x: Float) -> Float { x }
// expected-error @+1 {{derivative function must have same access level as original function; derivative function '_public_original_private_derivative' is fileprivate, but original function 'public_original_private_derivative' is public}}
@derivative(of: public_original_private_derivative)
// expected-note @+1 {{mark the derivative function as '@usableFromInline' to match the original function}} {{1-1=@usableFromInline }}
fileprivate func _public_original_private_derivative(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
fatalError()
}

public func public_original_internal_derivative(_ x: Float) -> Float { x }
// expected-error @+1 {{derivative function must have same access level as original function; derivative function '_public_original_internal_derivative' is internal, but original function 'public_original_internal_derivative' is public}}
@derivative(of: public_original_internal_derivative)
// expected-note @+1 {{mark the derivative function as '@usableFromInline' to match the original function}} {{1-1=@usableFromInline }}
func _public_original_internal_derivative(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
fatalError()
}

func internal_original_fileprivate_derivative(_ x: Float) -> Float { x }
// expected-error @+1 {{derivative function must have same access level as original function; derivative function '_internal_original_fileprivate_derivative' is fileprivate, but original function 'internal_original_fileprivate_derivative' is internal}}
@derivative(of: internal_original_fileprivate_derivative)
// expected-note @+1 {{mark the derivative function as 'internal' to match the original function}} {{1-12=internal}}
fileprivate func _internal_original_fileprivate_derivative(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
fatalError()
}