Skip to content

Commit d4c7a5d

Browse files
author
Xianzhe Dong
committed
[ut] add layernorm kernel unitest
1 parent 7de5c75 commit d4c7a5d

File tree

5 files changed

+237
-69
lines changed

5 files changed

+237
-69
lines changed

src/kernels/CMakeLists.txt

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
include(cc_library)
2+
include(cc_binary)
23

34
cc_library(
45
NAME
@@ -72,6 +73,24 @@ cc_library(
7273
torch
7374
)
7475

76+
# cc_test(
77+
# NAME
78+
# layernorm_kernels_test
79+
# SRCS
80+
# layernrom_kernels_test.cu
81+
# layernorm_kernels.cu
82+
# DEPS
83+
# DEFINES
84+
# )
85+
cc_binary(
86+
NAME
87+
layernorm_kernels_test
88+
SRCS
89+
layernrom_kernels_test.cu
90+
layernorm_kernels.cu
91+
DEPS
92+
torch
93+
)
94+
7595
add_subdirectory(flash_attn)
7696
add_subdirectory(flash_infer)
77-

src/kernels/layernorm_kernels.cu

Lines changed: 29 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
#include <ATen/cuda/CUDAContext.h>
22
#include <torch/torch.h>
3+
34
#include "dispatch.h"
4-
#include "reduce_kernel_utils.cuh"
55
#include "layernorm_kernels.h"
6+
#include "reduce_kernel_utils.cuh"
67

78
namespace llm::kernel {
89

@@ -116,11 +117,11 @@ __global__ void layer_norm_kernel(T* __restrict__ out,
116117
// The mean and standard-deviation are calculated over the last dimension
117118
template <>
118119
__global__ void layer_norm_kernel<half2>(half2* __restrict__ out,
119-
const half2* __restrict__ input,
120-
const half2* __restrict__ weight,
121-
const half2* __restrict__ bias,
122-
const float epsilon,
123-
int n) {
120+
const half2* __restrict__ input,
121+
const half2* __restrict__ weight,
122+
const half2* __restrict__ bias,
123+
const float epsilon,
124+
int n) {
124125
const int tidx = threadIdx.x;
125126
const int bidx = blockIdx.x;
126127

@@ -147,7 +148,6 @@ __global__ void layer_norm_kernel<half2>(half2* __restrict__ out,
147148
}
148149
variance = block_reduce_sum<half2>(variance);
149150
if (tidx == 0) {
150-
// const half2 e = make_half2(__float2half(epsilon), __float2half(epsilon));
151151
s_variance = __hadd(variance.x, variance.y);
152152
s_variance = __hdiv(s_variance, __float2half((float)n * 2));
153153
s_variance = __hadd(s_variance, __float2half(epsilon));
@@ -157,16 +157,11 @@ __global__ void layer_norm_kernel<half2>(half2* __restrict__ out,
157157

158158
for (int i = tidx; i < n; i += blockDim.x) {
159159
const int idx = bidx * n + i;
160-
// float local_out =
161-
// (__ldg(&input[idx]) - s_mean) * s_variance * __ldg(&weight[i]);
162-
// if (bias != nullptr) {
163-
// local_out += __ldg(&bias[i]);
164-
// }
165160
half2 local_out = __ldg(&input[idx]);
166161
local_out = __hsub2(local_out, make_half2(s_mean, s_mean));
167162
local_out = __hmul2(local_out, make_half2(s_variance, s_variance));
168163
local_out = __hmul2(local_out, __ldg(&weight[i]));
169-
if (bias != nullptr){
164+
if (bias != nullptr) {
170165
local_out = __hadd2(local_out, __ldg(&bias[i]));
171166
}
172167
out[idx] = local_out;
@@ -199,52 +194,34 @@ void layer_norm(torch::Tensor& out,
199194

200195
template <typename T>
201196
void invoke_layernorm_kernel(T* out,
202-
const T* input,
203-
const T* weight,
204-
const T* bias,
205-
const float epsilon,
206-
int m,
207-
int n) {
197+
const T* input,
198+
const T* weight,
199+
const T* bias,
200+
const float epsilon,
201+
int m,
202+
int n) {
208203
layer_norm_kernel<T><<<m, n>>>(out, input, weight, bias, epsilon, n);
209204
}
210205

211206
template <>
212207
void invoke_layernorm_kernel<half2>(half2* out,
213-
const half2* input,
214-
const half2* weight,
215-
const half2* bias,
216-
const float epsilon,
217-
int m,
218-
int n) {
208+
const half2* input,
209+
const half2* weight,
210+
const half2* bias,
211+
const float epsilon,
212+
int m,
213+
int n) {
219214
layer_norm_kernel<half2><<<m, n>>>(out, input, weight, bias, epsilon, n);
220215
}
221216
template <>
222217
void invoke_layernorm_kernel<float>(float* out,
223-
const float* input,
224-
const float* weight,
225-
const float* bias,
226-
const float epsilon,
227-
int m,
228-
int n) {
218+
const float* input,
219+
const float* weight,
220+
const float* bias,
221+
const float epsilon,
222+
int m,
223+
int n) {
229224
layer_norm_kernel<float><<<m, n>>>(out, input, weight, bias, epsilon, n);
230-
}
231-
// void invoke_float_layernorm_kernel(float* out,
232-
// const float* input,
233-
// const float* weight,
234-
// const float* bias,
235-
// const float epsilon,
236-
// int m,
237-
// int n){
238-
// layer_norm_kernel<float><<<m, n>>>(out, input, weight, bias, epsilon, n);
239-
// }
240-
241-
// void invoke_half2_layernorm_kernel(half2* out,
242-
// const half2* input,
243-
// const half2* weight,
244-
// const half2* bias,
245-
// const float epsilon,
246-
// int m,
247-
// int n){
248-
// layer_norm_kernel<half2><<<m, n>>>(out, input, weight, bias, epsilon, n);
249-
// }
250-
} // namespace llm::kernel
225+
}
226+
227+
} // namespace llm::kernel

src/kernels/layernorm_kernels.h

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -22,20 +22,4 @@ void invoke_layernorm_kernel(T* out,
2222
const float epsilon,
2323
int m,
2424
int n);
25-
26-
// void invoke_float_layernorm_kernel(float* out,
27-
// const float* input,
28-
// const float* weight,
29-
// const float* bias,
30-
// const float epsilon,
31-
// int m,
32-
// int n);
33-
34-
// void invoke_half2_layernorm_kernel(half2* out,
35-
// const half2* input,
36-
// const half2* weight,
37-
// const half2* bias,
38-
// const float epsilon,
39-
// int m,
40-
// int n);
4125
} // namespace llm::kernel
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
#include <cuda_fp16.h>
2+
3+
#include <cstdio>
4+
5+
#include "layernorm_kernels.h"
6+
7+
template <typename T>
8+
void printMatrix(T* a, int m, int n) {
9+
for (int i = 0; i < m; i++) {
10+
for (int j = 0; j < n; j++) {
11+
printf("%f ", (float)a[i * n + j]);
12+
}
13+
puts("");
14+
}
15+
puts("");
16+
}
17+
18+
template <>
19+
void printMatrix<half2>(half2* a, int m, int n) {
20+
for (int i = 0; i < m; i++) {
21+
for (int j = 0; j < n; j++) {
22+
printf(
23+
"%f %f ", __half2float(a[i * n + j].x), __half2float(a[i * n + j].y));
24+
}
25+
puts("");
26+
}
27+
puts("");
28+
}
29+
30+
void layernorm_kernel_half2_test() {
31+
float epsilon = 1e-6;
32+
int m = 2;
33+
int n = 2;
34+
35+
half2* out = (half2*)malloc(m * n * sizeof(half2));
36+
half2* input = (half2*)malloc(m * n * sizeof(half2));
37+
half2* weight = (half2*)malloc(m * n * sizeof(half2));
38+
half2* bias = (half2*)malloc(m * n * sizeof(half2));
39+
40+
for (int i = 0; i < m; i++) {
41+
for (int j = 0; j < n; j++) {
42+
input[i * n + j] = half2(__float2half((float)(i * n + j * 2)),
43+
__float2half((float)(i * n + j * 2 + 1)));
44+
weight[i * n + j] = half2(__float2half(1.), __float2half(1.));
45+
bias[i * n + j] = half2(__float2half(0.), __float2half(0.));
46+
}
47+
}
48+
49+
half2* dout;
50+
half2* dinput;
51+
half2* dweight;
52+
half2* dbias;
53+
cudaMalloc((void**)&dout, sizeof(half2) * m * n);
54+
cudaMalloc((void**)&dinput, sizeof(half2) * m * n);
55+
cudaMalloc((void**)&dweight, sizeof(half2) * m * n);
56+
cudaMalloc((void**)&dbias, sizeof(half2) * m * n);
57+
58+
cudaMemcpy(dinput, input, sizeof(half2) * m * n, cudaMemcpyHostToDevice);
59+
cudaMemcpy(dweight, weight, sizeof(half2) * m * n, cudaMemcpyHostToDevice);
60+
cudaMemcpy(dbias, bias, sizeof(half2) * m * n, cudaMemcpyHostToDevice);
61+
62+
llm::kernel::invoke_layernorm_kernel<half2>(
63+
dout, dinput, dweight, dbias, epsilon, m, n);
64+
65+
cudaMemcpy(out, dout, sizeof(half2) * m * n, cudaMemcpyDeviceToHost);
66+
67+
printf("---------- test half2 layernorm kernel -----------\n");
68+
printf("input:\n");
69+
printMatrix<half2>(input, m, n);
70+
printf("weights:\n");
71+
printMatrix<half2>(weight, m, n);
72+
printf("bias:\n");
73+
printMatrix<half2>(bias, m, n);
74+
printf("outputs:\n");
75+
printMatrix<half2>(out, m, n);
76+
}
77+
78+
void layernorm_kernel_float_test() {
79+
float epsilon = 1e-6;
80+
int m = 2;
81+
int n = 4;
82+
83+
float* out = (float*)malloc(m * n * sizeof(float));
84+
float* input = (float*)malloc(m * n * sizeof(float));
85+
float* weight = (float*)malloc(m * n * sizeof(float));
86+
float* bias = (float*)malloc(m * n * sizeof(float));
87+
88+
for (int i = 0; i < m; i++) {
89+
for (int j = 0; j < n; j++) {
90+
input[i * n + j] = (float)(i * n + j);
91+
weight[i * n + j] = 1.;
92+
bias[i * n + j] = 0.;
93+
}
94+
}
95+
96+
float* dout;
97+
float* dinput;
98+
float* dweight;
99+
float* dbias;
100+
cudaMalloc((void**)&dout, sizeof(float) * m * n);
101+
cudaMalloc((void**)&dinput, sizeof(float) * m * n);
102+
cudaMalloc((void**)&dweight, sizeof(float) * m * n);
103+
cudaMalloc((void**)&dbias, sizeof(float) * m * n);
104+
105+
cudaMemcpy(dinput, input, sizeof(float) * m * n, cudaMemcpyHostToDevice);
106+
cudaMemcpy(dweight, weight, sizeof(float) * m * n, cudaMemcpyHostToDevice);
107+
cudaMemcpy(dbias, bias, sizeof(float) * m * n, cudaMemcpyHostToDevice);
108+
109+
llm::kernel::invoke_layernorm_kernel<float>(
110+
dout, dinput, dweight, dbias, epsilon, m, n);
111+
112+
cudaMemcpy(out, dout, sizeof(float) * m * n, cudaMemcpyDeviceToHost);
113+
114+
printf("---------- test float layernorm kernel -----------\n");
115+
printf("input:\n");
116+
printMatrix<float>(input, m, n);
117+
printf("weights:\n");
118+
printMatrix<float>(weight, m, n);
119+
printf("bias:\n");
120+
printMatrix<float>(bias, m, n);
121+
printf("outputs:\n");
122+
printMatrix<float>(out, m, n);
123+
}
124+
125+
int main() {
126+
layernorm_kernel_float_test();
127+
layernorm_kernel_half2_test();
128+
return 0;
129+
}

src/kernels/reduce_kernel_utils.cuh

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,36 @@ __inline__ __device__ T warp_reduce_sum(T val) {
2424
return val;
2525
}
2626

27+
// performs a parallel reduction operation across the threads within a single
28+
// warp (32 threads).
29+
// - val: The value to be reduced within a warp.
30+
template <>
31+
__inline__ __device__ half warp_reduce_sum<half>(half val) {
32+
// uses bitwise operations to perform a parallel reduction
33+
// within a warp. The 'mask' is right-shifted by 1 in each iteration
34+
// until it reaches zero, effectively summing all values within the warp.
35+
#pragma unroll
36+
for (int mask = 16; mask > 0; mask >>= 1) {
37+
val = __hadd(val, __shfl_xor_sync(FINAL_MASK, val, mask, 32));
38+
}
39+
return val;
40+
}
41+
42+
// performs a parallel reduction operation across the threads within a single
43+
// warp (32 threads).
44+
// - val: The value to be reduced within a warp.
45+
template <>
46+
__inline__ __device__ half2 warp_reduce_sum<half2>(half2 val) {
47+
// uses bitwise operations to perform a parallel reduction
48+
// within a warp. The 'mask' is right-shifted by 1 in each iteration
49+
// until it reaches zero, effectively summing all values within the warp.
50+
#pragma unroll
51+
for (int mask = 16; mask > 0; mask >>= 1) {
52+
val = __hadd2(val, __shfl_xor_sync(FINAL_MASK, val, mask, 32));
53+
}
54+
return val;
55+
}
56+
2757
// performs a parallel reduction operation across the threads within a single
2858
// warp (32 threads).
2959
// - val: The value to be reduced within a warp.
@@ -63,6 +93,35 @@ __inline__ __device__ T block_reduce_sum(T val) {
6393
return val;
6494
}
6595

96+
/* Calculate the sum of all elements in a thread block */
97+
template <>
98+
__inline__ __device__ half2 block_reduce_sum<half2>(half2 val) {
99+
// up to 32 warps in a block
100+
static __shared__ half2 shared[32];
101+
// lane id in a warp
102+
int lane = threadIdx.x & 0x1f;
103+
// wrap id: threadIdx.x / 32
104+
int wid = threadIdx.x >> 5;
105+
106+
// perform a parallel reduction across the threads within each warp
107+
val = warp_reduce_sum<half2>(val);
108+
109+
if (lane == 0) {
110+
// write the sum of each warp to shared memory
111+
shared[wid] = val;
112+
}
113+
// wait for all warps to finish
114+
__syncthreads();
115+
116+
// Modify from blockDim.x << 5 to blockDim.x / 32. to prevent
117+
// blockDim.x is not divided by 32
118+
val = (threadIdx.x < (blockDim.x / 32.f))
119+
? shared[lane]
120+
: make_half2(__float2half(0.0f), __float2half(0.0f));
121+
val = warp_reduce_sum<half2>(val);
122+
return val;
123+
}
124+
66125
/* Calculate the max of all elements in a thread block */
67126
template <typename T>
68127
__inline__ __device__ T block_reduce_max(T val) {

0 commit comments

Comments
 (0)