Skip to content

Commit c2cb48d

Browse files
committed
[TF] Fix TensorArrayProtocol and TensorGroup derived conformances.
Remove `LoadExpr` and `InjectIntoOptionalExpr`: creating instances of these `ImplicitConversionExpr` subclasses causes an assertion failure in `ConstraintGenerator::visitImplicitConversionExpr`. This changes seem necessary after changes to `SanitizeExpr` in 0232cd0.
1 parent 2b8c2cf commit c2cb48d

File tree

2 files changed

+33
-53
lines changed

2 files changed

+33
-53
lines changed

lib/Sema/DerivedConformanceTensorArrayProtocol.cpp

Lines changed: 21 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ deriveBodyTensorArrayProtocol_unpackTensorHandles(
111111
tensorArrayProto, C.Id_tensorHandleCount);
112112

113113
Type intType = C.getIntDecl()->getDeclaredType();
114-
TypeExpr *intTE = TypeExpr::createImplicit(intType, C);
114+
TypeExpr *intTypeExpr = TypeExpr::createImplicit(intType, C);
115115

116116
// Iterate through the `TensorArrayProtocol`-conforming members and call
117117
// `member._unpackTensorHandles(into:)`.
@@ -144,11 +144,8 @@ deriveBodyTensorArrayProtocol_unpackTensorHandles(
144144
// Obtain the method call argument.
145145
auto *addressDRE = new (C) DeclRefExpr(
146146
currAddressDecl, DeclNameLoc(), /*implicit*/ true);
147-
auto *loadExpr = new (C) LoadExpr(addressDRE, baseAddressType);
148-
auto *injectExpr = new (C) InjectIntoOptionalExpr(loadExpr, addressType);
149-
150-
auto *callExpr = CallExpr::createImplicit(
151-
C, memberMethodExpr, {injectExpr}, {C.getIdentifier("into")});
147+
auto *callExpr = CallExpr::createImplicit(C, memberMethodExpr, {addressDRE},
148+
{C.getIdentifier("into")});
152149

153150
// Advance the current address.
154151
DeclName advancedName(C, C.getIdentifier("advanced"),
@@ -170,9 +167,9 @@ deriveBodyTensorArrayProtocol_unpackTensorHandles(
170167
// Cast the tensor handle count to Int.
171168
auto intInitName = DeclName(C, DeclBaseName::createConstructor(),
172169
{Identifier()});
173-
auto *intInitExpr =
174-
new (C) UnresolvedDotExpr(intTE, SourceLoc(), DeclNameRef(intInitName),
175-
DeclNameLoc(), /*Implicit*/ true);
170+
auto *intInitExpr = new (C)
171+
UnresolvedDotExpr(intTypeExpr, SourceLoc(), DeclNameRef(intInitName),
172+
DeclNameLoc(), /*Implicit*/ true);
176173
auto *intInitCallExpr = CallExpr::createImplicit(
177174
C, intInitExpr, {memberCountMRE}, {Identifier()});
178175

@@ -276,16 +273,15 @@ deriveBodyTensorArrayProtocol_tensorHandleCount(AbstractFunctionDecl *funcDecl,
276273

277274
// Concatenate all member `_tensorHandleCount`s.
278275
Type intType = C.getInt32Decl()->getDeclaredType();
279-
TypeExpr *intTE = TypeExpr::createImplicit(intType, C);
276+
TypeExpr *intTypeExpr = TypeExpr::createImplicit(intType, C);
280277
auto plusOpLookup = C.getInt32Decl()->lookupDirect(C.getIdentifier("+"));
281278
assert(plusOpLookup.size() == 1 && "Ambiguous 'Int32.+' operator.");
282279
ValueDecl *plusOpDecl = plusOpLookup.front();
283-
auto plusOpDRE = new (C)
284-
DeclRefExpr(plusOpDecl, DeclNameLoc(), /*Implicit*/ true);
285-
auto plusOpExpr = new (C) DotSyntaxCallExpr(plusOpDRE, SourceLoc(), intTE);
286280
Expr *tensorHandleCountExpr = new (C)
287281
IntegerLiteralExpr("0", SourceLoc(), /*implicit*/ true);
288282
for (auto member : nominal->getStoredProperties()) {
283+
auto plusOpExpr = new (C) MemberRefExpr(
284+
intTypeExpr, SourceLoc(), plusOpDecl, DeclNameLoc(), /*Implicit*/ true);
289285
auto *memberDRE = new (C) MemberRefExpr(
290286
selfDRE, SourceLoc(), member, DeclNameLoc(), /*Implicit*/ true);
291287
auto *memberTensorHandleCountExpr = new (C)
@@ -360,12 +356,11 @@ deriveBodyTensorArrayProtocol_typeList(AbstractFunctionDecl *funcDecl, void *) {
360356
auto plusOpLookup = C.getArrayDecl()->lookupDirect(C.getIdentifier("+"));
361357
assert(plusOpLookup.size() == 1 && "Ambiguous 'Array.+' operator.");
362358
ValueDecl *plusOpDecl = plusOpLookup.front();
363-
auto plusOpDRE = new (C)
364-
DeclRefExpr(plusOpDecl, DeclNameLoc(), /*Implicit*/ true);
365-
auto plusOpExpr = new (C)
366-
DotSyntaxCallExpr(plusOpDRE, SourceLoc(), arrayTypeExpr);
367359
Expr *typeListExpr = ArrayExpr::create(C, SourceLoc(), {}, {}, SourceLoc());
368360
for (auto member : nominal->getStoredProperties()) {
361+
auto *plusOpExpr =
362+
new (C) MemberRefExpr(arrayTypeExpr, SourceLoc(), plusOpDecl,
363+
DeclNameLoc(), /*Implicit*/ true);
369364
auto memberType =
370365
parentDC->mapTypeIntoContext(member->getValueInterfaceType());
371366
auto *memberTypeExpr = TypeExpr::createImplicit(memberType, C);
@@ -436,7 +431,7 @@ deriveBodyTensorArrayProtocol_init(AbstractFunctionDecl *funcDecl, void *) {
436431
C.getUnsafePointerDecl(), Type(), {cTensorHandleType});
437432
auto addressType = BoundGenericType::get(
438433
C.getOptionalDecl(), Type(), {baseAddressType});
439-
auto *addressTE = TypeExpr::createImplicit(addressType, C);
434+
auto *addressTypeExpr = TypeExpr::createImplicit(addressType, C);
440435

441436
// Get references to `self` and parameter declarations.
442437
auto *selfDecl = funcDecl->getImplicitSelfDecl();
@@ -474,7 +469,7 @@ deriveBodyTensorArrayProtocol_init(AbstractFunctionDecl *funcDecl, void *) {
474469
tensorArrayProto, C.Id_tensorHandleCount);
475470

476471
Type intType = C.getIntDecl()->getDeclaredType();
477-
TypeExpr *intTE = TypeExpr::createImplicit(intType, C);
472+
TypeExpr *intTypeExpr = TypeExpr::createImplicit(intType, C);
478473

479474
// Iterate over members and call `self.member = MemberType(_owning:)`.
480475
llvm::SmallVector<ASTNode, 2> thenMemberExprs;
@@ -507,22 +502,15 @@ deriveBodyTensorArrayProtocol_init(AbstractFunctionDecl *funcDecl, void *) {
507502

508503
auto *addressDRE = new (C) DeclRefExpr(
509504
currAddressDecl, DeclNameLoc(), /*implicit*/ true);
510-
auto *loadExpr = new (C) LoadExpr(addressDRE, baseAddressType);
511-
512-
// Initialize the member using its `TensorGroup` constructor.
513-
// Note that, initialization is dependent on the branch of the
514-
// if-statement taken.
515-
auto *thenInitExpr = new (C) InjectIntoOptionalExpr(loadExpr, addressType);
516505
auto *thenInitCallExpr = CallExpr::createImplicit(
517-
C, memberInitExpr, {thenInitExpr}, {C.getIdentifier("_owning")});
506+
C, memberInitExpr, {addressDRE}, {C.getIdentifier("_owning")});
518507

519508
// Create a nil expression with type `UnsafePointer<CTensorHandle>?` for the
520509
// `else` branch.
521510
auto *nilDecl = C.getOptionalNoneDecl();
522-
auto *nilDRE = new (C) DeclRefExpr(
523-
nilDecl, DeclNameLoc(), /*implicit*/ true);
524-
auto *elseInitExpr = new (C) DotSyntaxCallExpr(
525-
nilDRE, SourceLoc(), addressTE);
511+
auto *elseInitExpr =
512+
new (C) MemberRefExpr(addressTypeExpr, SourceLoc(), nilDecl,
513+
DeclNameLoc(), /*Implicit*/ true);
526514
auto *elseInitCallExpr = CallExpr::createImplicit(
527515
C, memberInitExpr, {elseInitExpr}, {C.getIdentifier("_owning")});
528516

@@ -558,9 +546,9 @@ deriveBodyTensorArrayProtocol_init(AbstractFunctionDecl *funcDecl, void *) {
558546
// Cast the tensor handle count to Int.
559547
auto intInitName = DeclName(C, DeclBaseName::createConstructor(),
560548
{Identifier()});
561-
auto *intInitExpr =
562-
new (C) UnresolvedDotExpr(intTE, SourceLoc(), DeclNameRef(intInitName),
563-
DeclNameLoc(), /*Implicit*/ true);
549+
auto *intInitExpr = new (C)
550+
UnresolvedDotExpr(intTypeExpr, SourceLoc(), DeclNameRef(intInitName),
551+
DeclNameLoc(), /*Implicit*/ true);
564552
auto *intInitCallExpr = CallExpr::createImplicit(
565553
C, intInitExpr, {memberCountMRE}, {Identifier()});
566554

lib/Sema/DerivedConformanceTensorGroup.cpp

Lines changed: 12 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -80,12 +80,11 @@ deriveBodyTensorGroup_typeList(AbstractFunctionDecl *funcDecl, void *) {
8080
auto plusOpLookup = C.getArrayDecl()->lookupDirect(C.getIdentifier("+"));
8181
assert(plusOpLookup.size() == 1 && "Ambiguous 'Array.+' operator.");
8282
ValueDecl *plusOpDecl = plusOpLookup.front();
83-
auto plusOpDRE = new (C)
84-
DeclRefExpr(plusOpDecl, DeclNameLoc(), /*Implicit*/ true);
85-
auto plusOpExpr = new (C)
86-
DotSyntaxCallExpr(plusOpDRE, SourceLoc(), arrayTypeExpr);
8783
Expr *typeListExpr = ArrayExpr::create(C, SourceLoc(), {}, {}, SourceLoc());
8884
for (auto member : nominal->getStoredProperties()) {
85+
auto plusOpExpr =
86+
new (C) MemberRefExpr(arrayTypeExpr, SourceLoc(), plusOpDecl,
87+
DeclNameLoc(), /*Implicit*/ true);
8988
auto memberType =
9089
parentDC->mapTypeIntoContext(member->getValueInterfaceType());
9190
auto *memberTypeExpr = TypeExpr::createImplicit(memberType, C);
@@ -154,7 +153,7 @@ deriveBodyTensorGroup_init(AbstractFunctionDecl *funcDecl, void *) {
154153
C.getUnsafePointerDecl(), Type(), {cTensorHandleType});
155154
auto addressType = BoundGenericType::get(
156155
C.getOptionalDecl(), Type(), {baseAddressType});
157-
auto *addressTE = TypeExpr::createImplicit(addressType, C);
156+
auto *addressTypeExpr = TypeExpr::createImplicit(addressType, C);
158157

159158
// Get references to `self` and parameter declarations.
160159
auto *selfDecl = funcDecl->getImplicitSelfDecl();
@@ -192,7 +191,7 @@ deriveBodyTensorGroup_init(AbstractFunctionDecl *funcDecl, void *) {
192191
tensorArrayProto, C.Id_tensorHandleCount);
193192

194193
Type intType = C.getIntDecl()->getDeclaredType();
195-
TypeExpr *intTE = TypeExpr::createImplicit(intType, C);
194+
TypeExpr *intTypeExpr = TypeExpr::createImplicit(intType, C);
196195

197196
// Iterate through the `TensorGroup`-conforming members and call
198197
// `self.member = MemberType(_owning:)`.
@@ -226,22 +225,15 @@ deriveBodyTensorGroup_init(AbstractFunctionDecl *funcDecl, void *) {
226225

227226
auto *addressDRE = new (C) DeclRefExpr(
228227
currAddressDecl, DeclNameLoc(), /*implicit*/ true);
229-
auto *loadExpr = new (C) LoadExpr(addressDRE, baseAddressType);
230-
231-
// Initialize the member using its `TensorGroup` constructor.
232-
// Note that, initialization is dependent on the branch of the
233-
// if-statement taken.
234-
auto *thenInitExpr = new (C) InjectIntoOptionalExpr(loadExpr, addressType);
235228
auto *thenInitCallExpr = CallExpr::createImplicit(
236-
C, memberInitExpr, {thenInitExpr}, {C.getIdentifier("_owning")});
229+
C, memberInitExpr, {addressDRE}, {C.getIdentifier("_owning")});
237230

238231
// Create a nil expression with type `UnsafePointer<CTensorHandle>?` for the
239232
// `else` branch.
240233
auto *nilDecl = C.getOptionalNoneDecl();
241-
auto *nilDRE = new (C) DeclRefExpr(
242-
nilDecl, DeclNameLoc(), /*implicit*/ true);
243-
auto *elseInitExpr = new (C) DotSyntaxCallExpr(
244-
nilDRE, SourceLoc(), addressTE);
234+
auto *elseInitExpr =
235+
new (C) MemberRefExpr(addressTypeExpr, SourceLoc(), nilDecl,
236+
DeclNameLoc(), /*Implicit*/ true);
245237
auto *elseInitCallExpr = CallExpr::createImplicit(
246238
C, memberInitExpr, {elseInitExpr}, {C.getIdentifier("_owning")});
247239

@@ -277,9 +269,9 @@ deriveBodyTensorGroup_init(AbstractFunctionDecl *funcDecl, void *) {
277269
// Cast the tensor handle count to Int.
278270
auto intInitName = DeclName(C, DeclBaseName::createConstructor(),
279271
{Identifier()});
280-
auto *intInitExpr =
281-
new (C) UnresolvedDotExpr(intTE, SourceLoc(), DeclNameRef(intInitName),
282-
DeclNameLoc(), /*Implicit*/ true);
272+
auto *intInitExpr = new (C)
273+
UnresolvedDotExpr(intTypeExpr, SourceLoc(), DeclNameRef(intInitName),
274+
DeclNameLoc(), /*Implicit*/ true);
283275
auto *intInitCallExpr = CallExpr::createImplicit(
284276
C, intInitExpr, {memberCountMRE}, {Identifier()});
285277

0 commit comments

Comments
 (0)