diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_conv.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_conv.cu index e8ff23e1e4..856b686d15 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_conv.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_conv.cu @@ -37,9 +37,7 @@ struct ProblemSize { std::vector stride; std::vector dilation; bool operator==(const ProblemSize& ps) const { - return activation_shape[1] == ps.activation_shape[1] && - activation_shape[2] == ps.activation_shape[2] && - activation_shape[3] == ps.activation_shape[3] && + return activation_shape == ps.activation_shape && filter_shape == ps.filter_shape; } void print() const { @@ -68,11 +66,6 @@ inline void hash_combine(std::size_t& seed, std::size_t value) { struct ProblemSizeHash { std::size_t operator()(const ProblemSize& ps) const { std::size_t seed = 0; - // Only hash spatial dimensions (D, H, W) from activation_shape, not batch - // (N) or channels (C) - hash_combine(seed, std::hash{}(ps.activation_shape[1])); - hash_combine(seed, std::hash{}(ps.activation_shape[2])); - hash_combine(seed, std::hash{}(ps.activation_shape[3])); // Hash the entire filter_shape auto vec_hash = [](const std::vector& v) { std::size_t h = 0; @@ -80,42 +73,130 @@ struct ProblemSizeHash { hash_combine(h, std::hash{}(x)); return h; }; + hash_combine(seed, vec_hash(ps.activation_shape)); hash_combine(seed, vec_hash(ps.filter_shape)); - // Exclude padding, stride, and dilation from hash + // hash_combine(seed, vec_hash(ps.padding)); + // hash_combine(seed, vec_hash(ps.stride)); + // hash_combine(seed, vec_hash(ps.dilation)); return seed; } }; // clang-format off +// Tuned on GB200 std::unordered_map kernel_map = { {{{1,1,192,128,1024}, {512,1,1,1,1024}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_128x256x128_2x1x1}, -{{{1,1,192,128,160}, {320,1,1,1,160}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x256x128_2x1x1}, +{{{1,1,192,128,160}, {320,1,1,1,160}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_128x128x128_1x1x1}, {{{1,1,384,256,512}, {256,1,1,1,512}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_512x256x128_4x1x1}, -{{{1,1,96,64,320}, {640,1,1,1,320}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_128x256x128_1x2x1}, +{{{1,1,96,64,320}, {640,1,1,1,320}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_128x128x128_1x1x1}, {{{1,3,194,130,1024}, {512,3,3,3,1024}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_512x256x128_4x1x1}, -{{{1,3,194,130,160}, {320,3,3,3,160}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_512x256x128_4x1x1}, +{{{1,3,194,130,160}, {320,3,3,3,160}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x128x128_4x1x1}, {{{1,3,194,130,320}, {320,3,3,3,320}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x128x128_4x1x1}, -{{{1,3,194,130,512}, {512,3,3,3,512}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_512x256x128_4x1x1}, -{{{1,3,386,258,160}, {160,3,3,3,160}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_512x256x128_4x1x1}, -{{{1,3,386,258,256}, {256,3,3,3,256}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x256x128_2x1x1}, -{{{1,3,386,258,512}, {256,3,3,3,512}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_512x256x128_4x1x1}, -{{{1,3,48,32,1024}, {2048,3,1,1,1024}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x128x128_4x1x1}, +{{{1,3,194,130,512}, {512,3,3,3,512}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x512x128_2x2x1}, +{{{1,3,386,258,160}, {160,3,3,3,160}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x256x128_4x1x1}, +{{{1,3,386,258,256}, {256,3,3,3,256}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_512x256x128_4x1x1}, +{{{1,3,386,258,512}, {256,3,3,3,512}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x256x128_4x1x1}, +{{{1,3,48,32,1024}, {2048,3,1,1,1024}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_128x128x128_2x1x1}, {{{1,3,50,34,1024}, {1024,3,3,3,1024}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_128x256x128_2x1x1}, -{{{1,3,50,34,48}, {1024,3,3,3,48}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x256x128_4x1x1}, +{{{1,3,50,34,48}, {1024,3,3,3,48}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x1024x128_4x4x1}, {{{1,3,50,34,640}, {640,3,3,3,640}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x128x128_4x1x1}, -{{{1,3,50,34,640}, {96,3,3,3,640}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x256x128_4x2x1}, +{{{1,3,50,34,640}, {96,3,3,3,640}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_128x128x128_2x1x1}, {{{1,3,98,66,1024}, {1024,3,3,3,1024}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x256x128_4x1x1}, {{{1,3,98,66,1024}, {1024,3,3,3,1024}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x256x128_4x1x1}, -{{{1,3,98,66,320}, {640,3,3,3,320}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x256x128_2x1x1}, -{{{1,3,98,66,640}, {640,3,3,3,640}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_128x256x128_2x1x1}, +{{{1,3,98,66,320}, {640,3,3,3,320}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_512x256x128_4x1x1}, +{{{1,3,98,66,640}, {640,3,3,3,640}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x256x128_2x1x1}, {{{1,4,192,128,1024}, {512,1,1,1,1024}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_128x256x128_2x1x1}, {{{1,4,384,256,512}, {256,1,1,1,512}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_512x256x128_4x1x1}, {{{1,4,96,64,1024}, {2048,3,1,1,1024}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_512x256x128_4x1x1}, {{{1,4,98,66,1024}, {1024,3,3,3,1024}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_512x256x128_4x1x1}, -{{{1,6,194,130,1024}, {512,3,3,3,1024}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x256x128_2x1x1}, -{{{1,6,194,130,512}, {512,3,3,3,512}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x256x128_2x1x1}, -{{{1,6,386,258,256}, {256,3,3,3,256}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x256x128_2x1x1}, -{{{1,6,386,258,512}, {256,3,3,3,512}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x256x128_2x1x1}, +{{{1,6,194,130,1024}, {512,3,3,3,1024}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x512x128_2x2x1}, +{{{1,6,194,130,512}, {512,3,3,3,512}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x512x128_2x2x1}, +{{{1,6,386,258,256}, {256,3,3,3,256}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_512x256x128_4x1x1}, +{{{1,6,386,258,512}, {256,3,3,3,512}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_512x256x128_4x1x1}, +{{{2,1,192,128,1024}, {512,1,1,1,1024}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_128x256x128_2x1x1}, +{{{2,1,192,128,160}, {320,1,1,1,160}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x256x128_2x1x1}, +{{{2,1,384,256,512}, {256,1,1,1,512}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_512x256x128_4x1x1}, +{{{2,1,96,64,320}, {640,1,1,1,320}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_512x256x128_4x1x1}, +{{{2,3,194,130,1024}, {512,3,3,3,1024}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x512x128_2x2x1}, +{{{2,3,194,130,160}, {320,3,3,3,160}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x256x128_4x1x1}, +{{{2,3,194,130,320}, {320,3,3,3,320}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x128x128_4x1x1}, +{{{2,3,194,130,512}, {512,3,3,3,512}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x256x128_4x1x1}, +{{{2,3,386,258,160}, {160,3,3,3,160}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_512x256x128_4x1x1}, +{{{2,3,386,258,256}, {256,3,3,3,256}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_512x256x128_4x1x1}, +{{{2,3,386,258,512}, {256,3,3,3,512}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_512x256x128_4x1x1}, +{{{2,3,48,32,1024}, {2048,3,1,1,1024}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x256x128_4x1x1}, +{{{2,3,50,34,1024}, {1024,3,3,3,1024}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_128x256x128_2x2x1}, +{{{2,3,50,34,48}, {1024,3,3,3,48}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_512x1024x128_4x4x1}, +{{{2,3,50,34,640}, {640,3,3,3,640}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_128x256x128_2x1x1}, +{{{2,3,50,34,640}, {96,3,3,3,640}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x128x128_4x1x1}, +{{{2,3,98,66,1024}, {1024,3,3,3,1024}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_512x256x128_4x1x1}, +{{{2,3,98,66,1024}, {1024,3,3,3,1024}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_512x256x128_4x1x1}, +{{{2,3,98,66,320}, {640,3,3,3,320}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x256x128_2x1x1}, +{{{2,3,98,66,640}, {640,3,3,3,640}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x256x128_4x1x1}, +{{{2,4,192,128,1024}, {512,1,1,1,1024}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x256x128_2x1x1}, +{{{2,4,384,256,512}, {256,1,1,1,512}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_512x256x128_4x1x1}, +{{{2,4,96,64,1024}, {2048,3,1,1,1024}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x512x128_2x2x1}, +{{{2,4,98,66,1024}, {1024,3,3,3,1024}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_512x256x128_4x1x1}, +{{{2,6,194,130,1024}, {512,3,3,3,1024}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_512x256x128_4x1x1}, +{{{2,6,194,130,512}, {512,3,3,3,512}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x256x128_2x1x1}, +{{{2,6,386,258,256}, {256,3,3,3,256}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_512x256x128_4x1x1}, +{{{2,6,386,258,512}, {256,3,3,3,512}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_512x256x128_4x1x1}, +{{{4,1,192,128,1024}, {512,1,1,1,1024}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x256x128_2x1x1}, +{{{4,1,192,128,160}, {320,1,1,1,160}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x256x128_2x1x1}, +{{{4,1,384,256,512}, {256,1,1,1,512}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x256x128_2x1x1}, +{{{4,1,96,64,320}, {640,1,1,1,320}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x256x128_2x1x1}, +{{{4,3,194,130,1024}, {512,3,3,3,1024}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_512x256x128_4x1x1}, +{{{4,3,194,130,160}, {320,3,3,3,160}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_512x256x128_4x1x1}, +{{{4,3,194,130,320}, {320,3,3,3,320}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_512x256x128_4x1x1}, +{{{4,3,194,130,512}, {512,3,3,3,512}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x512x128_2x2x1}, +{{{4,3,386,258,160}, {160,3,3,3,160}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x256x128_2x1x1}, +{{{4,3,386,258,256}, {256,3,3,3,256}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_512x256x128_4x1x1}, +{{{4,3,386,258,512}, {256,3,3,3,512}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_512x256x128_4x1x1}, +{{{4,3,48,32,1024}, {2048,3,1,1,1024}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x256x128_2x1x1}, +{{{4,3,50,34,1024}, {1024,3,3,3,1024}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x256x128_4x1x1}, +{{{4,3,50,34,48}, {1024,3,3,3,48}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_512x512x128_4x2x1}, +{{{4,3,50,34,640}, {640,3,3,3,640}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x256x128_2x1x1}, +{{{4,3,50,34,640}, {96,3,3,3,640}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_128x128x128_2x1x1}, +{{{4,3,98,66,1024}, {1024,3,3,3,1024}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_512x256x128_4x1x1}, +{{{4,3,98,66,1024}, {1024,3,3,3,1024}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x512x128_2x2x1}, +{{{4,3,98,66,320}, {640,3,3,3,320}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_512x256x128_4x1x1}, +{{{4,3,98,66,640}, {640,3,3,3,640}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_512x256x128_4x1x1}, +{{{4,4,192,128,1024}, {512,1,1,1,1024}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x256x128_2x1x1}, +{{{4,4,384,256,512}, {256,1,1,1,512}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_512x256x128_4x1x1}, +{{{4,4,96,64,1024}, {2048,3,1,1,1024}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x256x128_2x1x1}, +{{{4,4,98,66,1024}, {1024,3,3,3,1024}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_512x256x128_4x1x1}, +{{{4,6,194,130,1024}, {512,3,3,3,1024}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x512x128_2x2x1}, +{{{4,6,194,130,512}, {512,3,3,3,512}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_512x256x128_4x1x1}, +{{{4,6,386,258,256}, {256,3,3,3,256}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_512x256x128_4x1x1}, +{{{4,6,386,258,512}, {256,3,3,3,512}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x256x128_2x1x1}, +{{{8,1,192,128,1024}, {512,1,1,1,1024}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x256x128_2x1x1}, +{{{8,1,192,128,160}, {320,1,1,1,160}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x256x128_2x1x1}, +{{{8,1,384,256,512}, {256,1,1,1,512}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_512x256x128_4x1x1}, +{{{8,1,96,64,320}, {640,1,1,1,320}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_512x256x128_4x1x1}, +{{{8,3,194,130,1024}, {512,3,3,3,1024}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_512x256x128_4x1x1}, +{{{8,3,194,130,160}, {320,3,3,3,160}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x256x128_2x1x1}, +{{{8,3,194,130,320}, {320,3,3,3,320}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x256x128_2x1x1}, +{{{8,3,194,130,512}, {512,3,3,3,512}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_512x256x128_4x1x1}, +{{{8,3,386,258,160}, {160,3,3,3,160}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x256x128_2x1x1}, +{{{8,3,386,258,256}, {256,3,3,3,256}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_512x256x128_4x1x1}, +{{{8,3,386,258,512}, {256,3,3,3,512}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_512x256x128_4x1x1}, +{{{8,3,48,32,1024}, {2048,3,1,1,1024}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_512x256x128_4x1x1}, +{{{8,3,50,34,1024}, {1024,3,3,3,1024}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_512x256x128_4x1x1}, +{{{8,3,50,34,48}, {1024,3,3,3,48}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_512x512x128_4x2x1}, +{{{8,3,50,34,640}, {640,3,3,3,640}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_512x256x128_4x1x1}, +{{{8,3,50,34,640}, {96,3,3,3,640}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_128x128x128_1x1x1}, +{{{8,3,98,66,1024}, {1024,3,3,3,1024}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x256x128_2x1x1}, +{{{8,3,98,66,1024}, {1024,3,3,3,1024}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x256x128_2x1x1}, +{{{8,3,98,66,320}, {640,3,3,3,320}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_512x256x128_4x1x1}, +{{{8,3,98,66,640}, {640,3,3,3,640}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_512x256x128_4x1x1}, +{{{8,4,192,128,1024}, {512,1,1,1,1024}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x256x128_2x1x1}, +{{{8,4,384,256,512}, {256,1,1,1,512}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_512x256x128_4x1x1}, +{{{8,4,96,64,1024}, {2048,3,1,1,1024}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x256x128_2x1x1}, +{{{8,4,98,66,1024}, {1024,3,3,3,1024}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_512x256x128_4x1x1}, +{{{8,6,194,130,1024}, {512,3,3,3,1024}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x256x128_2x1x1}, +{{{8,6,194,130,512}, {512,3,3,3,512}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x256x128_2x1x1}, +{{{8,6,386,258,256}, {256,3,3,3,256}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x256x128_2x1x1}, +{{{8,6,386,258,512}, {256,3,3,3,512}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x256x128_2x1x1}, }; // clang-format on @@ -132,14 +213,46 @@ Kernel_f8f8bf16_conv get_kernel_via_heuristic( padding, stride, dilation}; + + // Try exact match first auto it = kernel_map.find(ps); if (it != kernel_map.end()) { return it->second; - } else { - std::cout << "warning: not found - "; - ps.print(); - std::cout << std::endl; } + +#if 1 + // If no exact match, look for configs with same spatial dims and filter + // but use the one with the largest batch that is <= current batch + int64_t current_batch = ps.activation_shape[0]; + Kernel_f8f8bf16_conv best_kernel = nullptr; + int64_t best_batch = 0; + + for (const auto& [candidate_ps, kernel] : kernel_map) { + // Check if spatial dimensions and filter match + if (candidate_ps.activation_shape[1] == ps.activation_shape[1] && + candidate_ps.activation_shape[2] == ps.activation_shape[2] && + candidate_ps.activation_shape[3] == ps.activation_shape[3] && + candidate_ps.filter_shape == ps.filter_shape) { + int64_t candidate_batch = candidate_ps.activation_shape[0]; + + // Use config with largest batch that is <= current batch + if (candidate_batch <= current_batch && candidate_batch > best_batch) { + best_batch = candidate_batch; + best_kernel = kernel; + } + } + } + + if (best_kernel != nullptr) { + return best_kernel; + } +#endif + + // No suitable config found + std::cout << "warning: not found - "; + ps.print(); + std::cout << std::endl; + // Fallback kernel return f8f8bf16_conv_256x256x128_2x1x1; } diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_conv/f8f8bf16_conv_128x128x128_2x1x1.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_conv/f8f8bf16_conv_128x128x128_2x1x1.cu new file mode 100644 index 0000000000..1fa39f8221 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_conv/f8f8bf16_conv_128x128x128_2x1x1.cu @@ -0,0 +1,32 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "f8f8bf16_conv_common.cuh" + +namespace fbgemm_gpu { + +at::Tensor f8f8bf16_conv_128x128x128_2x1x1( + at::Tensor activation, // FP8 - NDHWC layout + at::Tensor filter, // FP8 - KTRSC layout + at::Tensor scale, + std::vector padding, // [pad_d, pad_h, pad_w] + std::vector stride, // [stride_d, stride_h, stride_w] + std::vector dilation) { // [dilation_d, dilation_h, dilation_w] + + return f8f8bf16_conv_impl< + 128, + 128, + 128, + 2, + 1, + 1, + cutlass::conv::KernelImplicitTmaWarpSpecialized2SmSm100>( + activation, filter, scale, padding, stride, dilation); +} + +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_conv/f8f8bf16_conv_256x1024x128_4x4x1.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_conv/f8f8bf16_conv_256x1024x128_4x4x1.cu new file mode 100644 index 0000000000..0df15235a6 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_conv/f8f8bf16_conv_256x1024x128_4x4x1.cu @@ -0,0 +1,32 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "f8f8bf16_conv_common.cuh" + +namespace fbgemm_gpu { + +at::Tensor f8f8bf16_conv_256x1024x128_4x4x1( + at::Tensor activation, // FP8 - NDHWC layout + at::Tensor filter, // FP8 - KTRSC layout + at::Tensor scale, + std::vector padding, // [pad_d, pad_h, pad_w] + std::vector stride, // [stride_d, stride_h, stride_w] + std::vector dilation) { // [dilation_d, dilation_h, dilation_w] + + return f8f8bf16_conv_impl< + 128, + 256, + 128, + 4, + 4, + 1, + cutlass::conv::KernelImplicitTmaWarpSpecialized2SmSm100>( + activation, filter, scale, padding, stride, dilation); +} + +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_conv/f8f8bf16_conv_512x1024x128_4x4x1.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_conv/f8f8bf16_conv_512x1024x128_4x4x1.cu new file mode 100644 index 0000000000..0fc353d519 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_conv/f8f8bf16_conv_512x1024x128_4x4x1.cu @@ -0,0 +1,32 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "f8f8bf16_conv_common.cuh" + +namespace fbgemm_gpu { + +at::Tensor f8f8bf16_conv_512x1024x128_4x4x1( + at::Tensor activation, // FP8 - NDHWC layout + at::Tensor filter, // FP8 - KTRSC layout + at::Tensor scale, + std::vector padding, // [pad_d, pad_h, pad_w] + std::vector stride, // [stride_d, stride_h, stride_w] + std::vector dilation) { // [dilation_d, dilation_h, dilation_w] + + return f8f8bf16_conv_impl< + 256, + 256, + 128, + 4, + 4, + 1, + cutlass::conv::KernelImplicitTmaWarpSpecialized2SmSm100>( + activation, filter, scale, padding, stride, dilation); +} + +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_conv/f8f8bf16_conv_512x512x128_4x2x1.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_conv/f8f8bf16_conv_512x512x128_4x2x1.cu new file mode 100644 index 0000000000..af7648894d --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_conv/f8f8bf16_conv_512x512x128_4x2x1.cu @@ -0,0 +1,32 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "f8f8bf16_conv_common.cuh" + +namespace fbgemm_gpu { + +at::Tensor f8f8bf16_conv_512x512x128_4x2x1( + at::Tensor activation, // FP8 - NDHWC layout + at::Tensor filter, // FP8 - KTRSC layout + at::Tensor scale, + std::vector padding, // [pad_d, pad_h, pad_w] + std::vector stride, // [stride_d, stride_h, stride_w] + std::vector dilation) { // [dilation_d, dilation_h, dilation_w] + + return f8f8bf16_conv_impl< + 256, + 256, + 128, + 4, + 2, + 1, + cutlass::conv::KernelImplicitTmaWarpSpecialized2SmSm100>( + activation, filter, scale, padding, stride, dilation); +} + +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_conv/f8f8bf16_conv_manifest.cuh b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_conv/f8f8bf16_conv_manifest.cuh index 831e311ccb..78194bff09 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_conv/f8f8bf16_conv_manifest.cuh +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_conv/f8f8bf16_conv_manifest.cuh @@ -26,6 +26,14 @@ at::Tensor f8f8bf16_conv_128x128x128_1x1x1( std::vector stride, // [stride_d, stride_h, stride_w] std::vector dilation); +at::Tensor f8f8bf16_conv_128x128x128_2x1x1( + at::Tensor activation, // FP8 - NDHWC layout + at::Tensor filter, // FP8 - KTRSC layout + at::Tensor scale, + std::vector padding, // [pad_d, pad_h, pad_w] + std::vector stride, // [stride_d, stride_h, stride_w] + std::vector dilation); + at::Tensor f8f8bf16_conv_128x256x128_1x2x1( at::Tensor activation, // FP8 - NDHWC layout at::Tensor filter, // FP8 - KTRSC layout @@ -90,6 +98,14 @@ at::Tensor f8f8bf16_conv_256x512x128_2x2x1( std::vector stride, // [stride_d, stride_h, stride_w] std::vector dilation); +at::Tensor f8f8bf16_conv_256x1024x128_4x4x1( + at::Tensor activation, // FP8 - NDHWC layout + at::Tensor filter, // FP8 - KTRSC layout + at::Tensor scale, + std::vector padding, // [pad_d, pad_h, pad_w] + std::vector stride, // [stride_d, stride_h, stride_w] + std::vector dilation); + at::Tensor f8f8bf16_conv_512x256x128_4x1x1( at::Tensor activation, // FP8 - NDHWC layout at::Tensor filter, // FP8 - KTRSC layout @@ -98,6 +114,22 @@ at::Tensor f8f8bf16_conv_512x256x128_4x1x1( std::vector stride, // [stride_d, stride_h, stride_w] std::vector dilation); +at::Tensor f8f8bf16_conv_512x512x128_4x2x1( + at::Tensor activation, // FP8 - NDHWC layout + at::Tensor filter, // FP8 - KTRSC layout + at::Tensor scale, + std::vector padding, // [pad_d, pad_h, pad_w] + std::vector stride, // [stride_d, stride_h, stride_w] + std::vector dilation); + +at::Tensor f8f8bf16_conv_512x1024x128_4x4x1( + at::Tensor activation, // FP8 - NDHWC layout + at::Tensor filter, // FP8 - KTRSC layout + at::Tensor scale, + std::vector padding, // [pad_d, pad_h, pad_w] + std::vector stride, // [stride_d, stride_h, stride_w] + std::vector dilation); + using Kernel_f8f8bf16_conv = at::Tensor (*)( at::Tensor activation, // FP8 - NDHWC layout at::Tensor filter, // FP8 - KTRSC layout