@@ -215,7 +215,7 @@ static Expr parse_prim(TokenStream& stream, const Context& ctx) {
215215
216216// NOLINTNEXTLINE(readability-function-cognitive-complexity)
217217static Expr parse_binop (TokenStream& stream, const Context& ctx, int prec) {
218- // TODO: == != <= >= && || %
218+ // TODO: %
219219 Expr lhs = parse_prim (stream, ctx);
220220
221221 while (true ) {
@@ -416,6 +416,20 @@ static void parse_tuning_key_directive(
416416 builder.tuning_key (std::move (key));
417417}
418418
419+ template <typename ... Ts>
420+ static bool one_of (const std::string& value, Ts&... items) {
421+ static constexpr size_t N = sizeof ...(Ts);
422+ const char * array[N] = {items...};
423+
424+ for (size_t i = 0 ; i < N; i++) {
425+ if (value == array[i]) {
426+ return true ;
427+ }
428+ }
429+
430+ return false ;
431+ }
432+
419433static void
420434process_directive (TokenStream& stream, KernelBuilder& builder, Context& ctx) {
421435 while (!stream.next_if (TokenKind::DirectiveEnd)) {
@@ -424,25 +438,28 @@ process_directive(TokenStream& stream, KernelBuilder& builder, Context& ctx) {
424438
425439 if (name == " tune" ) {
426440 parse_tune_directive (stream, builder, ctx);
427- } else if (name == " set" ) {
441+ } else if (one_of ( name, " set" , " let " ) ) {
428442 parse_set_directive (stream, ctx);
429- } else if (name == " buffers " || name == " buffer " ) {
443+ } else if (one_of ( name, " buffer " , " buffers " ) ) {
430444 parse_buffer_directive (stream, builder, ctx);
431445 } else if (name == " tuning_key" ) {
432446 parse_tuning_key_directive (stream, builder, ctx);
433- } else if (name == " grid_size" || name == " grid_dim" ) {
447+ } else if (one_of ( name, " grid_size" , " grid_dim" ) ) {
434448 auto l = parse_expr_list3 (stream, ctx);
435449 builder.grid_size (l[0 ], l[1 ], l[2 ]);
436- } else if (name == " block_size" || name == " block_dim" ) {
450+ } else if (one_of ( name, " block_size" , " block_dim" ) ) {
437451 auto l = parse_expr_list3 (stream, ctx);
438452 builder.block_size (l[0 ], l[1 ], l[2 ]);
439- } else if (name == " grid_divisor" || name == " grid_divisors" ) {
453+ } else if (one_of ( name, " grid_divisor" , " grid_divisors" ) ) {
440454 auto l = parse_expr_list3 (stream, ctx);
441455 builder.grid_divisors (l[0 ], l[1 ], l[2 ]);
442- } else if (name == " problem_size" || name == " problem_dim" ) {
456+ } else if (one_of ( name, " problem_size" , " problem_dim" ) ) {
443457 auto l = parse_expr_list3 (stream, ctx);
444458 builder.problem_size (l[0 ], l[1 ], l[2 ]);
445- } else if (name == " restriction" || name == " restrictions" ) {
459+ } else if (
460+ name == " if"
461+ || one_of (name, " restriction" , " restrictions" , " restrict" )
462+ || one_of (name, " assertion" , " assertions" , " assert" )) {
446463 for (const auto & expr : parse_expr_list (stream, ctx)) {
447464 builder.restriction (expr);
448465 }
0 commit comments