Skip to content

[cxx-interop] Import nullability of templated function parameters correctly #82161

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion lib/AST/ClangTypeConverter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -881,8 +881,15 @@ ClangTypeConverter::convertClangDecl(Type type, const clang::Decl *clangDecl) {

if (auto clangTypeDecl = dyn_cast<clang::TypeDecl>(clangDecl)) {
auto qualType = ctx.getTypeDeclType(clangTypeDecl);
if (type->isForeignReferenceType())
if (type->isForeignReferenceType()) {
qualType = ctx.getPointerType(qualType);
auto nonNullAttr = new (ctx) clang::TypeNonNullAttr(
ctx,
clang::AttributeCommonInfo(
clang::SourceRange(), clang::AttributeCommonInfo::AT_TypeNonNull,
clang::AttributeCommonInfo::Form::Implicit()));
qualType = ctx.getAttributedType(nonNullAttr, qualType, qualType);
}

return qualType.getUnqualifiedType();
}
Expand Down
20 changes: 18 additions & 2 deletions lib/ClangImporter/ImportType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2458,6 +2458,13 @@ ClangImporter::Implementation::importParameterType(
auto paramTy = desugarIfElaborated(param->getType());
paramTy = desugarIfBoundsAttributed(paramTy);

// If this type has a _Nullable/_Nonnull attribute, drop it, since we already
// have that information in optionalityOfParam.
if (auto attributedTy = dyn_cast<clang::AttributedType>(paramTy)) {
if (attributedTy->getImmediateNullability())
clang::AttributedType::stripOuterNullability(paramTy);
}

ImportTypeKind importKind = paramIsCompletionHandler
? ImportTypeKind::CompletionHandlerParameter
: ImportTypeKind::Parameter;
Expand Down Expand Up @@ -2486,6 +2493,17 @@ ClangImporter::Implementation::importParameterType(
pointerKind));
return std::nullopt;
}
switch (optionalityOfParam) {
case OTK_Optional:
swiftParamTy = OptionalType::get(swiftParamTy);
break;
case OTK_ImplicitlyUnwrappedOptional:
swiftParamTy = OptionalType::get(swiftParamTy);
isParamTypeImplicitlyUnwrapped = true;
break;
case OTK_None:
break;
}
} else if (isa<clang::ReferenceType>(paramTy) &&
isa<clang::TemplateTypeParmType>(paramTy->getPointeeType())) {
// We don't support universal reference, bail.
Expand Down Expand Up @@ -2734,8 +2752,6 @@ ParameterList *ClangImporter::Implementation::importFunctionParameterList(
}

bool knownNonNull = !nonNullArgs.empty() && nonNullArgs[index];
// Specialized templates need to match the args/result exactly.
knownNonNull |= clangDecl->isFunctionTemplateSpecialization();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With this removed, you might have fixed some of the crashes for the commented-out cases in template-instantiation-irgen.swift, e.g.,

// FIXME: this crashes because this round-trips to UnsafeMutablePointer<FRT?>
// func takesMutPtrToFRT(x: UnsafeMutablePointer<FRT>) { takesValue(x) }

Could you see if those are fixed by your patch?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah great point, thanks! Let me uncomment some of these.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(related: rdar://142850017)


// Check nullability of the parameter.
OptionalTypeKind optionalityOfParam =
Expand Down
4 changes: 2 additions & 2 deletions stdlib/public/Cxx/CxxSpan.swift
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ public protocol CxxSpan<Element> {
associatedtype Size: BinaryInteger

init()
init(_ unsafePointer : UnsafePointer<Element>, _ count: Size)
init(_ unsafePointer: UnsafePointer<Element>!, _ count: Size)

func size() -> Size
func __dataUnsafe() -> UnsafePointer<Element>?
Expand Down Expand Up @@ -136,7 +136,7 @@ public protocol CxxMutableSpan<Element> {
associatedtype Size: BinaryInteger

init()
init(_ unsafeMutablePointer : UnsafeMutablePointer<Element>, _ count: Size)
init(_ unsafeMutablePointer: UnsafeMutablePointer<Element>!, _ count: Size)

func size() -> Size
func __dataUnsafe() -> UnsafeMutablePointer<Element>?
Expand Down
4 changes: 3 additions & 1 deletion stdlib/public/Cxx/cxxshim/libcxxshim.h
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
template <class From, class To>
To __swift_interopStaticCast(From from) { return static_cast<To>(from); }
To _Nonnull __swift_interopStaticCast(From _Nonnull from) {
return static_cast<To>(from);
}
4 changes: 1 addition & 3 deletions test/Interop/Cxx/stdlib/use-std-span-typechecker.swift
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,5 @@ arr.withUnsafeBufferPointer { ubpointer in
arr.withUnsafeBufferPointer { ubpointer in
// FIXME: this crashes the compiler once we import span's templated ctors as Swift generics.
let _ = ConstSpanOfInt(ubpointer.baseAddress, ubpointer.count)
// expected-error@-1 {{value of optional type 'UnsafePointer<Int32>?' must be unwrapped to a value of type 'UnsafePointer<Int32>'}}
// expected-note@-2 {{coalesce using '??' to provide a default when the optional value contains 'nil'}}
// expected-note@-3 {{force-unwrap using '!' to abort execution if the optional value contains 'nil'}}
// expected-warning@-1 {{'init(_:_:)' is deprecated: use 'init(_:)' instead.}}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
// RUN: %target-swift-frontend -plugin-path %swift-plugin-dir -I %t/Inputs -cxx-interoperability-mode=default -enable-experimental-feature SafeInteropWrappers %t/template.swift -dump-macro-expansions -emit-ir -o %t/out -verify
// RUN: %target-swift-ide-test -plugin-path %swift-plugin-dir -I %t/Inputs -cxx-interoperability-mode=default -enable-experimental-feature SafeInteropWrappers -print-module -module-to-print=Template -source-filename=x | %FileCheck %s

// CHECK: func cb_template<T>(_ p: UnsafePointer<T>, _ size: Int{{.*}}) -> UnsafePointer<T>
// CHECK: func eb_template<T>(_ p: UnsafePointer<T>, _ end: UnsafePointer<T>) -> UnsafePointer<T>
// CHECK: func s_template<T>(_ p: UnsafePointer<T>) -> UnsafePointer<T>
// CHECK: func ui_template<T>(_ p: UnsafePointer<T>) -> UnsafePointer<T>
// CHECK: func cb_template<T>(_ p: UnsafePointer<T>!, _ size: Int{{.*}}) -> UnsafePointer<T>
// CHECK: func eb_template<T>(_ p: UnsafePointer<T>!, _ end: UnsafePointer<T>!) -> UnsafePointer<T>
// CHECK: func s_template<T>(_ p: UnsafePointer<T>!) -> UnsafePointer<T>
// CHECK: func ui_template<T>(_ p: UnsafePointer<T>!) -> UnsafePointer<T>

//--- Inputs/module.modulemap
module Template {
Expand Down
4 changes: 3 additions & 1 deletion test/Interop/Cxx/templates/Inputs/function-templates.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,9 @@ template <class T> bool constLvalueReferenceToBool(const T &t) { return t; }

template <class T> void forwardingReference(T &&) {}

template <class T> void PointerTemplateParameter(T*){}
template <class T> bool pointerTemplateParameter(T *t) { return t; }
template <class T> bool pointerTemplateParameterNonnull(T *_Nonnull t) { return t; }
template <class T> bool pointerTemplateParameterNullable(T *_Nullable t) { return t; }

template <typename F> void callFunction(F f) { f(); }
template <typename F, typename T> void callFunctionWithParam(F f, T t) { f(t); }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,6 @@
// TODO-CHECK: func defaultedTemplatePointerTypeParam<T>(_ t: UnsafeMutablePointer<T>)
// TODO-CHECK: func defaultedTemplatePointerPointerTypeParam<T>(_ t: UnsafeMutablePointer<OpaquePointer?>!)

// CHECK: func defaultedTemplatePointerTypeParam<T>(_ t: UnsafeMutablePointer<T>)
// CHECK: func defaultedTemplatePointerTypeParam<T>(_ t: UnsafeMutablePointer<T>!)
// We don't support references to dependent types (rdar://89034440).
// CHECK-NOT: defaultedTemplatePointerReferenceTypeParam
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@

// CHECK: func lvalueReference<T>(_ ref: inout T)
// CHECK: func constLvalueReference<T>(_: T)
// CHECK: func PointerTemplateParameter<T>(_: UnsafeMutablePointer<T>)

// CHECK: func pointerTemplateParameter<T>(_ t: UnsafeMutablePointer<T>!) -> Bool
// CHECK: func pointerTemplateParameterNonnull<T>(_ t: UnsafeMutablePointer<T>) -> Bool
// CHECK: func pointerTemplateParameterNullable<T>(_ t: UnsafeMutablePointer<T>?) -> Bool

// CHECK: enum Orbiters {
// CHECK: static func galileo<T>(_: T)
Expand Down
20 changes: 20 additions & 0 deletions test/Interop/Cxx/templates/function-template.swift
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,26 @@ FunctionTemplateTestSuite.test("constLvalueReferenceToBool<T> where T == Bool")
expectFalse(constLvalueReferenceToBool(false))
}

var nilPtr: UnsafeMutablePointer<CInt>? = nil
var nilPtrIOU: UnsafeMutablePointer<CInt>! = nil
var nonNilPtr: UnsafeMutablePointer<CInt> = .init(bitPattern: 123)!

FunctionTemplateTestSuite.test("pointerTemplateParameter<T>") {
expectFalse(pointerTemplateParameter(nilPtr))
expectFalse(pointerTemplateParameter(nilPtrIOU))
expectTrue(pointerTemplateParameter(nonNilPtr))
}

FunctionTemplateTestSuite.test("pointerTemplateParameterNonnull<T>") {
expectTrue(pointerTemplateParameterNonnull(nonNilPtr))
}

FunctionTemplateTestSuite.test("pointerTemplateParameterNullable<T>") {
expectFalse(pointerTemplateParameterNullable(nilPtr))
expectFalse(pointerTemplateParameterNullable(nilPtrIOU))
expectTrue(pointerTemplateParameterNullable(nonNilPtr))
}

// TODO: Generics, Any, and Protocols should be tested here but need to be
// better supported in ClangTypeConverter first.

Expand Down
18 changes: 7 additions & 11 deletions test/Interop/Cxx/templates/template-instantiation-irgen.swift
Original file line number Diff line number Diff line change
Expand Up @@ -46,21 +46,20 @@ func takesPtrToStruct(x: UnsafePointer<PlainStruct>) { takesValue(x) }
func takesPtrToClass(x: UnsafePointer<CxxClass>) { takesValue(x) }
// CHECK: define {{.*}} void @{{.*}}takesPtrToClass{{.*}}

// FIXME: this crashes because this round-trips to UnsafePointer<FRT?>
// func takesPtrToFRT(x: UnsafePointer<FRT>) { takesValue(x) }
func takesPtrToFRT(x: UnsafePointer<FRT>) { takesValue(x) }
// CHECK: define {{.*}} void @{{.*}}takesPtrToFRT{{.*}}

func takesMutPtrToStruct(x: UnsafeMutablePointer<PlainStruct>) { takesValue(x) }
// CHECK: define {{.*}} void @{{.*}}takesMutPtrToStruct{{.*}}

func takesMutPtrToClass(x: UnsafeMutablePointer<CxxClass>) { takesValue(x) }
// CHECK: define {{.*}} void @{{.*}}takesMutPtrToClass{{.*}}

// FIXME: this crashes because this round-trips to UnsafeMutablePointer<FRT?>
// func takesMutPtrToFRT(x: UnsafeMutablePointer<FRT>) { takesValue(x) }
func takesMutPtrToFRT(x: UnsafeMutablePointer<FRT>) { takesValue(x) }
// CHECK: define {{.*}} void @{{.*}}takesMutPtrToFRT{{.*}}

func takesCPtr() {
// FIXME: optional pointers are not yet supported but they should be; this crashes
// takesValue(intPtr)
takesValue(intPtr)

// It's fine if we dereference it, though
takesValue(intPtr!)
Expand Down Expand Up @@ -92,9 +91,6 @@ func takesSwiftClosureTakingCxxClass() { takesValue({(x: CxxClass) in takesValue
func takesTakesCxxClass() { takesValue(takesCxxClass) }

func takesSwiftClosureReturningFRT() { takesValue({() -> FRT in FRT()}) }
func takesSwiftClosureTakingFRT() { takesValue({(x: FRT) in takesValue(x)}) }

// FIXME: this crashes due to pointer round-tripping
// func takesSwiftClosureTakingFRT() { takesValue({(x: FRT) in takesValue(x)}) }

// FIXME: this crashes due to pointer round-tripping
// func takesTakesFRT() { takesValue(takesFRT) }
func takesTakesFRT() { takesValue(takesFRT) }