diff --git a/tools/swift-plugin-server/Sources/CSwiftPluginServer/PluginServer.cpp b/tools/swift-plugin-server/Sources/CSwiftPluginServer/PluginServer.cpp index 707b40a5d182c..b84bdeaa173f7 100644 --- a/tools/swift-plugin-server/Sources/CSwiftPluginServer/PluginServer.cpp +++ b/tools/swift-plugin-server/Sources/CSwiftPluginServer/PluginServer.cpp @@ -13,11 +13,17 @@ #include "PluginServer.h" #include "swift/ABI/MetadataValues.h" #include "swift/Demangling/Demangle.h" +#include "llvm/Support/DynamicLibrary.h" +#if defined(_WIN32) +#include +#elif defined(__unix__) || defined(__APPLE__) #include +#include +#endif + #include #include -#include using namespace swift; @@ -32,6 +38,46 @@ struct ConnectionHandle { } // namespace const void *PluginServer_createConnection(const char **errorMessage) { +#if defined(_WIN32) + struct unique_fd { + unique_fd(int fd) : fd_(fd) {} + unique_fd(const unique_fd &) = delete; + unique_fd &operator=(const unique_fd &) = delete; + unique_fd &operator=(unique_fd &&) = delete; + unique_fd(unique_fd &&uf) : fd_(uf.fd_) { uf.fd_ = -1; } + ~unique_fd() { if (fd_ > 0) _close(fd_); } + + int operator*() const { return fd_; } + int release() { int fd = fd_; fd_ = -1; return fd; } + + private: + int fd_; + }; + + unique_fd ifd{_dup(_fileno(stdin))}; + if (*ifd < 0) { + *errorMessage = _strerror(nullptr); + return nullptr; + } + + if (_close(_fileno(stdin)) < 0) { + *errorMessage = _strerror(nullptr); + return nullptr; + } + + unique_fd ofd{_dup(_fileno(stdout))}; + if (*ofd < 0) { + *errorMessage = _strerror(nullptr); + return nullptr; + } + + if (_dup2(_fileno(stderr), _fileno(stdout)) < 0) { + *errorMessage = _strerror(nullptr); + return nullptr; + } + + return new ConnectionHandle(ifd.release(), ofd.release()); +#else // Duplicate the `stdin` file descriptor, which we will then use for // receiving messages from the plugin host. auto inputFD = dup(STDIN_FILENO); @@ -65,37 +111,48 @@ const void *PluginServer_createConnection(const char **errorMessage) { // Open a message channel for communicating with the plugin host. return new ConnectionHandle(inputFD, outputFD); +#endif } -void PluginServer_destroyConnection(const void *connHandle) { - const auto *conn = static_cast(connHandle); - delete conn; +void PluginServer_destroyConnection(const void *server) { + delete static_cast(server); } -long PluginServer_read(const void *connHandle, void *data, - unsigned long nbyte) { - const auto *conn = static_cast(connHandle); - return ::read(conn->inputFD, data, nbyte); +size_t PluginServer_read(const void *server, void *data, size_t nbyte) { + const auto *connection = static_cast(server); +#if defined(_WIN32) + return _read(connection->inputFD, data, nbyte); +#else + return ::read(connection->inputFD, data, nbyte); +#endif } -long PluginServer_write(const void *connHandle, const void *data, - unsigned long nbyte) { - const auto *conn = static_cast(connHandle); - return ::write(conn->outputFD, data, nbyte); +size_t PluginServer_write(const void *server, const void *data, size_t nbyte) { + const auto *connection = static_cast(server); +#if defined(_WIN32) + return _write(connection->outputFD, data, nbyte); +#else + return ::write(connection->outputFD, data, nbyte); +#endif } -void *PluginServer_dlopen(const char *filename, const char **errorMessage) { - auto *handle = ::dlopen(filename, RTLD_LAZY | RTLD_LOCAL); - if (!handle) { - *errorMessage = dlerror(); - } - return handle; +void *PluginServer_load(const char *plugin, const char **errorMessage) { + // Use a static allocation for the error as the client will not release the + // string. POSIX 2008 (IEEE-1003.1-2008) specifies that it is implementation + // defined if `dlerror` is re-entrant. Take advantage of that and make it + // thread-unsafe. This ensures that the string outlives the call permitting + // the client to duplicate it. + static std::string error; + auto library = llvm::sys::DynamicLibrary::getLibrary(plugin, &error); + if (library.isValid()) + return library.getOSSpecificHandle(); + *errorMessage = error.c_str(); + return nullptr; } const void *PluginServer_lookupMacroTypeMetadataByExternalName( const char *moduleName, const char *typeName, void *libraryHint, const char **errorMessage) { - // Look up the type metadata accessor as a struct, enum, or class. const Demangle::Node::Kind typeKinds[] = { Demangle::Node::Kind::Structure, @@ -108,8 +165,12 @@ const void *PluginServer_lookupMacroTypeMetadataByExternalName( auto symbolName = mangledNameForTypeMetadataAccessor(moduleName, typeName, typeKind); - auto *handle = libraryHint ? libraryHint : RTLD_DEFAULT; - accessorAddr = ::dlsym(handle, symbolName.c_str()); +#if !defined(_WIN32) + if (libraryHint == nullptr) + libraryHint = RTLD_DEFAULT; +#endif + accessorAddr = llvm::sys::DynamicLibrary{libraryHint} + .getAddressOfSymbol(symbolName.c_str()); if (accessorAddr) break; } diff --git a/tools/swift-plugin-server/Sources/CSwiftPluginServer/include/PluginServer.h b/tools/swift-plugin-server/Sources/CSwiftPluginServer/include/PluginServer.h index dba78fedf8139..92e62c3aa7bc4 100644 --- a/tools/swift-plugin-server/Sources/CSwiftPluginServer/include/PluginServer.h +++ b/tools/swift-plugin-server/Sources/CSwiftPluginServer/include/PluginServer.h @@ -13,6 +13,9 @@ #ifndef SWIFT_PLUGINSERVER_PLUGINSERVER_H #define SWIFT_PLUGINSERVER_PLUGINSERVER_H +#include +#include + #ifdef __cplusplus extern "C" { #endif @@ -29,18 +32,17 @@ const void *PluginServer_createConnection(const char **errorMessage); void PluginServer_destroyConnection(const void *connHandle); /// Read bytes from the IPC communication handle. -long PluginServer_read(const void *connHandle, void *data, unsigned long nbyte); +size_t PluginServer_read(const void *connHandle, void *data, size_t nbyte); /// Write bytes to the IPC communication handle. -long PluginServer_write(const void *connHandle, const void *data, - unsigned long nbyte); +size_t PluginServer_write(const void *connHandle, const void *data, size_t nbyte); //===----------------------------------------------------------------------===// // Dynamic link //===----------------------------------------------------------------------===// /// Load a dynamic link library, and return the handle. -void *PluginServer_dlopen(const char *filename, const char **errorMessage); +void *PluginServer_load(const char *filename, const char **errorMessage); /// Resolve a type metadata by a pair of the module name and the type name. /// 'libraryHint' is a diff --git a/tools/swift-plugin-server/Sources/swift-plugin-server/swift-plugin-server.swift b/tools/swift-plugin-server/Sources/swift-plugin-server/swift-plugin-server.swift index e1f8bae8b898a..8c06e6dcd5d5a 100644 --- a/tools/swift-plugin-server/Sources/swift-plugin-server/swift-plugin-server.swift +++ b/tools/swift-plugin-server/Sources/swift-plugin-server/swift-plugin-server.swift @@ -47,7 +47,7 @@ extension SwiftPluginServer: PluginProvider { /// Load a macro implementation from the dynamic link library. func loadPluginLibrary(libraryPath: String, moduleName: String) throws { var errorMessage: UnsafePointer? - guard let dlHandle = PluginServer_dlopen(libraryPath, &errorMessage) else { + guard let dlHandle = PluginServer_load(libraryPath, &errorMessage) else { throw PluginServerError(message: String(cString: errorMessage!)) } loadedLibraryPlugins[moduleName] = dlHandle @@ -172,7 +172,7 @@ final class PluginHostConnection: MessageConnection { var ptr = buffer.baseAddress! while (bytesToWrite > 0) { - let writtenSize = PluginServer_write(handle, ptr, UInt(bytesToWrite)) + let writtenSize = PluginServer_write(handle, ptr, Int(bytesToWrite)) if (writtenSize <= 0) { // error e.g. broken pipe. break @@ -193,7 +193,7 @@ final class PluginHostConnection: MessageConnection { var ptr = buffer.baseAddress! while bytesToRead > 0 { - let readSize = PluginServer_read(handle, ptr, UInt(bytesToRead)) + let readSize = PluginServer_read(handle, ptr, Int(bytesToRead)) if (readSize <= 0) { // 0: EOF (the host closed), -1: Broken pipe (the host crashed?) break;