Skip to content

[SR-14228] [AutoDiff] "Curry thunk" differentiation regression #54819

@dan-zheng

Description

@dan-zheng
Previous ID SR-14228
Radar None
Original Reporter @dan-zheng
Type Bug
Additional Detail from JIRA
Votes 0
Component/s Compiler
Labels Bug, AutoDiff
Assignee None
Priority Medium

md5: 8635dba187654182d9292136d60a56fd

relates to:

  • TF-1030 allow serialized functions to reference implicit derivatives in some cases

Issue Description:

Curry thunks were recently rewritten as implicit AST closures instead of SILGen'd thunks: #28698

This caused regressions in curry thunk differentiation. Extracted from test/AutoDiff/downstream/generics.swift:

// TF-688: Test generic curry thunk cloning.
public struct TF_688_Struct<Scalar> {
  var x: Scalar
}
extension TF_688_Struct: Differentiable where Scalar: Differentiable {
  @differentiable
  public static func id(x: Self) -> Self {
    return x
  }
}
@differentiable(wrt: x)
public func TF_688<Scalar: Differentiable>(
  _ x: TF_688_Struct<Scalar>,
  reduction: @differentiable (TF_688_Struct<Scalar>) -> TF_688_Struct<Scalar> = TF_688_Struct.id
) -> TF_688_Struct<Scalar> {
  reduction(x)
}

Before: no error.

// default argument 1 of TF_688<A>(_:reduction:)
sil non_abi [serialized] [ossa] @$s4main6TF_688_9reductionAA0B11_688_StructVyxGAF_A2FXFts14DifferentiableRzlFfA0_ : $@convention(thin) <Scalar where Scalar : Differentiable> () -> @owned @differentiable @callee_guaranteed (@in_guaranteed TF_688_Struct<Scalar>) -> @out TF_688_Struct<Scalar> {
bb0:
  %0 = metatype $@thin TF_688_Struct<Scalar>.Type // user: %2
  // function_ref curry thunk of static TF_688_Struct<A>.id(x:)
  %1 = function_ref @$s4main13TF_688_StructVAAs14DifferentiableRzlE2id1xACyxGAG_tFZTc : $@convention(thin) <τ_0_0 where τ_0_0 : Differentiable> (@thin TF_688_Struct<τ_0_0>.Type) -> @owned @callee_guaranteed (@in_guaranteed TF_688_Struct<τ_0_0>) -> @out TF_688_Struct<τ_0_0> // user: %2
  %2 = apply %1<Scalar>(%0) : $@convention(thin) <τ_0_0 where τ_0_0 : Differentiable> (@thin TF_688_Struct<τ_0_0>.Type) -> @owned @callee_guaranteed (@in_guaranteed TF_688_Struct<τ_0_0>) -> @out TF_688_Struct<τ_0_0> // user: %3
  %3 = differentiable_function [parameters 0] %2 : $@callee_guaranteed (@in_guaranteed TF_688_Struct<Scalar>) -> @out TF_688_Struct<Scalar> // user: %4
  return %3 : $@differentiable @callee_guaranteed (@in_guaranteed TF_688_Struct<Scalar>) -> @out TF_688_Struct<Scalar> // id: %4
} // end sil function '$s4main6TF_688_9reductionAA0B11_688_StructVyxGAF_A2FXFts14DifferentiableRzlFfA0_'

After: error regarding differentiating fragile function in serialized function.
This error was introduced in #28582

$ swiftc -Xllvm -debug-only=differentiation tf-688.swift
// AD__$s4main6TF_688_9reductionAA0B11_688_StructVyxGAF_A2FXFts14DifferentiableRzlFfA0_A2FcAFmcfu___differentiable_curry_thunk_src_0_wrt_0
sil shared [serialized] @AD__$s4main6TF_688_9reductionAA0B11_688_StructVyxGAF_A2FXFts14DifferentiableRzlFfA0_A2FcAFmcfu___differentiable_curry_thunk_src_0_wrt_0 : $@convention(thin) <τ_0_0 where τ_0_0 : Differentiable> (@thin TF_688_Struct<τ_0_0>.Type) -> @owned @differentiable @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed TF_688_Struct<τ_0_0>) -> @out TF_688_Struct<τ_0_1> for <τ_0_0, τ_0_0> {
// %0                                             // users: %3, %1
bb0(%0 : $@thin TF_688_Struct<τ_0_0>.Type):
  debug_value %0 : $@thin TF_688_Struct<τ_0_0>.Type, let, name "self", argno 1 // id: %1
  // function_ref implicit closure #&#8203;2 in implicit closure #&#8203;1 in default argument 1 of TF_688<A>(_:reduction:)
  %2 = function_ref @$s4main6TF_688_9reductionAA0B11_688_StructVyxGAF_A2FXFts14DifferentiableRzlFfA0_A2FcAFmcfu_A2Fcfu0_ : $@convention(thin) <τ_0_0 where τ_0_0 : Differentiable> (@in_guaranteed TF_688_Struct<τ_0_0>, @thin TF_688_Struct<τ_0_0>.Type) -> @out TF_688_Struct<τ_0_0> // user: %3
  %3 = partial_apply [callee_guaranteed] %2<τ_0_0>(%0) : $@convention(thin) <τ_0_0 where τ_0_0 : Differentiable> (@in_guaranteed TF_688_Struct<τ_0_0>, @thin TF_688_Struct<τ_0_0>.Type) -> @out TF_688_Struct<τ_0_0> // user: %4
  %4 = convert_function %3 : $@callee_guaranteed (@in_guaranteed TF_688_Struct<τ_0_0>) -> @out TF_688_Struct<τ_0_0> to $@callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed TF_688_Struct<τ_0_0>) -> @out TF_688_Struct<τ_0_1> for <τ_0_0, τ_0_0> // user: %5
  %5 = differentiable_function [parameters 0] %4 : $@callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed TF_688_Struct<τ_0_0>) -> @out TF_688_Struct<τ_0_1> for <τ_0_0, τ_0_0> // user: %6
  return %5 : $@differentiable @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed TF_688_Struct<τ_0_0>) -> @out TF_688_Struct<τ_0_1> for <τ_0_0, τ_0_0> // id: %6
} // end sil function 'AD__$s4main6TF_688_9reductionAA0B11_688_StructVyxGAF_A2FXFts14DifferentiableRzlFfA0_A2FcAFmcfu___differentiable_curry_thunk_src_0_wrt_0'

[AD] Diagnosing non-differentiability.
[AD] For value:
  %4 = convert_function %3 : $@callee_guaranteed (@in_guaranteed TF_688_Struct<τ_0_0>) -> @out TF_688_Struct<τ_0_0> to $@callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed TF_688_Struct<τ_0_0>) -> @out TF_688_Struct<τ_0_1> for <τ_0_0, τ_0_0> // user: %5
[AD] With invoker:
(differentiation_invoker differentiable_function_inst=(  %5 = differentiable_function [parameters 0] %4 : $@callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed TF_688_Struct<τ_0_0>) -> @out TF_688_Struct<τ_0_1> for <τ_0_0, τ_0_0> // user: %6
))
tf-688.swift:14:95: error: function is not differentiable
  reduction: @differentiable (TF_688_Struct<Scalar>) -> TF_688_Struct<Scalar> = TF_688_Struct.id
                                                                                ~~~~~~~~~~~~~~^~
tf-688.swift:14:95: note: differentiated functions in '@inlinable' functions must be marked '@differentiable' or have a public '@derivative'; this is not possible with a closure, make a top-level function instead
  reduction: @differentiable (TF_688_Struct<Scalar>) -> TF_688_Struct<Scalar> = TF_688_Struct.id
                                                                                              ^

Metadata

Metadata

Assignees

No one assigned

    Labels

    AutoDiffbugA deviation from expected or documented behavior. Also: expected but undesirable behavior.compilerThe Swift compiler itself

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions