Skip to content

[SYCL] Add template template parameter support #464

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 12, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 22 additions & 2 deletions clang/lib/Sema/SemaSYCL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1382,8 +1382,28 @@ void SYCLIntegrationHeader::emitForwardClassDecls(
}
break;
}
case TemplateArgument::ArgKind::Template:
llvm_unreachable("template template arguments not supported");
case TemplateArgument::ArgKind::Template: {
// recursion is not required, since the maximum possible nesting level
// equals two for template argument
//
// for example:
// template <typename T> class Bar;
// template <template <typename> class> class Baz;
// template <template <template <typename> class> class T>
// class Foo;
//
// The Baz is a template class. The Baz<Bar> is a class. The class Foo
// should be specialized with template class, not a class. The correct
// specialization of template class Foo is Foo<Baz>. The incorrect
// specialization of template class Foo is Foo<Baz<Bar>>. In this case
// template class Foo specialized by class Baz<Bar>, not a template
// class template <template <typename> class> class T as it should.
TemplateDecl *TD = Arg.getAsTemplate().getAsTemplateDecl();
if (Printed.insert(TD).second) {
emitFwdDecl(O, TD);
}
break;
}
default:
break; // nop
}
Expand Down
72 changes: 72 additions & 0 deletions clang/test/CodeGenSYCL/template-template-parameter.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
// RUN: %clang -I %S/Inputs --sycl -Xclang -fsycl-int-header=%t.h %s
// RUN: FileCheck -input-file=%t.h %s

#include <sycl.hpp>

using namespace cl::sycl;

template <typename T> class Foo1;
// CHECK: template <typename T> class Foo1;
template <template <typename> class TT> class KernelName1;
// CHECK: template <template <typename> class TT> class KernelName1;
template <template <typename> class TT> void enqueue() {
queue q;
q.submit([&](handler &cgh) {
cgh.single_task<KernelName1<TT>>([](){});
});
}

template <typename TY> class Bar2;
// CHECK: template <typename TY> class Bar2;
template <template <typename> class TT> class Foo2;
// CHECK: template <template <typename> class TT> class Foo2;
template <class TTY> class KernelName2;
// CHECK: template <class TTY> class KernelName2;
template <class Y> void enqueue2() {
queue q;
q.submit([&](handler &cgh) {
cgh.single_task< KernelName2<Y> >([](){});
});
}

template <typename T> class Bar3;
// CHECK: template <typename T> class Bar3;
template <template <typename> class> class Baz3;
// CHECK: template <template <typename> class > class Baz3;
template <template <template <typename> class> class T> class Foo3;
// CHECK: template <template <template <typename> class > class T> class Foo3;
template <typename T , typename... Args> class Mist3;
// CHECK: template <typename T, typename ...Args> class Mist3;
template <typename T, template <typename, typename...> class, typename... Args> class Ice3;
// CHECK: template <typename T, template <typename, typename ...> class , typename ...Args> class Ice3;

int main() {
enqueue<Foo1>();

enqueue2<Foo2<Bar2>>();

queue q;

q.submit([&](handler &cgh) {
cgh.single_task<Bar3<int>>([](){});
});

q.submit([&](handler &cgh) {
cgh.single_task<Baz3<Bar3>>([](){});
});

q.submit([&](handler &cgh) {
cgh.single_task<Foo3<Baz3>>([](){});
});

q.submit([&](handler &cgh) {
cgh.single_task<Mist3<int, float, char, double>>([](){});
});

q.submit([&](handler &cgh) {
cgh.single_task<Ice3<int, Mist3, char, short, float>>([](){});
});

return 0;
}