Skip to content

Commit 8746f8b

Browse files
authored
Merge pull request #7 from KernelTuner/develop
Add support for templated types as parameters in pragma kernels
2 parents 6e27d0f + 6c3ec70 commit 8746f8b

File tree

14 files changed

+208
-30
lines changed

14 files changed

+208
-30
lines changed

include/kernel_launcher/builder.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#define KERNEL_LAUNCHER_BUILDER_H
33

44
#include <unordered_map>
5+
#include <utility>
56
#include <vector>
67

78
#include "kernel_launcher/arg.h"
@@ -24,11 +25,13 @@ struct KernelInstance {
2425
CudaModule module,
2526
std::array<TypedExpr<uint32_t>, 3> block_size,
2627
std::array<TypedExpr<uint32_t>, 3> grid_size,
27-
TypedExpr<uint32_t> shared_mem) :
28+
TypedExpr<uint32_t> shared_mem = 0,
29+
std::vector<TypedExpr<bool>> assertions = {}) :
2830
module_(std::move(module)),
2931
block_size_(std::move(block_size)),
3032
grid_size_(std::move(grid_size)),
31-
shared_mem_(std::move(shared_mem)) {}
33+
shared_mem_(std::move(shared_mem)),
34+
assertions_(std::move(assertions)) {}
3235

3336
void launch(
3437
cudaStream_t stream,
@@ -52,6 +55,7 @@ struct KernelInstance {
5255
std::array<TypedExpr<uint32_t>, 3> block_size_ = {1, 1, 1};
5356
std::array<TypedExpr<uint32_t>, 3> grid_size_ = {0, 0, 0};
5457
TypedExpr<uint32_t> shared_mem_ = 0;
58+
std::vector<TypedExpr<bool>> assertions_;
5559
};
5660

5761
struct KernelBuilderSerializerHack;

include/kernel_launcher/compiler.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,9 @@ struct Compiler: ICompiler {
101101
*/
102102
template<typename C>
103103
Compiler(C&& compiler) :
104-
inner_(std::make_shared<typename std::decay<C>::type>(
105-
std::forward<C>(compiler))) {}
104+
inner_(
105+
std::make_shared<typename std::decay<C>::type>(
106+
std::forward<C>(compiler))) {}
106107

107108
void compile_ptx(
108109
KernelDef def,

include/kernel_launcher/config.h

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -203,18 +203,33 @@ struct ConfigSpace {
203203
*/
204204
void restriction(TypedExpr<bool> e);
205205

206+
/**
207+
* Returns the restrictions added by `restriction()`
208+
*/
209+
const std::vector<TypedExpr<bool>>& restrictions() const {
210+
return restrictions_;
211+
}
212+
206213
/**
207214
* Returns the default configuration for this configuration space.
208215
*/
209216
Config default_config() const;
210217

211218
/**
212-
* Check if the given configuration is a valid member of this configuration
213-
* space. This method essentially checks three things:
219+
* Check if the given configuration is a member of this configuration
220+
* space. This method essentially checks two things:
214221
*
215222
* * Does the configuration contain the correct parameters.
216223
* * Do these parameter contain valid values.
217-
* * Does the configuration meet the restrictions.
224+
*
225+
* However, it does _not_ check if the configuration satisfies the
226+
* restrictions.
227+
*/
228+
bool contains(const Eval& config) const;
229+
230+
/**
231+
* Check if the given configuration is a valid member of this configuration
232+
* space and also meets the restrictions.
218233
*/
219234
bool is_valid(const Eval& config) const;
220235

include/kernel_launcher/utils.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#define KERNEL_LAUNCHER_UTILS_H
33

44
#include <cuda_runtime_api.h>
5+
#include <stdint.h>
56

67
#include <functional>
78
#include <iosfwd>
@@ -11,7 +12,6 @@
1112
#include <type_traits>
1213
#include <typeindex>
1314
#include <vector>
14-
#include <stdint.h>
1515

1616
namespace kernel_launcher {
1717

src/arg.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ KernelArg::KernelArg(KernelArg&& that) noexcept : KernelArg() {
3737

3838
KernelArg::~KernelArg() {
3939
if (is_scalar() && !is_inline_scalar(type_)) {
40-
delete[](char*) data_.large_scalar;
40+
delete[] (char*)data_.large_scalar;
4141
}
4242
}
4343

src/builder.cpp

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,18 @@ void KernelInstance::launch(
131131
}
132132
}
133133

134+
// We check for assertions now after printing the debug information. This
135+
// allows one to check the debugging output to see what where the arguments
136+
// provided to kernel that caused the assertion to fail.
137+
for (const auto& assertion : assertions_) {
138+
if (!eval(assertion)) {
139+
std::stringstream ss;
140+
ss << "failed to launch kernel `" << module_.function_name()
141+
<< "`, assertion failed: `" << assertion.to_string() << "`";
142+
throw std::runtime_error(ss.str());
143+
}
144+
}
145+
134146
module_.launch(stream, grid_size, block_size, smem, ptrs.data());
135147
}
136148

@@ -168,6 +180,14 @@ KernelBuilder& KernelBuilder::buffer_size(ArgExpr arg, TypedExpr<size_t> len) {
168180
ArgsEval eval {args, fallback};
169181
size_t i = arg.get();
170182
size_t n = eval(len);
183+
184+
if (i >= args.size()) {
185+
throw std::runtime_error(
186+
"argument " + std::to_string(i)
187+
+ " is out of bounds for kernel that has only "
188+
+ std::to_string(args.size()) + " parameters");
189+
}
190+
171191
args[i] = args[i].to_array(n);
172192
});
173193
}
@@ -444,7 +464,34 @@ KernelInstance KernelBuilder::compile(
444464
const std::vector<TypeInfo>& param_types,
445465
const ICompiler& compiler,
446466
CudaContextHandle ctx) const {
467+
if (!contains(config)) {
468+
std::stringstream ss;
469+
ss << "invalid configuration: `" << config << "`";
470+
throw std::runtime_error(ss.str());
471+
}
472+
447473
DeviceAttrEval eval = {ctx.device(), config};
474+
std::vector<TypedExpr<bool>> assertions;
475+
476+
for (const auto& restriction : restrictions()) {
477+
auto r = restriction.resolve(eval);
478+
479+
if (!r.is_constant()) {
480+
// Any restriction that contain kernel arguments cannot be resolved
481+
// now at this moment. We add these to the list of assertions
482+
// that will be checked each time the kernel gets launched.
483+
assertions.emplace_back(r);
484+
continue;
485+
}
486+
487+
if (!r.eval(eval)) {
488+
std::stringstream ss;
489+
ss << "configuration `" << config
490+
<< "` does not meet the following restriction: `"
491+
<< restriction.to_string() << "`";
492+
throw std::runtime_error(ss.str());
493+
}
494+
}
448495

449496
if (!is_valid(eval)) {
450497
std::stringstream ss;
@@ -469,7 +516,8 @@ KernelInstance KernelBuilder::compile(
469516
std::move(module),
470517
std::move(block_size),
471518
std::move(grid_size),
472-
shared_mem};
519+
std::move(shared_mem),
520+
std::move(assertions)};
473521
}
474522

475-
} // namespace kernel_launcher
523+
} // namespace kernel_launcher

src/config.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ Config ConfigSpace::default_config() const {
111111
return config;
112112
}
113113

114-
bool ConfigSpace::is_valid(const Eval& config) const {
114+
bool ConfigSpace::contains(const Eval& config) const {
115115
for (const auto& p : params_) {
116116
Value v;
117117

@@ -124,6 +124,14 @@ bool ConfigSpace::is_valid(const Eval& config) const {
124124
}
125125
}
126126

127+
return true;
128+
}
129+
130+
bool ConfigSpace::is_valid(const Eval& config) const {
131+
if (!contains(config)) {
132+
return false;
133+
}
134+
127135
for (const auto& r : restrictions_) {
128136
if (!config(r)) {
129137
return false;

src/internal/directives.cpp

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ static Expr parse_prim(TokenStream& stream, const Context& ctx) {
215215

216216
// NOLINTNEXTLINE(readability-function-cognitive-complexity)
217217
static Expr parse_binop(TokenStream& stream, const Context& ctx, int prec) {
218-
// TODO: == != <= >= && || %
218+
// TODO: %
219219
Expr lhs = parse_prim(stream, ctx);
220220

221221
while (true) {
@@ -416,6 +416,20 @@ static void parse_tuning_key_directive(
416416
builder.tuning_key(std::move(key));
417417
}
418418

419+
template<typename... Ts>
420+
static bool one_of(const std::string& value, Ts&... items) {
421+
static constexpr size_t N = sizeof...(Ts);
422+
const char* array[N] = {items...};
423+
424+
for (size_t i = 0; i < N; i++) {
425+
if (value == array[i]) {
426+
return true;
427+
}
428+
}
429+
430+
return false;
431+
}
432+
419433
static void
420434
process_directive(TokenStream& stream, KernelBuilder& builder, Context& ctx) {
421435
while (!stream.next_if(TokenKind::DirectiveEnd)) {
@@ -424,25 +438,28 @@ process_directive(TokenStream& stream, KernelBuilder& builder, Context& ctx) {
424438

425439
if (name == "tune") {
426440
parse_tune_directive(stream, builder, ctx);
427-
} else if (name == "set") {
441+
} else if (one_of(name, "set", "let")) {
428442
parse_set_directive(stream, ctx);
429-
} else if (name == "buffers" || name == "buffer") {
443+
} else if (one_of(name, "buffer", "buffers")) {
430444
parse_buffer_directive(stream, builder, ctx);
431445
} else if (name == "tuning_key") {
432446
parse_tuning_key_directive(stream, builder, ctx);
433-
} else if (name == "grid_size" || name == "grid_dim") {
447+
} else if (one_of(name, "grid_size", "grid_dim")) {
434448
auto l = parse_expr_list3(stream, ctx);
435449
builder.grid_size(l[0], l[1], l[2]);
436-
} else if (name == "block_size" || name == "block_dim") {
450+
} else if (one_of(name, "block_size", "block_dim")) {
437451
auto l = parse_expr_list3(stream, ctx);
438452
builder.block_size(l[0], l[1], l[2]);
439-
} else if (name == "grid_divisor" || name == "grid_divisors") {
453+
} else if (one_of(name, "grid_divisor", "grid_divisors")) {
440454
auto l = parse_expr_list3(stream, ctx);
441455
builder.grid_divisors(l[0], l[1], l[2]);
442-
} else if (name == "problem_size" || name == "problem_dim") {
456+
} else if (one_of(name, "problem_size", "problem_dim")) {
443457
auto l = parse_expr_list3(stream, ctx);
444458
builder.problem_size(l[0], l[1], l[2]);
445-
} else if (name == "restriction" || name == "restrictions") {
459+
} else if (
460+
name == "if"
461+
|| one_of(name, "restriction", "restrictions", "restrict")
462+
|| one_of(name, "assertion", "assertions", "assert")) {
446463
for (const auto& expr : parse_expr_list(stream, ctx)) {
447464
builder.restriction(expr);
448465
}

src/internal/parser.cpp

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,18 @@ static std::vector<FunctionParam> parse_kernel_params(TokenStream& stream) {
4646
Token before_name = begin;
4747
Token name = stream.next();
4848
Token end = stream.peek();
49+
int template_depth = 0;
50+
51+
while (template_depth > 0
52+
|| !(
53+
end.kind == TokenKind::Comma
54+
|| end.kind == TokenKind::ParenR)) {
55+
if (name.kind == TokenKind::AngleL) {
56+
template_depth++;
57+
} else if (name.kind == TokenKind::AngleR && template_depth > 0) {
58+
template_depth--;
59+
}
4960

50-
while (end.kind != TokenKind::Comma && end.kind != TokenKind::ParenR) {
5161
before_name = name;
5262
name = stream.next();
5363
end = stream.peek();

src/internal/tokens.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,9 @@ static index_t advance_string(index_t i, const std::string& input) {
8787
TokenKind char2_to_kind(char a, char b) {
8888
if ((a == '=' && b == '=') || (a == '!' && b == '=')
8989
|| (a == '<' && b == '=') || (a == '>' && b == '=')
90-
|| (a == '&' && b == '&') || (a == '|' && b == '|')
91-
|| (a == '<' && b == '<') || (a == '>' && b == '>')
90+
|| (a == '&' && b == '&')
91+
|| (a == '|' && b == '|')
92+
//|| (a == '<' && b == '<') || (a == '>' && b == '>')
9293
|| (a == ':' && b == ':')) {
9394
return TokenKind::Punct;
9495
}

0 commit comments

Comments
 (0)