diff --git a/lib/SILOptimizer/Differentiation/PullbackEmitter.cpp b/lib/SILOptimizer/Differentiation/PullbackEmitter.cpp index 32207e38f9590..098709c3e541a 100644 --- a/lib/SILOptimizer/Differentiation/PullbackEmitter.cpp +++ b/lib/SILOptimizer/Differentiation/PullbackEmitter.cpp @@ -1331,6 +1331,7 @@ void PullbackEmitter::visitSILInstruction(SILInstruction *inst) { AllocStackInst * PullbackEmitter::getArrayAdjointElementBuffer(SILValue arrayAdjoint, int eltIndex, SILLocation loc) { + auto &ctx = builder.getASTContext(); auto arrayTanType = cast(arrayAdjoint->getType().getASTType()); auto arrayType = arrayTanType->getParent()->castTo(); auto eltTanType = arrayType->getGenericArgs().front()->getCanonicalType(); @@ -1340,7 +1341,19 @@ PullbackEmitter::getArrayAdjointElementBuffer(SILValue arrayAdjoint, auto *arrayTanStructDecl = arrayTanType->getStructOrBoundGenericStruct(); auto subscriptLookup = arrayTanStructDecl->lookupDirect(DeclBaseName::createSubscript()); - auto *subscriptDecl = cast(subscriptLookup.front()); + SubscriptDecl *subscriptDecl = nullptr; + for (auto *candidate : subscriptLookup) { + auto candidateModule = candidate->getModuleContext(); + if (candidateModule->getName() == ctx.Id_Differentiation || + candidateModule->isStdlibModule()) { + assert(!subscriptDecl && "Multiple `Array.TangentVector.subscript`s"); + subscriptDecl = cast(candidate); +#ifdef NDEBUG + break; +#endif + } + } + assert(subscriptDecl && "No `Array.TangentVector.subscript`"); auto *subscriptGetterDecl = subscriptDecl->getAccessor(AccessorKind::Get); assert(subscriptGetterDecl && "No `Array.TangentVector.subscript` getter"); SILOptFunctionBuilder fb(getContext().getTransform()); @@ -1352,7 +1365,6 @@ PullbackEmitter::getArrayAdjointElementBuffer(SILValue arrayAdjoint, subscriptGetterFn->getLoweredFunctionType()->getSubstGenericSignature(); // Apply `Array.TangentVector.subscript.getter` to get array element adjoint // buffer. - auto &ctx = builder.getASTContext(); // %index_literal = integer_literal $Builtin.IntXX, auto builtinIntType = SILType::getPrimitiveObjectType(ctx.getIntDecl() diff --git a/test/AutoDiff/stdlib/array.swift b/test/AutoDiff/stdlib/array.swift index 0983060cc7ba5..06807cc9c32ee 100644 --- a/test/AutoDiff/stdlib/array.swift +++ b/test/AutoDiff/stdlib/array.swift @@ -8,6 +8,15 @@ var ArrayAutoDiffTests = TestSuite("ArrayAutoDiff") typealias FloatArrayTan = Array.TangentVector +extension Array.DifferentiableView { + /// A subscript that always fatal errors. + /// + /// The differentiation transform should never emit calls to this. + subscript(alwaysFatalError: Int) -> Element { + fatalError("wrong subscript") + } +} + ArrayAutoDiffTests.test("ArrayIdentity") { func arrayIdentity(_ x: [Float]) -> [Float] { return x