diff --git a/lib/SILOptimizer/Differentiation/LinearMapInfo.cpp b/lib/SILOptimizer/Differentiation/LinearMapInfo.cpp index 9797c3e982381..3585697e96875 100644 --- a/lib/SILOptimizer/Differentiation/LinearMapInfo.cpp +++ b/lib/SILOptimizer/Differentiation/LinearMapInfo.cpp @@ -142,10 +142,12 @@ void LinearMapInfo::populateBranchingTraceDecl(SILBasicBlock *originalBB, heapAllocatedContext = true; decl->setInterfaceType(astCtx.TheRawPointerType); } else { // Otherwise the payload is the linear map tuple. - auto linearMapStructTy = getLinearMapTupleType(predBB)->getCanonicalType(); + auto *linearMapStructTy = getLinearMapTupleType(predBB); + assert(linearMapStructTy && "must have linear map struct type for predecessor BB"); + auto canLinearMapStructTy = linearMapStructTy->getCanonicalType(); decl->setInterfaceType( - linearMapStructTy->hasArchetype() - ? linearMapStructTy->mapTypeOutOfContext() : linearMapStructTy); + canLinearMapStructTy->hasArchetype() + ? canLinearMapStructTy->mapTypeOutOfContext() : canLinearMapStructTy); } // Create enum element and enum case declarations. auto *paramList = ParameterList::create(astCtx, {decl}); @@ -331,10 +333,28 @@ void LinearMapInfo::generateDifferentiationDataStructures( } // Add linear map fields to the linear map tuples. - for (auto &origBB : *original) { + // + // Now we need to be very careful as we're having a very subtle + // chicken-and-egg problem. We need lowered branch trace enum type for the + // linear map typle type. However branch trace enum type lowering depends on + // the lowering of its elements (at very least, the type classification of + // being trivial / non-trivial). As the lowering is cached we need to ensure + // we compute lowered type for the branch trace enum when the corresponding + // EnumDecl is fully complete: we cannot add more entries without causing some + // very subtle issues later on. However, the elements of the enum are linear + // map tuples of predecessors, that correspondingly may contain branch trace + // enums of corresponding predecessor BBs. + // + // Traverse all BBs in reverse post-order traversal order to ensure we process + // each BB before its predecessors. + llvm::ReversePostOrderTraversal RPOT(original); + for (auto Iter = RPOT.begin(), E = RPOT.end(); Iter != E; ++Iter) { + auto *origBB = *Iter; SmallVector linearTupleTypes; - if (!origBB.isEntry()) { - CanType traceEnumType = getBranchingTraceEnumLoweredType(&origBB).getASTType(); + if (!origBB->isEntry()) { + populateBranchingTraceDecl(origBB, loopInfo); + + CanType traceEnumType = getBranchingTraceEnumLoweredType(origBB).getASTType(); linearTupleTypes.emplace_back(traceEnumType, astCtx.getIdentifier(traceEnumFieldName)); } @@ -343,7 +363,7 @@ void LinearMapInfo::generateDifferentiationDataStructures( // Do not add linear map fields for semantic member accessors, which have // special-case pullback generation. Linear map tuples should be empty. } else { - for (auto &inst : origBB) { + for (auto &inst : *origBB) { if (auto *ai = dyn_cast(&inst)) { // Add linear map field to struct for active `apply` instructions. // Skip array literal intrinsic applications since array literal @@ -363,12 +383,9 @@ void LinearMapInfo::generateDifferentiationDataStructures( } } - linearMapTuples.insert({&origBB, TupleType::get(linearTupleTypes, astCtx)}); + linearMapTuples.insert({origBB, TupleType::get(linearTupleTypes, astCtx)}); } - for (auto &origBB : *original) - populateBranchingTraceDecl(&origBB, loopInfo); - // Print generated linear map structs and branching trace enums. // These declarations do not show up with `-emit-sil` because they are // implicit. Instead, use `-Xllvm -debug-only=differentiation` to test diff --git a/test/AutoDiff/SILOptimizer/differentiation_control_flow_sil.swift b/test/AutoDiff/SILOptimizer/differentiation_control_flow_sil.swift index 31c137a502d49..3884c11420e62 100644 --- a/test/AutoDiff/SILOptimizer/differentiation_control_flow_sil.swift +++ b/test/AutoDiff/SILOptimizer/differentiation_control_flow_sil.swift @@ -56,7 +56,7 @@ func cond(_ x: Float) -> Float { // CHECK-SIL: [[BB3_PRED_PRED2:%.*]] = enum $_AD__cond_bb3__Pred__src_0_wrt_0, #_AD__cond_bb3__Pred__src_0_wrt_0.bb2!enumelt, [[BB2_PB_STRUCT]] // CHECK-SIL: br bb3({{.*}} : $Float, [[BB3_PRED_PRED2]] : $_AD__cond_bb3__Pred__src_0_wrt_0) -// CHECK-SIL: bb3([[ORIG_RES:%.*]] : $Float, [[BB3_PRED_ARG:%.*]] : $_AD__cond_bb3__Pred__src_0_wrt_0) +// CHECK-SIL: bb3([[ORIG_RES:%.*]] : $Float, [[BB3_PRED_ARG:%.*]] : @owned $_AD__cond_bb3__Pred__src_0_wrt_0) // CHECK-SIL: [[PULLBACK_REF:%.*]] = function_ref @condTJpSpSr // CHECK-SIL: [[PB:%.*]] = partial_apply [callee_guaranteed] [[PULLBACK_REF]]([[BB3_PRED_ARG]]) // CHECK-SIL: [[VJP_RESULT:%.*]] = tuple ([[ORIG_RES]] : $Float, [[PB]] : $@callee_guaranteed (Float) -> Float) @@ -64,7 +64,7 @@ func cond(_ x: Float) -> Float { // CHECK-SIL-LABEL: sil private [ossa] @condTJpSpSr : $@convention(thin) (Float, @owned _AD__cond_bb3__Pred__src_0_wrt_0) -> Float { -// CHECK-SIL: bb0([[SEED:%.*]] : $Float, [[BB3_PRED:%.*]] : $_AD__cond_bb3__Pred__src_0_wrt_0): +// CHECK-SIL: bb0([[SEED:%.*]] : $Float, [[BB3_PRED:%.*]] : @owned $_AD__cond_bb3__Pred__src_0_wrt_0): // CHECK-SIL: switch_enum [[BB3_PRED]] : $_AD__cond_bb3__Pred__src_0_wrt_0, case #_AD__cond_bb3__Pred__src_0_wrt_0.bb2!enumelt: bb1, case #_AD__cond_bb3__Pred__src_0_wrt_0.bb1!enumelt: bb3 // CHECK-SIL: bb1([[BB3_PRED2_TRAMP_PB_STRUCT:%.*]] : @owned $(predecessor: _AD__cond_bb2__Pred__src_0_wrt_0, @callee_guaranteed (Float) -> (Float, Float))): @@ -132,6 +132,39 @@ func loop_generic(_ x: T) -> T { return result } +@differentiable(reverse) +@_silgen_name("loop_context") +func loop_context(x: Float) -> Float { + let y = x + 1 + for _ in 0 ..< 1 {} + return y +} + +// CHECK-DATA-STRUCTURES-LABEL: Generated linear map tuples and branching trace enums for @loop_context: +// CHECK-DATA-STRUCTURES: (_: (Float) -> Float) +// CHECK-DATA-STRUCTURES: (predecessor: _AD__loop_context_bb1__Pred__src_0_wrt_0) +// CHECK-DATA-STRUCTURES: (predecessor: _AD__loop_context_bb2__Pred__src_0_wrt_0) +// CHECK-DATA-STRUCTURES: (predecessor: _AD__loop_context_bb3__Pred__src_0_wrt_0) +// CHECK-DATA-STRUCTURES: enum _AD__loop_context_bb0__Pred__src_0_wrt_0 { +// CHECK-DATA-STRUCTURES: } +// CHECK-DATA-STRUCTURES: enum _AD__loop_context_bb1__Pred__src_0_wrt_0 { +// CHECK-DATA-STRUCTURES: case bb2(Builtin.RawPointer) +// CHECK-DATA-STRUCTURES: case bb0((_: (Float) -> Float)) +// CHECK-DATA-STRUCTURES: } +// CHECK-DATA-STRUCTURES: enum _AD__loop_context_bb2__Pred__src_0_wrt_0 { +// CHECK-DATA-STRUCTURES: case bb1(Builtin.RawPointer) +// CHECK-DATA-STRUCTURES: } +// CHECK-DATA-STRUCTURES: enum _AD__loop_context_bb3__Pred__src_0_wrt_0 { +// CHECK-DATA-STRUCTURES: case bb1(Builtin.RawPointer) +// CHECK-DATA-STRUCTURES: } + +// CHECK-SIL-LABEL: sil private [ossa] @loop_contextTJpSpSr : $@convention(thin) (Float, @guaranteed Builtin.NativeObject) -> Float { +// CHECK-SIL: bb1([[LOOP_CONTEXT:%.*]] : $Builtin.RawPointer): +// CHECK-SIL: [[PB_TUPLE_ADDR:%.*]] = pointer_to_address [[LOOP_CONTEXT]] : $Builtin.RawPointer to [strict] $*(predecessor: _AD__loop_context_bb1__Pred__src_0_wrt_0) +// CHECK-SIL: [[PB_TUPLE_CPY:%.*]] = load [copy] [[PB_TUPLE_ADDR]] : $*(predecessor: _AD__loop_context_bb1__Pred__src_0_wrt_0) +// CHECK-SIL: br bb3({{.*}} : $Float, {{.*}} : $Float, [[PB_TUPLE_CPY]] : $(predecessor: _AD__loop_context_bb1__Pred__src_0_wrt_0)) +// CHECK-SIL: bb3({{.*}} : $Float, {{.*}} : $Float, {{.*}} : @owned $(predecessor: _AD__loop_context_bb1__Pred__src_0_wrt_0)): + // Test `switch_enum`. enum Enum { @@ -164,7 +197,7 @@ func enum_notactive(_ e: Enum, _ x: Float) -> Float { // CHECK-SIL: [[BB3_PRED_PRED2:%.*]] = enum $_AD__enum_notactive_bb3__Pred__src_0_wrt_1, #_AD__enum_notactive_bb3__Pred__src_0_wrt_1.bb2!enumelt, [[BB2_PB_STRUCT]] : $(predecessor: _AD__enum_notactive_bb2__Pred__src_0_wrt_1, @callee_guaranteed (Float) -> Float, @callee_guaranteed (Float) -> Float) // CHECK-SIL: br bb3({{.*}} : $Float, [[BB3_PRED_PRED2]] : $_AD__enum_notactive_bb3__Pred__src_0_wrt_1) -// CHECK-SIL: bb3([[ORIG_RES:%.*]] : $Float, [[BB3_PRED_ARG:%.*]] : $_AD__enum_notactive_bb3__Pred__src_0_wrt_1) +// CHECK-SIL: bb3([[ORIG_RES:%.*]] : $Float, [[BB3_PRED_ARG:%.*]] : @owned $_AD__enum_notactive_bb3__Pred__src_0_wrt_1) // CHECK-SIL: [[PULLBACK_REF:%.*]] = function_ref @enum_notactiveTJpUSpSr // CHECK-SIL: [[PB:%.*]] = partial_apply [callee_guaranteed] [[PULLBACK_REF]]([[BB3_PRED_ARG]]) // CHECK-SIL: [[VJP_RESULT:%.*]] = tuple ([[ORIG_RES]] : $Float, [[PB]] : $@callee_guaranteed (Float) -> Float)