Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,7 @@ struct ProblemSize {
std::vector<int64_t> stride;
std::vector<int64_t> dilation;
bool operator==(const ProblemSize& ps) const {
return activation_shape[1] == ps.activation_shape[1] &&
activation_shape[2] == ps.activation_shape[2] &&
activation_shape[3] == ps.activation_shape[3] &&
return activation_shape == ps.activation_shape &&
filter_shape == ps.filter_shape;
}
void print() const {
Expand Down Expand Up @@ -68,54 +66,137 @@ inline void hash_combine(std::size_t& seed, std::size_t value) {
struct ProblemSizeHash {
std::size_t operator()(const ProblemSize& ps) const {
std::size_t seed = 0;
// Only hash spatial dimensions (D, H, W) from activation_shape, not batch
// (N) or channels (C)
hash_combine(seed, std::hash<int64_t>{}(ps.activation_shape[1]));
hash_combine(seed, std::hash<int64_t>{}(ps.activation_shape[2]));
hash_combine(seed, std::hash<int64_t>{}(ps.activation_shape[3]));
// Hash the entire filter_shape
auto vec_hash = [](const std::vector<int64_t>& v) {
std::size_t h = 0;
for (auto x : v)
hash_combine(h, std::hash<int64_t>{}(x));
return h;
};
hash_combine(seed, vec_hash(ps.activation_shape));
hash_combine(seed, vec_hash(ps.filter_shape));
// Exclude padding, stride, and dilation from hash
// hash_combine(seed, vec_hash(ps.padding));
// hash_combine(seed, vec_hash(ps.stride));
// hash_combine(seed, vec_hash(ps.dilation));
return seed;
}
};

// clang-format off
// Tuned on GB200
std::unordered_map<ProblemSize, Kernel_f8f8bf16_conv, ProblemSizeHash> kernel_map = {
{{{1,1,192,128,1024}, {512,1,1,1,1024}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_128x256x128_2x1x1},
{{{1,1,192,128,160}, {320,1,1,1,160}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x256x128_2x1x1},
{{{1,1,192,128,160}, {320,1,1,1,160}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_128x128x128_1x1x1},
{{{1,1,384,256,512}, {256,1,1,1,512}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_512x256x128_4x1x1},
{{{1,1,96,64,320}, {640,1,1,1,320}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_128x256x128_1x2x1},
{{{1,1,96,64,320}, {640,1,1,1,320}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_128x128x128_1x1x1},
{{{1,3,194,130,1024}, {512,3,3,3,1024}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_512x256x128_4x1x1},
{{{1,3,194,130,160}, {320,3,3,3,160}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_512x256x128_4x1x1},
{{{1,3,194,130,160}, {320,3,3,3,160}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x128x128_4x1x1},
{{{1,3,194,130,320}, {320,3,3,3,320}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x128x128_4x1x1},
{{{1,3,194,130,512}, {512,3,3,3,512}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_512x256x128_4x1x1},
{{{1,3,386,258,160}, {160,3,3,3,160}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_512x256x128_4x1x1},
{{{1,3,386,258,256}, {256,3,3,3,256}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x256x128_2x1x1},
{{{1,3,386,258,512}, {256,3,3,3,512}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_512x256x128_4x1x1},
{{{1,3,48,32,1024}, {2048,3,1,1,1024}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x128x128_4x1x1},
{{{1,3,194,130,512}, {512,3,3,3,512}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x512x128_2x2x1},
{{{1,3,386,258,160}, {160,3,3,3,160}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x256x128_4x1x1},
{{{1,3,386,258,256}, {256,3,3,3,256}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_512x256x128_4x1x1},
{{{1,3,386,258,512}, {256,3,3,3,512}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x256x128_4x1x1},
{{{1,3,48,32,1024}, {2048,3,1,1,1024}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_128x128x128_2x1x1},
{{{1,3,50,34,1024}, {1024,3,3,3,1024}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_128x256x128_2x1x1},
{{{1,3,50,34,48}, {1024,3,3,3,48}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x256x128_4x1x1},
{{{1,3,50,34,48}, {1024,3,3,3,48}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x1024x128_4x4x1},
{{{1,3,50,34,640}, {640,3,3,3,640}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x128x128_4x1x1},
{{{1,3,50,34,640}, {96,3,3,3,640}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x256x128_4x2x1},
{{{1,3,50,34,640}, {96,3,3,3,640}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_128x128x128_2x1x1},
{{{1,3,98,66,1024}, {1024,3,3,3,1024}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x256x128_4x1x1},
{{{1,3,98,66,1024}, {1024,3,3,3,1024}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x256x128_4x1x1},
{{{1,3,98,66,320}, {640,3,3,3,320}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x256x128_2x1x1},
{{{1,3,98,66,640}, {640,3,3,3,640}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_128x256x128_2x1x1},
{{{1,3,98,66,320}, {640,3,3,3,320}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_512x256x128_4x1x1},
{{{1,3,98,66,640}, {640,3,3,3,640}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x256x128_2x1x1},
{{{1,4,192,128,1024}, {512,1,1,1,1024}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_128x256x128_2x1x1},
{{{1,4,384,256,512}, {256,1,1,1,512}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_512x256x128_4x1x1},
{{{1,4,96,64,1024}, {2048,3,1,1,1024}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_512x256x128_4x1x1},
{{{1,4,98,66,1024}, {1024,3,3,3,1024}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_512x256x128_4x1x1},
{{{1,6,194,130,1024}, {512,3,3,3,1024}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x256x128_2x1x1},
{{{1,6,194,130,512}, {512,3,3,3,512}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x256x128_2x1x1},
{{{1,6,386,258,256}, {256,3,3,3,256}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x256x128_2x1x1},
{{{1,6,386,258,512}, {256,3,3,3,512}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x256x128_2x1x1},
{{{1,6,194,130,1024}, {512,3,3,3,1024}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x512x128_2x2x1},
{{{1,6,194,130,512}, {512,3,3,3,512}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x512x128_2x2x1},
{{{1,6,386,258,256}, {256,3,3,3,256}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_512x256x128_4x1x1},
{{{1,6,386,258,512}, {256,3,3,3,512}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_512x256x128_4x1x1},
{{{2,1,192,128,1024}, {512,1,1,1,1024}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_128x256x128_2x1x1},
{{{2,1,192,128,160}, {320,1,1,1,160}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x256x128_2x1x1},
{{{2,1,384,256,512}, {256,1,1,1,512}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_512x256x128_4x1x1},
{{{2,1,96,64,320}, {640,1,1,1,320}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_512x256x128_4x1x1},
{{{2,3,194,130,1024}, {512,3,3,3,1024}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x512x128_2x2x1},
{{{2,3,194,130,160}, {320,3,3,3,160}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x256x128_4x1x1},
{{{2,3,194,130,320}, {320,3,3,3,320}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x128x128_4x1x1},
{{{2,3,194,130,512}, {512,3,3,3,512}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x256x128_4x1x1},
{{{2,3,386,258,160}, {160,3,3,3,160}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_512x256x128_4x1x1},
{{{2,3,386,258,256}, {256,3,3,3,256}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_512x256x128_4x1x1},
{{{2,3,386,258,512}, {256,3,3,3,512}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_512x256x128_4x1x1},
{{{2,3,48,32,1024}, {2048,3,1,1,1024}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x256x128_4x1x1},
{{{2,3,50,34,1024}, {1024,3,3,3,1024}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_128x256x128_2x2x1},
{{{2,3,50,34,48}, {1024,3,3,3,48}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_512x1024x128_4x4x1},
{{{2,3,50,34,640}, {640,3,3,3,640}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_128x256x128_2x1x1},
{{{2,3,50,34,640}, {96,3,3,3,640}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x128x128_4x1x1},
{{{2,3,98,66,1024}, {1024,3,3,3,1024}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_512x256x128_4x1x1},
{{{2,3,98,66,1024}, {1024,3,3,3,1024}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_512x256x128_4x1x1},
{{{2,3,98,66,320}, {640,3,3,3,320}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x256x128_2x1x1},
{{{2,3,98,66,640}, {640,3,3,3,640}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x256x128_4x1x1},
{{{2,4,192,128,1024}, {512,1,1,1,1024}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x256x128_2x1x1},
{{{2,4,384,256,512}, {256,1,1,1,512}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_512x256x128_4x1x1},
{{{2,4,96,64,1024}, {2048,3,1,1,1024}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x512x128_2x2x1},
{{{2,4,98,66,1024}, {1024,3,3,3,1024}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_512x256x128_4x1x1},
{{{2,6,194,130,1024}, {512,3,3,3,1024}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_512x256x128_4x1x1},
{{{2,6,194,130,512}, {512,3,3,3,512}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x256x128_2x1x1},
{{{2,6,386,258,256}, {256,3,3,3,256}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_512x256x128_4x1x1},
{{{2,6,386,258,512}, {256,3,3,3,512}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_512x256x128_4x1x1},
{{{4,1,192,128,1024}, {512,1,1,1,1024}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x256x128_2x1x1},
{{{4,1,192,128,160}, {320,1,1,1,160}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x256x128_2x1x1},
{{{4,1,384,256,512}, {256,1,1,1,512}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x256x128_2x1x1},
{{{4,1,96,64,320}, {640,1,1,1,320}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x256x128_2x1x1},
{{{4,3,194,130,1024}, {512,3,3,3,1024}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_512x256x128_4x1x1},
{{{4,3,194,130,160}, {320,3,3,3,160}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_512x256x128_4x1x1},
{{{4,3,194,130,320}, {320,3,3,3,320}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_512x256x128_4x1x1},
{{{4,3,194,130,512}, {512,3,3,3,512}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x512x128_2x2x1},
{{{4,3,386,258,160}, {160,3,3,3,160}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x256x128_2x1x1},
{{{4,3,386,258,256}, {256,3,3,3,256}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_512x256x128_4x1x1},
{{{4,3,386,258,512}, {256,3,3,3,512}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_512x256x128_4x1x1},
{{{4,3,48,32,1024}, {2048,3,1,1,1024}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x256x128_2x1x1},
{{{4,3,50,34,1024}, {1024,3,3,3,1024}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x256x128_4x1x1},
{{{4,3,50,34,48}, {1024,3,3,3,48}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_512x512x128_4x2x1},
{{{4,3,50,34,640}, {640,3,3,3,640}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x256x128_2x1x1},
{{{4,3,50,34,640}, {96,3,3,3,640}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_128x128x128_2x1x1},
{{{4,3,98,66,1024}, {1024,3,3,3,1024}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_512x256x128_4x1x1},
{{{4,3,98,66,1024}, {1024,3,3,3,1024}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x512x128_2x2x1},
{{{4,3,98,66,320}, {640,3,3,3,320}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_512x256x128_4x1x1},
{{{4,3,98,66,640}, {640,3,3,3,640}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_512x256x128_4x1x1},
{{{4,4,192,128,1024}, {512,1,1,1,1024}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x256x128_2x1x1},
{{{4,4,384,256,512}, {256,1,1,1,512}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_512x256x128_4x1x1},
{{{4,4,96,64,1024}, {2048,3,1,1,1024}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x256x128_2x1x1},
{{{4,4,98,66,1024}, {1024,3,3,3,1024}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_512x256x128_4x1x1},
{{{4,6,194,130,1024}, {512,3,3,3,1024}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x512x128_2x2x1},
{{{4,6,194,130,512}, {512,3,3,3,512}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_512x256x128_4x1x1},
{{{4,6,386,258,256}, {256,3,3,3,256}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_512x256x128_4x1x1},
{{{4,6,386,258,512}, {256,3,3,3,512}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x256x128_2x1x1},
{{{8,1,192,128,1024}, {512,1,1,1,1024}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x256x128_2x1x1},
{{{8,1,192,128,160}, {320,1,1,1,160}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x256x128_2x1x1},
{{{8,1,384,256,512}, {256,1,1,1,512}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_512x256x128_4x1x1},
{{{8,1,96,64,320}, {640,1,1,1,320}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_512x256x128_4x1x1},
{{{8,3,194,130,1024}, {512,3,3,3,1024}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_512x256x128_4x1x1},
{{{8,3,194,130,160}, {320,3,3,3,160}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x256x128_2x1x1},
{{{8,3,194,130,320}, {320,3,3,3,320}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x256x128_2x1x1},
{{{8,3,194,130,512}, {512,3,3,3,512}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_512x256x128_4x1x1},
{{{8,3,386,258,160}, {160,3,3,3,160}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x256x128_2x1x1},
{{{8,3,386,258,256}, {256,3,3,3,256}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_512x256x128_4x1x1},
{{{8,3,386,258,512}, {256,3,3,3,512}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_512x256x128_4x1x1},
{{{8,3,48,32,1024}, {2048,3,1,1,1024}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_512x256x128_4x1x1},
{{{8,3,50,34,1024}, {1024,3,3,3,1024}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_512x256x128_4x1x1},
{{{8,3,50,34,48}, {1024,3,3,3,48}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_512x512x128_4x2x1},
{{{8,3,50,34,640}, {640,3,3,3,640}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_512x256x128_4x1x1},
{{{8,3,50,34,640}, {96,3,3,3,640}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_128x128x128_1x1x1},
{{{8,3,98,66,1024}, {1024,3,3,3,1024}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x256x128_2x1x1},
{{{8,3,98,66,1024}, {1024,3,3,3,1024}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x256x128_2x1x1},
{{{8,3,98,66,320}, {640,3,3,3,320}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_512x256x128_4x1x1},
{{{8,3,98,66,640}, {640,3,3,3,640}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_512x256x128_4x1x1},
{{{8,4,192,128,1024}, {512,1,1,1,1024}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x256x128_2x1x1},
{{{8,4,384,256,512}, {256,1,1,1,512}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_512x256x128_4x1x1},
{{{8,4,96,64,1024}, {2048,3,1,1,1024}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x256x128_2x1x1},
{{{8,4,98,66,1024}, {1024,3,3,3,1024}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_512x256x128_4x1x1},
{{{8,6,194,130,1024}, {512,3,3,3,1024}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x256x128_2x1x1},
{{{8,6,194,130,512}, {512,3,3,3,512}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x256x128_2x1x1},
{{{8,6,386,258,256}, {256,3,3,3,256}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x256x128_2x1x1},
{{{8,6,386,258,512}, {256,3,3,3,512}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x256x128_2x1x1},
};
// clang-format on

Expand All @@ -132,14 +213,46 @@ Kernel_f8f8bf16_conv get_kernel_via_heuristic(
padding,
stride,
dilation};

// Try exact match first
auto it = kernel_map.find(ps);
if (it != kernel_map.end()) {
return it->second;
} else {
std::cout << "warning: not found - ";
ps.print();
std::cout << std::endl;
}

#if 1
// If no exact match, look for configs with same spatial dims and filter
// but use the one with the largest batch that is <= current batch
int64_t current_batch = ps.activation_shape[0];
Kernel_f8f8bf16_conv best_kernel = nullptr;
int64_t best_batch = 0;

for (const auto& [candidate_ps, kernel] : kernel_map) {
// Check if spatial dimensions and filter match
if (candidate_ps.activation_shape[1] == ps.activation_shape[1] &&
candidate_ps.activation_shape[2] == ps.activation_shape[2] &&
candidate_ps.activation_shape[3] == ps.activation_shape[3] &&
candidate_ps.filter_shape == ps.filter_shape) {
int64_t candidate_batch = candidate_ps.activation_shape[0];

// Use config with largest batch that is <= current batch
if (candidate_batch <= current_batch && candidate_batch > best_batch) {
best_batch = candidate_batch;
best_kernel = kernel;
}
}
}

if (best_kernel != nullptr) {
return best_kernel;
}
#endif

// No suitable config found
std::cout << "warning: not found - ";
ps.print();
std::cout << std::endl;

// Fallback kernel
return f8f8bf16_conv_256x256x128_2x1x1;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include "f8f8bf16_conv_common.cuh"

namespace fbgemm_gpu {

at::Tensor f8f8bf16_conv_128x128x128_2x1x1(
at::Tensor activation, // FP8 - NDHWC layout
at::Tensor filter, // FP8 - KTRSC layout
at::Tensor scale,
std::vector<int64_t> padding, // [pad_d, pad_h, pad_w]
std::vector<int64_t> stride, // [stride_d, stride_h, stride_w]
std::vector<int64_t> dilation) { // [dilation_d, dilation_h, dilation_w]

return f8f8bf16_conv_impl<
128,
128,
128,
2,
1,
1,
cutlass::conv::KernelImplicitTmaWarpSpecialized2SmSm100>(
activation, filter, scale, padding, stride, dilation);
}

} // namespace fbgemm_gpu
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include "f8f8bf16_conv_common.cuh"

namespace fbgemm_gpu {

at::Tensor f8f8bf16_conv_256x1024x128_4x4x1(
at::Tensor activation, // FP8 - NDHWC layout
at::Tensor filter, // FP8 - KTRSC layout
at::Tensor scale,
std::vector<int64_t> padding, // [pad_d, pad_h, pad_w]
std::vector<int64_t> stride, // [stride_d, stride_h, stride_w]
std::vector<int64_t> dilation) { // [dilation_d, dilation_h, dilation_w]

return f8f8bf16_conv_impl<
128,
256,
128,
4,
4,
1,
cutlass::conv::KernelImplicitTmaWarpSpecialized2SmSm100>(
activation, filter, scale, padding, stride, dilation);
}

} // namespace fbgemm_gpu
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include "f8f8bf16_conv_common.cuh"

namespace fbgemm_gpu {

at::Tensor f8f8bf16_conv_512x1024x128_4x4x1(
at::Tensor activation, // FP8 - NDHWC layout
at::Tensor filter, // FP8 - KTRSC layout
at::Tensor scale,
std::vector<int64_t> padding, // [pad_d, pad_h, pad_w]
std::vector<int64_t> stride, // [stride_d, stride_h, stride_w]
std::vector<int64_t> dilation) { // [dilation_d, dilation_h, dilation_w]

return f8f8bf16_conv_impl<
256,
256,
128,
4,
4,
1,
cutlass::conv::KernelImplicitTmaWarpSpecialized2SmSm100>(
activation, filter, scale, padding, stride, dilation);
}

} // namespace fbgemm_gpu
Loading
Loading