|
22 | 22 | #include <map>
|
23 | 23 | #include <sstream>
|
24 | 24 |
|
| 25 | +template <typename T> |
| 26 | +static Fortran::semantics::Scope *GetScope( |
| 27 | + Fortran::semantics::SemanticsContext &context, const T &x) { |
| 28 | + std::optional<Fortran::parser::CharBlock> source{GetSource(x)}; |
| 29 | + return source ? &context.FindScope(*source) : nullptr; |
| 30 | +} |
| 31 | + |
25 | 32 | namespace Fortran::semantics {
|
26 | 33 |
|
27 | 34 | template <typename T> class DirectiveAttributeVisitor {
|
@@ -324,11 +331,6 @@ class OmpAttributeVisitor : DirectiveAttributeVisitor<llvm::omp::Directive> {
|
324 | 331 | return true;
|
325 | 332 | }
|
326 | 333 |
|
327 |
| - bool Pre(const parser::SpecificationPart &x) { |
328 |
| - Walk(std::get<std::list<parser::OpenMPDeclarativeConstruct>>(x.t)); |
329 |
| - return true; |
330 |
| - } |
331 |
| - |
332 | 334 | bool Pre(const parser::StmtFunctionStmt &x) {
|
333 | 335 | const auto &parsedExpr{std::get<parser::Scalar<parser::Expr>>(x.t)};
|
334 | 336 | if (const auto *expr{GetExpr(context_, parsedExpr)}) {
|
@@ -375,7 +377,38 @@ class OmpAttributeVisitor : DirectiveAttributeVisitor<llvm::omp::Directive> {
|
375 | 377 | void Post(const parser::OpenMPDeclareSimdConstruct &) { PopContext(); }
|
376 | 378 |
|
377 | 379 | bool Pre(const parser::OpenMPRequiresConstruct &x) {
|
| 380 | + using Flags = WithOmpDeclarative::RequiresFlags; |
| 381 | + using Requires = WithOmpDeclarative::RequiresFlag; |
378 | 382 | PushContext(x.source, llvm::omp::Directive::OMPD_requires);
|
| 383 | + |
| 384 | + // Gather information from the clauses. |
| 385 | + Flags flags; |
| 386 | + std::optional<common::OmpAtomicDefaultMemOrderType> memOrder; |
| 387 | + for (const auto &clause : std::get<parser::OmpClauseList>(x.t).v) { |
| 388 | + flags |= common::visit( |
| 389 | + common::visitors{ |
| 390 | + [&memOrder]( |
| 391 | + const parser::OmpClause::AtomicDefaultMemOrder &atomic) { |
| 392 | + memOrder = atomic.v.v; |
| 393 | + return Flags{}; |
| 394 | + }, |
| 395 | + [](const parser::OmpClause::ReverseOffload &) { |
| 396 | + return Flags{Requires::ReverseOffload}; |
| 397 | + }, |
| 398 | + [](const parser::OmpClause::UnifiedAddress &) { |
| 399 | + return Flags{Requires::UnifiedAddress}; |
| 400 | + }, |
| 401 | + [](const parser::OmpClause::UnifiedSharedMemory &) { |
| 402 | + return Flags{Requires::UnifiedSharedMemory}; |
| 403 | + }, |
| 404 | + [](const parser::OmpClause::DynamicAllocators &) { |
| 405 | + return Flags{Requires::DynamicAllocators}; |
| 406 | + }, |
| 407 | + [](const auto &) { return Flags{}; }}, |
| 408 | + clause.u); |
| 409 | + } |
| 410 | + // Merge clauses into parents' symbols details. |
| 411 | + AddOmpRequiresToScope(currScope(), flags, memOrder); |
379 | 412 | return true;
|
380 | 413 | }
|
381 | 414 | void Post(const parser::OpenMPRequiresConstruct &) { PopContext(); }
|
@@ -672,6 +705,9 @@ class OmpAttributeVisitor : DirectiveAttributeVisitor<llvm::omp::Directive> {
|
672 | 705 |
|
673 | 706 | bool HasSymbolInEnclosingScope(const Symbol &, Scope &);
|
674 | 707 | std::int64_t ordCollapseLevel{0};
|
| 708 | + |
| 709 | + void AddOmpRequiresToScope(Scope &, WithOmpDeclarative::RequiresFlags, |
| 710 | + std::optional<common::OmpAtomicDefaultMemOrderType>); |
675 | 711 | };
|
676 | 712 |
|
677 | 713 | template <typename T>
|
@@ -2175,6 +2211,77 @@ void ResolveOmpParts(
|
2175 | 2211 | }
|
2176 | 2212 | }
|
2177 | 2213 |
|
| 2214 | +void ResolveOmpTopLevelParts( |
| 2215 | + SemanticsContext &context, const parser::Program &program) { |
| 2216 | + if (!context.IsEnabled(common::LanguageFeature::OpenMP)) { |
| 2217 | + return; |
| 2218 | + } |
| 2219 | + |
| 2220 | + // Gather REQUIRES clauses from all non-module top-level program unit symbols, |
| 2221 | + // combine them together ensuring compatibility and apply them to all these |
| 2222 | + // program units. Modules are skipped because their REQUIRES clauses should be |
| 2223 | + // propagated via USE statements instead. |
| 2224 | + WithOmpDeclarative::RequiresFlags combinedFlags; |
| 2225 | + std::optional<common::OmpAtomicDefaultMemOrderType> combinedMemOrder; |
| 2226 | + |
| 2227 | + // Function to go through non-module top level program units and extract |
| 2228 | + // REQUIRES information to be processed by a function-like argument. |
| 2229 | + auto processProgramUnits{[&](auto processFn) { |
| 2230 | + for (const parser::ProgramUnit &unit : program.v) { |
| 2231 | + if (!std::holds_alternative<common::Indirection<parser::Module>>( |
| 2232 | + unit.u) && |
| 2233 | + !std::holds_alternative<common::Indirection<parser::Submodule>>( |
| 2234 | + unit.u)) { |
| 2235 | + Symbol *symbol{common::visit( |
| 2236 | + [&context]( |
| 2237 | + auto &x) { return GetScope(context, x.value())->symbol(); }, |
| 2238 | + unit.u)}; |
| 2239 | + |
| 2240 | + common::visit( |
| 2241 | + [&](auto &details) { |
| 2242 | + if constexpr (std::is_convertible_v<decltype(&details), |
| 2243 | + WithOmpDeclarative *>) { |
| 2244 | + processFn(*symbol, details); |
| 2245 | + } |
| 2246 | + }, |
| 2247 | + symbol->details()); |
| 2248 | + } |
| 2249 | + } |
| 2250 | + }}; |
| 2251 | + |
| 2252 | + // Combine global REQUIRES information from all program units except modules |
| 2253 | + // and submodules. |
| 2254 | + processProgramUnits([&](Symbol &symbol, WithOmpDeclarative &details) { |
| 2255 | + if (const WithOmpDeclarative::RequiresFlags * |
| 2256 | + flags{details.ompRequires()}) { |
| 2257 | + combinedFlags |= *flags; |
| 2258 | + } |
| 2259 | + if (const common::OmpAtomicDefaultMemOrderType * |
| 2260 | + memOrder{details.ompAtomicDefaultMemOrder()}) { |
| 2261 | + if (combinedMemOrder && *combinedMemOrder != *memOrder) { |
| 2262 | + context.Say(symbol.scope()->sourceRange(), |
| 2263 | + "Conflicting '%s' REQUIRES clauses found in compilation " |
| 2264 | + "unit"_err_en_US, |
| 2265 | + parser::ToUpperCaseLetters(llvm::omp::getOpenMPClauseName( |
| 2266 | + llvm::omp::Clause::OMPC_atomic_default_mem_order) |
| 2267 | + .str())); |
| 2268 | + } |
| 2269 | + combinedMemOrder = *memOrder; |
| 2270 | + } |
| 2271 | + }); |
| 2272 | + |
| 2273 | + // Update all program units except modules and submodules with the combined |
| 2274 | + // global REQUIRES information. |
| 2275 | + processProgramUnits([&](Symbol &, WithOmpDeclarative &details) { |
| 2276 | + if (combinedFlags.any()) { |
| 2277 | + details.set_ompRequires(combinedFlags); |
| 2278 | + } |
| 2279 | + if (combinedMemOrder) { |
| 2280 | + details.set_ompAtomicDefaultMemOrder(*combinedMemOrder); |
| 2281 | + } |
| 2282 | + }); |
| 2283 | +} |
| 2284 | + |
2178 | 2285 | void OmpAttributeVisitor::CheckDataCopyingClause(
|
2179 | 2286 | const parser::Name &name, const Symbol &symbol, Symbol::Flag ompFlag) {
|
2180 | 2287 | const auto *checkSymbol{&symbol};
|
@@ -2322,4 +2429,44 @@ void OmpAttributeVisitor::CheckNameInAllocateStmt(
|
2322 | 2429 | parser::ToUpperCaseLetters(
|
2323 | 2430 | llvm::omp::getOpenMPDirectiveName(GetContext().directive).str()));
|
2324 | 2431 | }
|
| 2432 | + |
| 2433 | +void OmpAttributeVisitor::AddOmpRequiresToScope(Scope &scope, |
| 2434 | + WithOmpDeclarative::RequiresFlags flags, |
| 2435 | + std::optional<common::OmpAtomicDefaultMemOrderType> memOrder) { |
| 2436 | + Scope *scopeIter = &scope; |
| 2437 | + do { |
| 2438 | + if (Symbol * symbol{scopeIter->symbol()}) { |
| 2439 | + common::visit( |
| 2440 | + [&](auto &details) { |
| 2441 | + // Store clauses information into the symbol for the parent and |
| 2442 | + // enclosing modules, programs, functions and subroutines. |
| 2443 | + if constexpr (std::is_convertible_v<decltype(&details), |
| 2444 | + WithOmpDeclarative *>) { |
| 2445 | + if (flags.any()) { |
| 2446 | + if (const WithOmpDeclarative::RequiresFlags * |
| 2447 | + otherFlags{details.ompRequires()}) { |
| 2448 | + flags |= *otherFlags; |
| 2449 | + } |
| 2450 | + details.set_ompRequires(flags); |
| 2451 | + } |
| 2452 | + if (memOrder) { |
| 2453 | + if (details.has_ompAtomicDefaultMemOrder() && |
| 2454 | + *details.ompAtomicDefaultMemOrder() != *memOrder) { |
| 2455 | + context_.Say(scopeIter->sourceRange(), |
| 2456 | + "Conflicting '%s' REQUIRES clauses found in compilation " |
| 2457 | + "unit"_err_en_US, |
| 2458 | + parser::ToUpperCaseLetters(llvm::omp::getOpenMPClauseName( |
| 2459 | + llvm::omp::Clause::OMPC_atomic_default_mem_order) |
| 2460 | + .str())); |
| 2461 | + } |
| 2462 | + details.set_ompAtomicDefaultMemOrder(*memOrder); |
| 2463 | + } |
| 2464 | + } |
| 2465 | + }, |
| 2466 | + symbol->details()); |
| 2467 | + } |
| 2468 | + scopeIter = &scopeIter->parent(); |
| 2469 | + } while (!scopeIter->IsGlobal()); |
| 2470 | +} |
| 2471 | + |
2325 | 2472 | } // namespace Fortran::semantics
|
0 commit comments