Skip to content

[AutoDiff] Support differentiation of wrapped properties. #31173

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions include/swift/AST/DiagnosticsSema.def
Original file line number Diff line number Diff line change
Expand Up @@ -2734,13 +2734,21 @@ WARNING(differentiable_nondiff_type_implicit_noderivative_fixit,none,
"stored property %0 has no derivative because %1 does not conform to "
"'Differentiable'; add an explicit '@noDerivative' attribute"
"%select{|, or conform %2 to 'AdditiveArithmetic'}3",
(Identifier, Type, Identifier, bool))
WARNING(differentiable_let_property_implicit_noderivative_fixit,none,
(/*propName*/ Identifier, /*propType*/ Type, /*nominalName*/ Identifier,
/*nominalCanDeriveAdditiveArithmetic*/ bool))
WARNING(differentiable_immutable_wrapper_implicit_noderivative_fixit,none,
"synthesis of the 'Differentiable.move(along:)' requirement for %1 "
"requires all stored properties to be mutable; use 'var' instead, or add "
"an explicit '@noDerivative' attribute"
"requires all stored properties not marked with `@noDerivative` to be "
"mutable; add an explicit '@noDerivative' attribute"
"%select{|, or conform %1 to 'AdditiveArithmetic'}2",
(Identifier, Identifier, bool))
(/*wrapperType*/ StringRef, /*nominalName*/ Identifier,
/*nominalCanDeriveAdditiveArithmetic*/ bool))
WARNING(differentiable_let_property_implicit_noderivative_fixit,none,
"synthesis of the 'Differentiable.move(along:)' requirement for %0 "
"requires all stored properties not marked with `@noDerivative` to be "
"mutable; use 'var' instead, or add an explicit '@noDerivative' attribute"
"%select{|, or conform %0 to 'AdditiveArithmetic'}1",
(/*nominalName*/ Identifier, /*nominalCanDeriveAdditiveArithmetic*/ bool))

NOTE(codable_extraneous_codingkey_case_here,none,
"CodingKey case %0 does not match any stored properties", (Identifier))
Expand Down
1 change: 1 addition & 0 deletions include/swift/SILOptimizer/Differentiation/AdjointValue.h
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ class AdjointValue final {
break;
}
}
SWIFT_DEBUG_DUMP { print(llvm::dbgs()); };
};

inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
Expand Down
17 changes: 17 additions & 0 deletions include/swift/SILOptimizer/Differentiation/Common.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,23 @@ ApplyInst *getAllocateUninitializedArrayIntrinsicElementAddress(SILValue v);
/// tuple-typed and such a user exists.
DestructureTupleInst *getSingleDestructureTupleUser(SILValue value);

/// Returns true if the given original function is a "semantic member accessor".
///
/// "Semantic member accessors" are attached to member properties that have a
/// corresponding tangent stored property in the parent `TangentVector` type.
/// These accessors have special-case pullback generation based on their
/// semantic behavior.
///
/// "Semantic member accessors" currently include:
/// - Stored property accessors. These are implicitly generated.
/// - Property wrapper wrapped value accessors. These are implicitly generated
/// and internally call `var wrappedValue`.
bool isSemanticMemberAccessor(SILFunction *original);

/// Returns true if the given apply site has a "semantic member accessor"
/// callee.
bool hasSemanticMemberAccessorCallee(ApplySite applySite);

/// Given a full apply site, apply the given callback to each of its
/// "direct results".
///
Expand Down
15 changes: 14 additions & 1 deletion include/swift/SILOptimizer/Differentiation/PullbackEmitter.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
SILBuilder localAllocBuilder;

/// Stack buffers allocated for storing local adjoint values.
SmallVector<SILValue, 64> functionLocalAllocations;
SmallVector<AllocStackInst *, 64> functionLocalAllocations;

/// A set used to remember local allocations that were destroyed.
llvm::SmallDenseSet<SILValue> destroyedLocalAllocations;
Expand Down Expand Up @@ -316,6 +316,19 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
/// if any error occurs.
bool run();

/// Performs pullback generation on the empty pullback function, given that
/// the original function is a "semantic member accessor".
///
/// "Semantic member accessors" are attached to member properties that have a
/// corresponding tangent stored property in the parent `TangentVector` type.
/// These accessors have special-case pullback generation based on their
/// semantic behavior.
///
/// Returns true if any error occurs.
bool runForSemanticMemberAccessor();
bool runForSemanticMemberGetter();
bool runForSemanticMemberSetter();

/// If original result is non-varied, it will always have a zero derivative.
/// Skip full pullback generation and simply emit zero derivatives for wrt
/// parameters.
Expand Down
38 changes: 38 additions & 0 deletions lib/SILOptimizer/Differentiation/Common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,44 @@ DestructureTupleInst *getSingleDestructureTupleUser(SILValue value) {
return result;
}

bool isSemanticMemberAccessor(SILFunction *original) {
auto *dc = original->getDeclContext();
if (!dc)
return false;
auto *decl = dc->getAsDecl();
if (!decl)
return false;
auto *accessor = dyn_cast<AccessorDecl>(decl);
if (!accessor)
return false;
// Currently, only getters and setters are supported.
// TODO(SR-12640): Support `modify` accessors.
if (accessor->getAccessorKind() != AccessorKind::Get &&
accessor->getAccessorKind() != AccessorKind::Set)
return false;
// Accessor must come from a `var` declaration.
auto *varDecl = dyn_cast<VarDecl>(accessor->getStorage());
if (!varDecl)
return false;
// Return true for stored property accessors.
if (varDecl->hasStorage() && varDecl->isInstanceMember())
return true;
// Return true for properties that have attached property wrappers.
if (varDecl->hasAttachedPropertyWrapper())
return true;
// Otherwise, return false.
// User-defined accessors can never be supported because they may use custom
// logic that does not semantically perform a member access.
return false;
}

bool hasSemanticMemberAccessorCallee(ApplySite applySite) {
if (auto *FRI = dyn_cast<FunctionRefBaseInst>(applySite.getCallee()))
if (auto *F = FRI->getReferencedFunctionOrNull())
return isSemanticMemberAccessor(F);
return false;
}

void forEachApplyDirectResult(
FullApplySite applySite,
llvm::function_ref<void(SILValue)> resultCallback) {
Expand Down
Loading