-
Notifications
You must be signed in to change notification settings - Fork 10.5k
Closed
Labels
AutoDiffbugA deviation from expected or documented behavior. Also: expected but undesirable behavior.A deviation from expected or documented behavior. Also: expected but undesirable behavior.compilerThe Swift compiler itselfThe Swift compiler itself
Description
Previous ID | SR-14228 |
Radar | None |
Original Reporter | @dan-zheng |
Type | Bug |
Additional Detail from JIRA
Votes | 0 |
Component/s | Compiler |
Labels | Bug, AutoDiff |
Assignee | None |
Priority | Medium |
md5: 8635dba187654182d9292136d60a56fd
relates to:
- TF-1030 allow serialized functions to reference implicit derivatives in some cases
Issue Description:
Curry thunks were recently rewritten as implicit AST closures instead of SILGen'd thunks: #28698
This caused regressions in curry thunk differentiation. Extracted from test/AutoDiff/downstream/generics.swift
:
// TF-688: Test generic curry thunk cloning.
public struct TF_688_Struct<Scalar> {
var x: Scalar
}
extension TF_688_Struct: Differentiable where Scalar: Differentiable {
@differentiable
public static func id(x: Self) -> Self {
return x
}
}
@differentiable(wrt: x)
public func TF_688<Scalar: Differentiable>(
_ x: TF_688_Struct<Scalar>,
reduction: @differentiable (TF_688_Struct<Scalar>) -> TF_688_Struct<Scalar> = TF_688_Struct.id
) -> TF_688_Struct<Scalar> {
reduction(x)
}
Before: no error.
// default argument 1 of TF_688<A>(_:reduction:)
sil non_abi [serialized] [ossa] @$s4main6TF_688_9reductionAA0B11_688_StructVyxGAF_A2FXFts14DifferentiableRzlFfA0_ : $@convention(thin) <Scalar where Scalar : Differentiable> () -> @owned @differentiable @callee_guaranteed (@in_guaranteed TF_688_Struct<Scalar>) -> @out TF_688_Struct<Scalar> {
bb0:
%0 = metatype $@thin TF_688_Struct<Scalar>.Type // user: %2
// function_ref curry thunk of static TF_688_Struct<A>.id(x:)
%1 = function_ref @$s4main13TF_688_StructVAAs14DifferentiableRzlE2id1xACyxGAG_tFZTc : $@convention(thin) <τ_0_0 where τ_0_0 : Differentiable> (@thin TF_688_Struct<τ_0_0>.Type) -> @owned @callee_guaranteed (@in_guaranteed TF_688_Struct<τ_0_0>) -> @out TF_688_Struct<τ_0_0> // user: %2
%2 = apply %1<Scalar>(%0) : $@convention(thin) <τ_0_0 where τ_0_0 : Differentiable> (@thin TF_688_Struct<τ_0_0>.Type) -> @owned @callee_guaranteed (@in_guaranteed TF_688_Struct<τ_0_0>) -> @out TF_688_Struct<τ_0_0> // user: %3
%3 = differentiable_function [parameters 0] %2 : $@callee_guaranteed (@in_guaranteed TF_688_Struct<Scalar>) -> @out TF_688_Struct<Scalar> // user: %4
return %3 : $@differentiable @callee_guaranteed (@in_guaranteed TF_688_Struct<Scalar>) -> @out TF_688_Struct<Scalar> // id: %4
} // end sil function '$s4main6TF_688_9reductionAA0B11_688_StructVyxGAF_A2FXFts14DifferentiableRzlFfA0_'
After: error regarding differentiating fragile function in serialized function.
This error was introduced in #28582
$ swiftc -Xllvm -debug-only=differentiation tf-688.swift
// AD__$s4main6TF_688_9reductionAA0B11_688_StructVyxGAF_A2FXFts14DifferentiableRzlFfA0_A2FcAFmcfu___differentiable_curry_thunk_src_0_wrt_0
sil shared [serialized] @AD__$s4main6TF_688_9reductionAA0B11_688_StructVyxGAF_A2FXFts14DifferentiableRzlFfA0_A2FcAFmcfu___differentiable_curry_thunk_src_0_wrt_0 : $@convention(thin) <τ_0_0 where τ_0_0 : Differentiable> (@thin TF_688_Struct<τ_0_0>.Type) -> @owned @differentiable @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed TF_688_Struct<τ_0_0>) -> @out TF_688_Struct<τ_0_1> for <τ_0_0, τ_0_0> {
// %0 // users: %3, %1
bb0(%0 : $@thin TF_688_Struct<τ_0_0>.Type):
debug_value %0 : $@thin TF_688_Struct<τ_0_0>.Type, let, name "self", argno 1 // id: %1
// function_ref implicit closure #​2 in implicit closure #​1 in default argument 1 of TF_688<A>(_:reduction:)
%2 = function_ref @$s4main6TF_688_9reductionAA0B11_688_StructVyxGAF_A2FXFts14DifferentiableRzlFfA0_A2FcAFmcfu_A2Fcfu0_ : $@convention(thin) <τ_0_0 where τ_0_0 : Differentiable> (@in_guaranteed TF_688_Struct<τ_0_0>, @thin TF_688_Struct<τ_0_0>.Type) -> @out TF_688_Struct<τ_0_0> // user: %3
%3 = partial_apply [callee_guaranteed] %2<τ_0_0>(%0) : $@convention(thin) <τ_0_0 where τ_0_0 : Differentiable> (@in_guaranteed TF_688_Struct<τ_0_0>, @thin TF_688_Struct<τ_0_0>.Type) -> @out TF_688_Struct<τ_0_0> // user: %4
%4 = convert_function %3 : $@callee_guaranteed (@in_guaranteed TF_688_Struct<τ_0_0>) -> @out TF_688_Struct<τ_0_0> to $@callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed TF_688_Struct<τ_0_0>) -> @out TF_688_Struct<τ_0_1> for <τ_0_0, τ_0_0> // user: %5
%5 = differentiable_function [parameters 0] %4 : $@callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed TF_688_Struct<τ_0_0>) -> @out TF_688_Struct<τ_0_1> for <τ_0_0, τ_0_0> // user: %6
return %5 : $@differentiable @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed TF_688_Struct<τ_0_0>) -> @out TF_688_Struct<τ_0_1> for <τ_0_0, τ_0_0> // id: %6
} // end sil function 'AD__$s4main6TF_688_9reductionAA0B11_688_StructVyxGAF_A2FXFts14DifferentiableRzlFfA0_A2FcAFmcfu___differentiable_curry_thunk_src_0_wrt_0'
[AD] Diagnosing non-differentiability.
[AD] For value:
%4 = convert_function %3 : $@callee_guaranteed (@in_guaranteed TF_688_Struct<τ_0_0>) -> @out TF_688_Struct<τ_0_0> to $@callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed TF_688_Struct<τ_0_0>) -> @out TF_688_Struct<τ_0_1> for <τ_0_0, τ_0_0> // user: %5
[AD] With invoker:
(differentiation_invoker differentiable_function_inst=( %5 = differentiable_function [parameters 0] %4 : $@callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed TF_688_Struct<τ_0_0>) -> @out TF_688_Struct<τ_0_1> for <τ_0_0, τ_0_0> // user: %6
))
tf-688.swift:14:95: error: function is not differentiable
reduction: @differentiable (TF_688_Struct<Scalar>) -> TF_688_Struct<Scalar> = TF_688_Struct.id
~~~~~~~~~~~~~~^~
tf-688.swift:14:95: note: differentiated functions in '@inlinable' functions must be marked '@differentiable' or have a public '@derivative'; this is not possible with a closure, make a top-level function instead
reduction: @differentiable (TF_688_Struct<Scalar>) -> TF_688_Struct<Scalar> = TF_688_Struct.id
^
Metadata
Metadata
Assignees
Labels
AutoDiffbugA deviation from expected or documented behavior. Also: expected but undesirable behavior.A deviation from expected or documented behavior. Also: expected but undesirable behavior.compilerThe Swift compiler itselfThe Swift compiler itself