Skip to content

Commit e67e09a

Browse files
authored
[Flang][OpenMP][Sema] Adding parsing and semantic support for scan directive. (#102792)
1 parent fd5fcfb commit e67e09a

File tree

20 files changed

+349
-107
lines changed

20 files changed

+349
-107
lines changed

flang/include/flang/Semantics/openmp-directive-sets.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,12 @@ static const OmpDirectiveSet workShareSet{
290290
} | allDoSet,
291291
};
292292

293+
//===----------------------------------------------------------------------===//
294+
// Directive sets for parent directives that do allow/not allow a construct
295+
//===----------------------------------------------------------------------===//
296+
297+
static const OmpDirectiveSet scanParentAllowedSet{allDoSet | allSimdSet};
298+
293299
//===----------------------------------------------------------------------===//
294300
// Directive sets for allowed/not allowed nested directives
295301
//===----------------------------------------------------------------------===//

flang/include/flang/Semantics/symbol.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -755,7 +755,8 @@ class Symbol {
755755
OmpDeclarativeAllocateDirective, OmpExecutableAllocateDirective,
756756
OmpDeclareSimd, OmpDeclareTarget, OmpThreadprivate, OmpDeclareReduction,
757757
OmpFlushed, OmpCriticalLock, OmpIfSpecified, OmpNone, OmpPreDetermined,
758-
OmpImplicit, OmpDependObject);
758+
OmpImplicit, OmpDependObject, OmpInclusiveScan, OmpExclusiveScan,
759+
OmpInScanReduction);
759760
using Flags = common::EnumSet<Flag, Flag_enumSize>;
760761

761762
const Scope &owner() const { return *owner_; }

flang/lib/Lower/OpenMP/OpenMP.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2520,6 +2520,9 @@ static void genOMPDispatch(lower::AbstractConverter &converter,
25202520
case llvm::omp::Directive::OMPD_parallel:
25212521
genStandaloneParallel(converter, symTable, semaCtx, eval, loc, queue, item);
25222522
break;
2523+
case llvm::omp::Directive::OMPD_scan:
2524+
TODO(loc, "Unhandled directive " + llvm::omp::getOpenMPDirectiveName(dir));
2525+
break;
25232526
case llvm::omp::Directive::OMPD_section:
25242527
llvm_unreachable("genOMPDispatch: OMPD_section");
25252528
// Lowered in the enclosing genSectionsOp.

flang/lib/Parser/openmp-parsers.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -558,6 +558,8 @@ TYPE_PARSER(
558558
construct<OmpClause>(construct<OmpClause::DynamicAllocators>()) ||
559559
"ENTER" >> construct<OmpClause>(construct<OmpClause::Enter>(
560560
parenthesized(Parser<OmpObjectList>{}))) ||
561+
"EXCLUSIVE" >> construct<OmpClause>(construct<OmpClause::Exclusive>(
562+
parenthesized(Parser<OmpObjectList>{}))) ||
561563
"FILTER" >> construct<OmpClause>(construct<OmpClause::Filter>(
562564
parenthesized(scalarIntExpr))) ||
563565
"FINAL" >> construct<OmpClause>(construct<OmpClause::Final>(
@@ -577,6 +579,8 @@ TYPE_PARSER(
577579
"IF" >> construct<OmpClause>(construct<OmpClause::If>(
578580
parenthesized(Parser<OmpIfClause>{}))) ||
579581
"INBRANCH" >> construct<OmpClause>(construct<OmpClause::Inbranch>()) ||
582+
"INCLUSIVE" >> construct<OmpClause>(construct<OmpClause::Inclusive>(
583+
parenthesized(Parser<OmpObjectList>{}))) ||
580584
"IS_DEVICE_PTR" >> construct<OmpClause>(construct<OmpClause::IsDevicePtr>(
581585
parenthesized(Parser<OmpObjectList>{}))) ||
582586
"LASTPRIVATE" >> construct<OmpClause>(construct<OmpClause::Lastprivate>(
@@ -789,6 +793,7 @@ TYPE_PARSER(sourced(construct<OpenMPFlushConstruct>(verbatim("FLUSH"_tok),
789793
TYPE_PARSER(sourced(construct<OmpSimpleStandaloneDirective>(first(
790794
"BARRIER" >> pure(llvm::omp::Directive::OMPD_barrier),
791795
"ORDERED" >> pure(llvm::omp::Directive::OMPD_ordered),
796+
"SCAN" >> pure(llvm::omp::Directive::OMPD_scan),
792797
"TARGET ENTER DATA" >> pure(llvm::omp::Directive::OMPD_target_enter_data),
793798
"TARGET EXIT DATA" >> pure(llvm::omp::Directive::OMPD_target_exit_data),
794799
"TARGET UPDATE" >> pure(llvm::omp::Directive::OMPD_target_update),

flang/lib/Parser/unparse.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2393,6 +2393,9 @@ class UnparseVisitor {
23932393
case llvm::omp::Directive::OMPD_barrier:
23942394
Word("BARRIER ");
23952395
break;
2396+
case llvm::omp::Directive::OMPD_scan:
2397+
Word("SCAN ");
2398+
break;
23962399
case llvm::omp::Directive::OMPD_taskwait:
23972400
Word("TASKWAIT ");
23982401
break;

flang/lib/Semantics/check-omp-structure.cpp

Lines changed: 144 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "flang/Parser/parse-tree.h"
1313
#include "flang/Semantics/expression.h"
1414
#include "flang/Semantics/tools.h"
15+
#include <variant>
1516

1617
namespace Fortran::semantics {
1718

@@ -746,62 +747,69 @@ void OmpStructureChecker::CheckSIMDNest(const parser::OpenMPConstruct &c) {
746747
// current context yet.
747748
// TODO: Check for declare simd regions.
748749
bool eligibleSIMD{false};
749-
common::visit(Fortran::common::visitors{
750-
// Allow `!$OMP ORDERED SIMD`
751-
[&](const parser::OpenMPBlockConstruct &c) {
752-
const auto &beginBlockDir{
753-
std::get<parser::OmpBeginBlockDirective>(c.t)};
754-
const auto &beginDir{
755-
std::get<parser::OmpBlockDirective>(beginBlockDir.t)};
756-
if (beginDir.v == llvm::omp::Directive::OMPD_ordered) {
757-
const auto &clauses{
758-
std::get<parser::OmpClauseList>(beginBlockDir.t)};
759-
for (const auto &clause : clauses.v) {
760-
if (std::get_if<parser::OmpClause::Simd>(&clause.u)) {
761-
eligibleSIMD = true;
762-
break;
763-
}
764-
}
765-
}
766-
},
767-
[&](const parser::OpenMPSimpleStandaloneConstruct &c) {
768-
const auto &dir{
769-
std::get<parser::OmpSimpleStandaloneDirective>(c.t)};
770-
if (dir.v == llvm::omp::Directive::OMPD_ordered) {
771-
const auto &clauses{
772-
std::get<parser::OmpClauseList>(c.t)};
773-
for (const auto &clause : clauses.v) {
774-
if (std::get_if<parser::OmpClause::Simd>(&clause.u)) {
775-
eligibleSIMD = true;
776-
break;
777-
}
778-
}
779-
}
780-
},
781-
// Allowing SIMD construct
782-
[&](const parser::OpenMPLoopConstruct &c) {
783-
const auto &beginLoopDir{
784-
std::get<parser::OmpBeginLoopDirective>(c.t)};
785-
const auto &beginDir{
786-
std::get<parser::OmpLoopDirective>(beginLoopDir.t)};
787-
if ((beginDir.v == llvm::omp::Directive::OMPD_simd) ||
788-
(beginDir.v == llvm::omp::Directive::OMPD_do_simd)) {
789-
eligibleSIMD = true;
790-
}
791-
},
792-
[&](const parser::OpenMPAtomicConstruct &c) {
793-
// Allow `!$OMP ATOMIC`
794-
eligibleSIMD = true;
795-
},
796-
[&](const auto &c) {},
797-
},
750+
common::visit(
751+
Fortran::common::visitors{
752+
// Allow `!$OMP ORDERED SIMD`
753+
[&](const parser::OpenMPBlockConstruct &c) {
754+
const auto &beginBlockDir{
755+
std::get<parser::OmpBeginBlockDirective>(c.t)};
756+
const auto &beginDir{
757+
std::get<parser::OmpBlockDirective>(beginBlockDir.t)};
758+
if (beginDir.v == llvm::omp::Directive::OMPD_ordered) {
759+
const auto &clauses{
760+
std::get<parser::OmpClauseList>(beginBlockDir.t)};
761+
for (const auto &clause : clauses.v) {
762+
if (std::get_if<parser::OmpClause::Simd>(&clause.u)) {
763+
eligibleSIMD = true;
764+
break;
765+
}
766+
}
767+
}
768+
},
769+
[&](const parser::OpenMPStandaloneConstruct &c) {
770+
if (const auto &simpleConstruct =
771+
std::get_if<parser::OpenMPSimpleStandaloneConstruct>(
772+
&c.u)) {
773+
const auto &dir{std::get<parser::OmpSimpleStandaloneDirective>(
774+
simpleConstruct->t)};
775+
if (dir.v == llvm::omp::Directive::OMPD_ordered) {
776+
const auto &clauses{
777+
std::get<parser::OmpClauseList>(simpleConstruct->t)};
778+
for (const auto &clause : clauses.v) {
779+
if (std::get_if<parser::OmpClause::Simd>(&clause.u)) {
780+
eligibleSIMD = true;
781+
break;
782+
}
783+
}
784+
} else if (dir.v == llvm::omp::Directive::OMPD_scan) {
785+
eligibleSIMD = true;
786+
}
787+
}
788+
},
789+
// Allowing SIMD construct
790+
[&](const parser::OpenMPLoopConstruct &c) {
791+
const auto &beginLoopDir{
792+
std::get<parser::OmpBeginLoopDirective>(c.t)};
793+
const auto &beginDir{
794+
std::get<parser::OmpLoopDirective>(beginLoopDir.t)};
795+
if ((beginDir.v == llvm::omp::Directive::OMPD_simd) ||
796+
(beginDir.v == llvm::omp::Directive::OMPD_do_simd)) {
797+
eligibleSIMD = true;
798+
}
799+
},
800+
[&](const parser::OpenMPAtomicConstruct &c) {
801+
// Allow `!$OMP ATOMIC`
802+
eligibleSIMD = true;
803+
},
804+
[&](const auto &c) {},
805+
},
798806
c.u);
799807
if (!eligibleSIMD) {
800808
context_.Say(parser::FindSourceLocation(c),
801809
"The only OpenMP constructs that can be encountered during execution "
802810
"of a 'SIMD' region are the `ATOMIC` construct, the `LOOP` construct, "
803-
"the `SIMD` construct and the `ORDERED` construct with the `SIMD` "
804-
"clause."_err_en_US);
811+
"the `SIMD` construct, the `SCAN` construct and the `ORDERED` "
812+
"construct with the `SIMD` clause."_err_en_US);
805813
}
806814
}
807815

@@ -965,6 +973,49 @@ void OmpStructureChecker::CheckDistLinear(
965973
}
966974

967975
void OmpStructureChecker::Leave(const parser::OpenMPLoopConstruct &x) {
976+
const auto &beginLoopDir{std::get<parser::OmpBeginLoopDirective>(x.t)};
977+
const auto &clauseList{std::get<parser::OmpClauseList>(beginLoopDir.t)};
978+
979+
// A few semantic checks for InScan reduction are performed below as SCAN
980+
// constructs inside LOOP may add the relevant information. Scan reduction is
981+
// supported only in loop constructs, so same checks are not applicable to
982+
// other directives.
983+
for (const auto &clause : clauseList.v) {
984+
if (const auto *reductionClause{
985+
std::get_if<parser::OmpClause::Reduction>(&clause.u)}) {
986+
const auto &maybeModifier{
987+
std::get<std::optional<ReductionModifier>>(reductionClause->v.t)};
988+
if (maybeModifier && *maybeModifier == ReductionModifier::Inscan) {
989+
const auto &objectList{
990+
std::get<parser::OmpObjectList>(reductionClause->v.t)};
991+
auto checkReductionSymbolInScan = [&](const parser::Name *name) {
992+
if (auto &symbol = name->symbol) {
993+
if (!symbol->test(Symbol::Flag::OmpInclusiveScan) &&
994+
!symbol->test(Symbol::Flag::OmpExclusiveScan)) {
995+
context_.Say(name->source,
996+
"List item %s must appear in EXCLUSIVE or "
997+
"INCLUSIVE clause of an "
998+
"enclosed SCAN directive"_err_en_US,
999+
name->ToString());
1000+
}
1001+
}
1002+
};
1003+
for (const auto &ompObj : objectList.v) {
1004+
common::visit(
1005+
common::visitors{
1006+
[&](const parser::Designator &designator) {
1007+
if (const auto *name{semantics::getDesignatorNameIfDataRef(
1008+
designator)}) {
1009+
checkReductionSymbolInScan(name);
1010+
}
1011+
},
1012+
[&](const auto &name) { checkReductionSymbolInScan(&name); },
1013+
},
1014+
ompObj.u);
1015+
}
1016+
}
1017+
}
1018+
}
9681019
if (llvm::omp::allSimdSet.test(GetContext().directive)) {
9691020
ExitDirectiveNest(SIMDNest);
9701021
}
@@ -1652,19 +1703,32 @@ void OmpStructureChecker::Leave(const parser::OpenMPAllocatorsConstruct &x) {
16521703
dirContext_.pop_back();
16531704
}
16541705

1706+
void OmpStructureChecker::CheckScan(
1707+
const parser::OpenMPSimpleStandaloneConstruct &x) {
1708+
if (std::get<parser::OmpClauseList>(x.t).v.size() != 1) {
1709+
context_.Say(x.source,
1710+
"Exactly one of EXCLUSIVE or INCLUSIVE clause is expected"_err_en_US);
1711+
}
1712+
if (!CurrentDirectiveIsNested() ||
1713+
!llvm::omp::scanParentAllowedSet.test(GetContextParent().directive)) {
1714+
context_.Say(x.source,
1715+
"Orphaned SCAN directives are prohibited; perhaps you forgot "
1716+
"to enclose the directive in to a WORKSHARING LOOP, a WORKSHARING "
1717+
"LOOP SIMD or a SIMD directive."_err_en_US);
1718+
}
1719+
}
1720+
16551721
void OmpStructureChecker::CheckBarrierNesting(
16561722
const parser::OpenMPSimpleStandaloneConstruct &x) {
16571723
// A barrier region may not be `closely nested` inside a worksharing, loop,
16581724
// task, taskloop, critical, ordered, atomic, or master region.
16591725
// TODO: Expand the check to include `LOOP` construct as well when it is
16601726
// supported.
1661-
if (GetContext().directive == llvm::omp::Directive::OMPD_barrier) {
1662-
if (IsCloselyNestedRegion(llvm::omp::nestedBarrierErrSet)) {
1663-
context_.Say(parser::FindSourceLocation(x),
1664-
"`BARRIER` region may not be closely nested inside of `WORKSHARING`, "
1665-
"`LOOP`, `TASK`, `TASKLOOP`,"
1666-
"`CRITICAL`, `ORDERED`, `ATOMIC` or `MASTER` region."_err_en_US);
1667-
}
1727+
if (IsCloselyNestedRegion(llvm::omp::nestedBarrierErrSet)) {
1728+
context_.Say(parser::FindSourceLocation(x),
1729+
"`BARRIER` region may not be closely nested inside of `WORKSHARING`, "
1730+
"`LOOP`, `TASK`, `TASKLOOP`,"
1731+
"`CRITICAL`, `ORDERED`, `ATOMIC` or `MASTER` region."_err_en_US);
16681732
}
16691733
}
16701734

@@ -1848,7 +1912,16 @@ void OmpStructureChecker::Enter(
18481912
const parser::OpenMPSimpleStandaloneConstruct &x) {
18491913
const auto &dir{std::get<parser::OmpSimpleStandaloneDirective>(x.t)};
18501914
PushContextAndClauseSets(dir.source, dir.v);
1851-
CheckBarrierNesting(x);
1915+
switch (dir.v) {
1916+
case llvm::omp::Directive::OMPD_barrier:
1917+
CheckBarrierNesting(x);
1918+
break;
1919+
case llvm::omp::Directive::OMPD_scan:
1920+
CheckScan(x);
1921+
break;
1922+
default:
1923+
break;
1924+
}
18521925
}
18531926

18541927
void OmpStructureChecker::Leave(
@@ -2687,8 +2760,8 @@ CHECK_SIMPLE_CLAUSE(Full, OMPC_full)
26872760
CHECK_SIMPLE_CLAUSE(Grainsize, OMPC_grainsize)
26882761
CHECK_SIMPLE_CLAUSE(Hint, OMPC_hint)
26892762
CHECK_SIMPLE_CLAUSE(Holds, OMPC_holds)
2690-
CHECK_SIMPLE_CLAUSE(InReduction, OMPC_in_reduction)
26912763
CHECK_SIMPLE_CLAUSE(Inclusive, OMPC_inclusive)
2764+
CHECK_SIMPLE_CLAUSE(InReduction, OMPC_in_reduction)
26922765
CHECK_SIMPLE_CLAUSE(Match, OMPC_match)
26932766
CHECK_SIMPLE_CLAUSE(Nontemporal, OMPC_nontemporal)
26942767
CHECK_SIMPLE_CLAUSE(NumTasks, OMPC_num_tasks)
@@ -2781,7 +2854,11 @@ void OmpStructureChecker::Enter(const parser::OmpClause::Reduction &x) {
27812854
if (CheckReductionOperators(x)) {
27822855
CheckReductionTypeList(x);
27832856
}
2784-
CheckReductionModifier(x);
2857+
if (const auto &maybeModifier{
2858+
std::get<std::optional<ReductionModifier>>(x.v.t)}) {
2859+
const ReductionModifier modifier{*maybeModifier};
2860+
CheckReductionModifier(modifier);
2861+
}
27852862
}
27862863

27872864
bool OmpStructureChecker::CheckReductionOperators(
@@ -2824,6 +2901,7 @@ bool OmpStructureChecker::CheckReductionOperators(
28242901

28252902
return ok;
28262903
}
2904+
28272905
bool OmpStructureChecker::CheckIntrinsicOperator(
28282906
const parser::DefinedOperator::IntrinsicOperator &op) {
28292907

@@ -2958,14 +3036,11 @@ void OmpStructureChecker::CheckReductionTypeList(
29583036
}
29593037

29603038
void OmpStructureChecker::CheckReductionModifier(
2961-
const parser::OmpClause::Reduction &x) {
2962-
using ReductionModifier = parser::OmpReductionClause::ReductionModifier;
2963-
const auto &maybeModifier{std::get<std::optional<ReductionModifier>>(x.v.t)};
2964-
if (!maybeModifier || *maybeModifier == ReductionModifier::Default) {
2965-
// No modifier, or the default one is always ok.
3039+
const ReductionModifier &modifier) {
3040+
if (modifier == ReductionModifier::Default) {
3041+
// The default one is always ok.
29663042
return;
29673043
}
2968-
ReductionModifier modifier{*maybeModifier};
29693044
const DirectiveContext &dirCtx{GetContext()};
29703045
if (dirCtx.directive == llvm::omp::Directive::OMPD_loop) {
29713046
// [5.2:257:33-34]
@@ -2996,15 +3071,10 @@ void OmpStructureChecker::CheckReductionModifier(
29963071
// or "simd" directive.
29973072
// The worksharing-loop directives are OMPD_do and OMPD_for. Only the
29983073
// former is allowed in Fortran.
2999-
switch (dirCtx.directive) {
3000-
case llvm::omp::Directive::OMPD_do: // worksharing-loop
3001-
case llvm::omp::Directive::OMPD_do_simd: // worksharing-loop simd
3002-
case llvm::omp::Directive::OMPD_simd: // "simd"
3003-
break;
3004-
default:
3074+
if (!llvm::omp::scanParentAllowedSet.test(dirCtx.directive)) {
30053075
context_.Say(GetContext().clauseSource,
30063076
"Modifier 'INSCAN' on REDUCTION clause is only allowed with "
3007-
"worksharing-loop, worksharing-loop simd, "
3077+
"WORKSHARING LOOP, WORKSHARING LOOP SIMD, "
30083078
"or SIMD directive"_err_en_US);
30093079
}
30103080
} else {

flang/lib/Semantics/check-omp-structure.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ class OmpStructureChecker
7070
) {
7171
}
7272
using llvmOmpClause = const llvm::omp::Clause;
73+
using ReductionModifier = parser::OmpReductionClause::ReductionModifier;
7374

7475
void Enter(const parser::OpenMPConstruct &);
7576
void Leave(const parser::OpenMPConstruct &);
@@ -229,10 +230,11 @@ class OmpStructureChecker
229230
bool CheckIntrinsicOperator(
230231
const parser::DefinedOperator::IntrinsicOperator &);
231232
void CheckReductionTypeList(const parser::OmpClause::Reduction &);
232-
void CheckReductionModifier(const parser::OmpClause::Reduction &);
233+
void CheckReductionModifier(const ReductionModifier &);
233234
void CheckMasterNesting(const parser::OpenMPBlockConstruct &x);
234235
void ChecksOnOrderedAsBlock();
235236
void CheckBarrierNesting(const parser::OpenMPSimpleStandaloneConstruct &x);
237+
void CheckScan(const parser::OpenMPSimpleStandaloneConstruct &x);
236238
void ChecksOnOrderedAsStandalone();
237239
void CheckOrderedDependClause(std::optional<std::int64_t> orderedValue);
238240
void CheckReductionArraySection(const parser::OmpObjectList &ompObjectList);

0 commit comments

Comments
 (0)