-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[flang][OpenMP] Make object identity more precise #94495
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
The object identity requires more than just `Symbol`. Don't use `id()` to get the Symbol associated with the object, becase the return value will need to change. Instead use `sym()` which is added for that reason.
Derived type components may use a given `Symbol` regardless of what parent objects they are a part of. Because of that, simply using a symbol address is not sufficient to determine object identity. Make the designator a part of the IdTy. To compare identities, when symbols are equal (and non-null), compare the designators.
@llvm/pr-subscribers-flang-openmp @llvm/pr-subscribers-flang-fir-hlfir Author: Krzysztof Parzyszek (kparzysz) ChangesDerived type components may use a given Make the designator a part of the IdTy. To compare identities, when symbols are equal (and non-null), compare the designators. Full diff: https://github.com/llvm/llvm-project/pull/94495.diff 3 Files Affected:
diff --git a/flang/lib/Lower/OpenMP/Clauses.h b/flang/lib/Lower/OpenMP/Clauses.h
index f7cd0ea83ad12..98fb5dcf7722e 100644
--- a/flang/lib/Lower/OpenMP/Clauses.h
+++ b/flang/lib/Lower/OpenMP/Clauses.h
@@ -36,30 +36,64 @@ struct TypeTy : public evaluate::SomeType {
bool operator==(const TypeTy &t) const { return true; }
};
-using IdTy = semantics::Symbol *;
+template <typename ExprTy>
+struct IdTyTemplate {
+ // "symbol" is always non-null for id's of actual objects.
+ Fortran::semantics::Symbol *symbol;
+ std::optional<ExprTy> designator;
+
+ bool operator==(const IdTyTemplate &other) const {
+ // If symbols are different, then the objects are different.
+ if (symbol != other.symbol)
+ return false;
+ if (symbol == nullptr)
+ return true;
+ // Equal symbols don't necessarily indicate identical objects,
+ // for example, a derived object component may use a single symbol,
+ // which will refer to different objects for different designators,
+ // e.g. a%c and b%c.
+ return designator == other.designator;
+ }
+
+ operator bool() const { return symbol != nullptr; }
+};
+
using ExprTy = SomeExpr;
template <typename T>
using List = tomp::ListT<T>;
} // namespace Fortran::lower::omp
+// Specialization of the ObjectT template
namespace tomp::type {
template <>
-struct ObjectT<Fortran::lower::omp::IdTy, Fortran::lower::omp::ExprTy> {
- using IdTy = Fortran::lower::omp::IdTy;
+struct ObjectT<Fortran::lower::omp::IdTyTemplate<Fortran::lower::omp::ExprTy>,
+ Fortran::lower::omp::ExprTy> {
+ using IdTy = Fortran::lower::omp::IdTyTemplate<Fortran::lower::omp::ExprTy>;
using ExprTy = Fortran::lower::omp::ExprTy;
- IdTy id() const { return symbol; }
- Fortran::semantics::Symbol *sym() const { return symbol; }
- const std::optional<ExprTy> &ref() const { return designator; }
+ IdTy id() const { return identity; }
+ Fortran::semantics::Symbol *sym() const { return identity.symbol; }
+ const std::optional<ExprTy> &ref() const { return identity.designator; }
- IdTy symbol;
- std::optional<ExprTy> designator;
+ IdTy identity;
};
} // namespace tomp::type
namespace Fortran::lower::omp {
+using IdTy = IdTyTemplate<ExprTy>;
+}
+namespace std {
+template <>
+struct hash<Fortran::lower::omp::IdTy> {
+ size_t operator()(const Fortran::lower::omp::IdTy &id) const {
+ return static_cast<size_t>(reinterpret_cast<uintptr_t>(id.symbol));
+ }
+};
+} // namespace std
+
+namespace Fortran::lower::omp {
using Object = tomp::ObjectT<IdTy, ExprTy>;
using ObjectList = tomp::ObjectListT<IdTy, ExprTy>;
diff --git a/flang/lib/Lower/OpenMP/Utils.cpp b/flang/lib/Lower/OpenMP/Utils.cpp
index eff915f569f27..da94352a84a7c 100644
--- a/flang/lib/Lower/OpenMP/Utils.cpp
+++ b/flang/lib/Lower/OpenMP/Utils.cpp
@@ -188,7 +188,7 @@ void addChildIndexAndMapToParent(
std::map<const semantics::Symbol *,
llvm::SmallVector<OmpMapMemberIndicesData>> &parentMemberIndices,
mlir::omp::MapInfoOp &mapOp, semantics::SemanticsContext &semaCtx) {
- std::optional<evaluate::DataRef> dataRef = ExtractDataRef(object.designator);
+ std::optional<evaluate::DataRef> dataRef = ExtractDataRef(object.ref());
assert(dataRef.has_value() &&
"DataRef could not be extracted during mapping of derived type "
"cannot proceed");
diff --git a/flang/test/Lower/OpenMP/map-component-ref.f90 b/flang/test/Lower/OpenMP/map-component-ref.f90
index 2c582667f38d3..21b56ab303acd 100644
--- a/flang/test/Lower/OpenMP/map-component-ref.f90
+++ b/flang/test/Lower/OpenMP/map-component-ref.f90
@@ -1,21 +1,22 @@
! RUN: %flang_fc1 -emit-hlfir -fopenmp %s -o - | FileCheck %s
! RUN: bbc -fopenmp -emit-hlfir %s -o - | FileCheck %s
-! CHECK: %[[V0:[0-9]+]] = fir.alloca !fir.type<_QFfooTt0{a0:i32,a1:i32}> {bindc_name = "a", uniq_name = "_QFfooEa"}
-! CHECK: %[[V1:[0-9]+]]:2 = hlfir.declare %[[V0]] {uniq_name = "_QFfooEa"} : (!fir.ref<!fir.type<_QFfooTt0{a0:i32,a1:i32}>>) -> (!fir.ref<!fir.type<_QFfooTt0{a0:i32,a1:i32}>>, !fir.ref<!fir.type<_QFfooTt0{a0:i32,a1:i32}>>)
-! CHECK: %[[V2:[0-9]+]] = hlfir.designate %[[V1]]#0{"a1"} : (!fir.ref<!fir.type<_QFfooTt0{a0:i32,a1:i32}>>) -> !fir.ref<i32>
+! CHECK-LABEL: func.func @_QPfoo1
+! CHECK: %[[V0:[0-9]+]] = fir.alloca !fir.type<_QFfoo1Tt0{a0:i32,a1:i32}> {bindc_name = "a", uniq_name = "_QFfoo1Ea"}
+! CHECK: %[[V1:[0-9]+]]:2 = hlfir.declare %[[V0]] {uniq_name = "_QFfoo1Ea"} : (!fir.ref<!fir.type<_QFfoo1Tt0{a0:i32,a1:i32}>>) -> (!fir.ref<!fir.type<_QFfoo1Tt0{a0:i32,a1:i32}>>, !fir.ref<!fir.type<_QFfoo1Tt0{a0:i32,a1:i32}>>)
+! CHECK: %[[V2:[0-9]+]] = hlfir.designate %[[V1]]#0{"a1"} : (!fir.ref<!fir.type<_QFfoo1Tt0{a0:i32,a1:i32}>>) -> !fir.ref<i32>
! CHECK: %[[V3:[0-9]+]] = omp.map.info var_ptr(%[[V2]] : !fir.ref<i32>, i32) map_clauses(tofrom) capture(ByRef) -> !fir.ref<i32> {name = "a%a1"}
-! CHECK: %[[V4:[0-9]+]] = omp.map.info var_ptr(%[[V1]]#1 : !fir.ref<!fir.type<_QFfooTt0{a0:i32,a1:i32}>>, !fir.type<_QFfooTt0{a0:i32,a1:i32}>) map_clauses(tofrom) capture(ByRef) members(%[[V3]] : [1] : !fir.ref<i32>) -> !fir.ref<!fir.type<_QFfooTt0{a0:i32,a1:i32}>> {name = "a", partial_map = true}
-! CHECK: omp.target map_entries(%[[V3]] -> %arg0, %[[V4]] -> %arg1 : !fir.ref<i32>, !fir.ref<!fir.type<_QFfooTt0{a0:i32,a1:i32}>>) {
-! CHECK: ^bb0(%arg0: !fir.ref<i32>, %arg1: !fir.ref<!fir.type<_QFfooTt0{a0:i32,a1:i32}>>):
-! CHECK: %[[V5:[0-9]+]]:2 = hlfir.declare %arg1 {uniq_name = "_QFfooEa"} : (!fir.ref<!fir.type<_QFfooTt0{a0:i32,a1:i32}>>) -> (!fir.ref<!fir.type<_QFfooTt0{a0:i32,a1:i32}>>, !fir.ref<!fir.type<_QFfooTt0{a0:i32,a1:i32}>>)
+! CHECK: %[[V4:[0-9]+]] = omp.map.info var_ptr(%[[V1]]#1 : !fir.ref<!fir.type<_QFfoo1Tt0{a0:i32,a1:i32}>>, !fir.type<_QFfoo1Tt0{a0:i32,a1:i32}>) map_clauses(tofrom) capture(ByRef) members(%[[V3]] : [1] : !fir.ref<i32>) -> !fir.ref<!fir.type<_QFfoo1Tt0{a0:i32,a1:i32}>> {name = "a", partial_map = true}
+! CHECK: omp.target map_entries(%[[V3]] -> %arg0, %[[V4]] -> %arg1 : !fir.ref<i32>, !fir.ref<!fir.type<_QFfoo1Tt0{a0:i32,a1:i32}>>) {
+! CHECK: ^bb0(%arg0: !fir.ref<i32>, %arg1: !fir.ref<!fir.type<_QFfoo1Tt0{a0:i32,a1:i32}>>):
+! CHECK: %[[V5:[0-9]+]]:2 = hlfir.declare %arg1 {uniq_name = "_QFfoo1Ea"} : (!fir.ref<!fir.type<_QFfoo1Tt0{a0:i32,a1:i32}>>) -> (!fir.ref<!fir.type<_QFfoo1Tt0{a0:i32,a1:i32}>>, !fir.ref<!fir.type<_QFfoo1Tt0{a0:i32,a1:i32}>>)
! CHECK: %c0_i32 = arith.constant 0 : i32
-! CHECK: %[[V6:[0-9]+]] = hlfir.designate %[[V5]]#0{"a1"} : (!fir.ref<!fir.type<_QFfooTt0{a0:i32,a1:i32}>>) -> !fir.ref<i32>
+! CHECK: %[[V6:[0-9]+]] = hlfir.designate %[[V5]]#0{"a1"} : (!fir.ref<!fir.type<_QFfoo1Tt0{a0:i32,a1:i32}>>) -> !fir.ref<i32>
! CHECK: hlfir.assign %c0_i32 to %[[V6]] : i32, !fir.ref<i32>
! CHECK: omp.terminator
! CHECK: }
-subroutine foo()
+subroutine foo1()
implicit none
type t0
@@ -29,3 +30,25 @@ subroutine foo()
!$omp end target
end
+
+! CHECK-LABEL: func.func @_QPfoo2
+! CHECK-DAG: omp.map.info var_ptr(%{{[0-9]+}} : {{.*}} map_clauses(to) capture(ByRef) bounds(%{{[0-9]+}}) -> {{.*}} {name = "t%b(1_8)%a(1)"}
+! CHECK-DAG: omp.map.info var_ptr(%{{[0-9]+}} : {{.*}} map_clauses(from) capture(ByRef) bounds(%{{[0-9]+}}) -> {{.*}} {name = "u%b(1_8)%a(1)"}
+subroutine foo2()
+ implicit none
+
+ type t0
+ integer :: a(10)
+ end type
+
+ type t1
+ type(t0) :: b(10)
+ end type
+
+ type(t1) :: t, u
+
+!$omp target map(to: t%b(1)%a(1)) map(from: u%b(1)%a(1))
+ t%b(1)%a(1) = u%b(1)%a(1)
+!$omp end target
+
+end
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, and tested it locally and it seems to fix the issue I encountered! Thank you very much for this fix :-)
Derived type components may use a given
Symbol
regardless of what parent objects they are a part of. Because of that, simply using a symbol address is not sufficient to determine object identity.Make the designator a part of the IdTy. To compare identities, when symbols are equal (and non-null), compare the designators.