Skip to content

Commit

Permalink
[SYCLomatic] Refine the migration of vector type operator function #2099
Browse files Browse the repository at this point in the history


Signed-off-by: intwanghao <[email protected]>
  • Loading branch information
intwanghao authored Jul 5, 2024
1 parent 7c5de2a commit 6ea742c
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 8 deletions.
33 changes: 27 additions & 6 deletions clang/lib/DPCT/ASTTraversal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2889,9 +2889,16 @@ void VectorTypeOperatorRule::registerMatcher(MatchFinder &MF) {
this);

// Matches call of user overloaded operator
MF.addMatcher(cxxOperatorCallExpr(callee(vectorTypeOverLoadedOperator()))
.bind("callOverloadedOperator"),
MF.addMatcher(cxxOperatorCallExpr(callee(vectorTypeOverLoadedOperator()),
hasAncestor(vectorTypeOverLoadedOperator()))
.bind("callOverloadedOperatorInOverloadedOperator"),
this);

MF.addMatcher(
cxxOperatorCallExpr(callee(vectorTypeOverLoadedOperator()),
unless(hasAncestor(vectorTypeOverLoadedOperator())))
.bind("callOverloadedOperatorNotInOverloadedOperator"),
this);
}

const char VectorTypeOperatorRule::NamespaceName[] =
Expand Down Expand Up @@ -2973,10 +2980,15 @@ void VectorTypeOperatorRule::MigrateOverloadedOperatorDecl(
}

void VectorTypeOperatorRule::MigrateOverloadedOperatorCall(
const MatchFinder::MatchResult &Result, const CXXOperatorCallExpr *CE) {
const MatchFinder::MatchResult &Result, const CXXOperatorCallExpr *CE,
bool InOverloadedOperator) {
if (!CE)
return;

if (!InOverloadedOperator &&
(DpctGlobalInfo::findAncestor<FunctionTemplateDecl>(CE) ||
DpctGlobalInfo::findAncestor<ClassTemplateDecl>(CE))) {
return;
}
// Explicitly call user overloaded operator
//
// For non-assignment operator:
Expand Down Expand Up @@ -3010,8 +3022,17 @@ void VectorTypeOperatorRule::runRule(const MatchFinder::MatchResult &Result) {
Result, getNodeAsType<FunctionDecl>(Result, "overloadedOperatorDecl"));

// Explicitly call user overloaded operator
MigrateOverloadedOperatorCall(Result, getNodeAsType<CXXOperatorCallExpr>(
Result, "callOverloadedOperator"));
MigrateOverloadedOperatorCall(
Result,
getNodeAsType<CXXOperatorCallExpr>(
Result, "callOverloadedOperatorInOverloadedOperator"),
true);

MigrateOverloadedOperatorCall(
Result,
getNodeAsType<CXXOperatorCallExpr>(
Result, "callOverloadedOperatorNotInOverloadedOperator"),
false);
}

REGISTER_RULE(VectorTypeOperatorRule, PassKind::PK_Migration)
Expand Down
2 changes: 1 addition & 1 deletion clang/lib/DPCT/ASTTraversal.h
Original file line number Diff line number Diff line change
Expand Up @@ -574,7 +574,7 @@ class VectorTypeOperatorRule
const FunctionDecl *FD);
void MigrateOverloadedOperatorCall(
const ast_matchers::MatchFinder::MatchResult &Result,
const CXXOperatorCallExpr *CE);
const CXXOperatorCallExpr *CE, bool InOverloadedOperator);

private:
static const char NamespaceName[];
Expand Down
39 changes: 38 additions & 1 deletion clang/test/dpct/double2_overloaded_operator.cu
Original file line number Diff line number Diff line change
Expand Up @@ -520,4 +520,41 @@ void foo(){
int2 i2;
A2 a;
a - i2;
}
}

inline __device__ float2 operator+(const float2 & a, const float2 & b) {
return {a.x + b.x, a.y + b.y};
}

// CHECK: template<typename T>
// CHECK: struct Sum {
// CHECK: inline Sum() {}
// CHECK: inline T operator()(const T &a, const T &b) const {
// CHECK: return a + b;
// CHECK: }
// CHECK: };
template<typename T>
struct Sum {
inline __device__ Sum() {}
inline __device__ T operator()(const T &a, const T &b) const {
return a + b;
}
};

// CHECK: template <typename T>
// CHECK: void bar() {
// CHECK: T a, b, c;
// CHECK: c = a + b;
// CHECK: }
template <typename T>
__device__ void bar() {
T a, b, c;
c = a + b;
}

__global__ void kernel() {
bar<float2>();
bar<float>();
Sum<float2> a;
Sum<float> b;
}

0 comments on commit 6ea742c

Please sign in to comment.