Skip to content

Commit 7de5c75

Browse files
author
Xianzhe Dong
committed
[op] optimize layernorm kernel for half2 type
1 parent bec08f7 commit 7de5c75

File tree

2 files changed

+137
-1
lines changed

2 files changed

+137
-1
lines changed

src/kernels/layernorm_kernels.cu

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#include <torch/torch.h>
33
#include "dispatch.h"
44
#include "reduce_kernel_utils.cuh"
5+
#include "layernorm_kernels.h"
56

67
namespace llm::kernel {
78

@@ -111,6 +112,67 @@ __global__ void layer_norm_kernel(T* __restrict__ out,
111112
}
112113
}
113114

115+
// equation: x -> (x - E[x]) / sqrt(Var[x] + eps) * w + b
116+
// The mean and standard-deviation are calculated over the last dimension
117+
template <>
118+
__global__ void layer_norm_kernel<half2>(half2* __restrict__ out,
119+
const half2* __restrict__ input,
120+
const half2* __restrict__ weight,
121+
const half2* __restrict__ bias,
122+
const float epsilon,
123+
int n) {
124+
const int tidx = threadIdx.x;
125+
const int bidx = blockIdx.x;
126+
127+
__shared__ half s_mean;
128+
__shared__ half s_variance;
129+
half2 mean = make_half2(__float2half(0.0f), __float2half(0.0f));
130+
half2 variance = make_half2(__float2half(0.0f), __float2half(0.0f));
131+
132+
// calculate mean of the input.
133+
for (int i = tidx; i < n; i += blockDim.x) {
134+
const int idx = bidx * n + i;
135+
mean = __hadd2(mean, __ldg(&input[idx]));
136+
}
137+
mean = block_reduce_sum<half2>(mean);
138+
if (tidx == 0) {
139+
s_mean = __hdiv(__hadd(mean.x, mean.y), __float2half((float)n * 2));
140+
}
141+
__syncthreads();
142+
143+
// calculate variance of the input.
144+
for (int i = tidx; i < n; i += blockDim.x) {
145+
const half2 x = __hsub2(input[bidx * n + i], make_half2(s_mean, s_mean));
146+
variance = __hadd2(variance, __hmul2(x, x));
147+
}
148+
variance = block_reduce_sum<half2>(variance);
149+
if (tidx == 0) {
150+
// const half2 e = make_half2(__float2half(epsilon), __float2half(epsilon));
151+
s_variance = __hadd(variance.x, variance.y);
152+
s_variance = __hdiv(s_variance, __float2half((float)n * 2));
153+
s_variance = __hadd(s_variance, __float2half(epsilon));
154+
s_variance = hrsqrt(s_variance);
155+
}
156+
__syncthreads();
157+
158+
for (int i = tidx; i < n; i += blockDim.x) {
159+
const int idx = bidx * n + i;
160+
// float local_out =
161+
// (__ldg(&input[idx]) - s_mean) * s_variance * __ldg(&weight[i]);
162+
// if (bias != nullptr) {
163+
// local_out += __ldg(&bias[i]);
164+
// }
165+
half2 local_out = __ldg(&input[idx]);
166+
local_out = __hsub2(local_out, make_half2(s_mean, s_mean));
167+
local_out = __hmul2(local_out, make_half2(s_variance, s_variance));
168+
local_out = __hmul2(local_out, __ldg(&weight[i]));
169+
if (bias != nullptr){
170+
local_out = __hadd2(local_out, __ldg(&bias[i]));
171+
}
172+
out[idx] = local_out;
173+
}
174+
}
175+
114176
void layer_norm(torch::Tensor& out,
115177
torch::Tensor input,
116178
torch::Tensor weight,
@@ -135,4 +197,54 @@ void layer_norm(torch::Tensor& out,
135197
});
136198
}
137199

200+
template <typename T>
201+
void invoke_layernorm_kernel(T* out,
202+
const T* input,
203+
const T* weight,
204+
const T* bias,
205+
const float epsilon,
206+
int m,
207+
int n) {
208+
layer_norm_kernel<T><<<m, n>>>(out, input, weight, bias, epsilon, n);
209+
}
210+
211+
template <>
212+
void invoke_layernorm_kernel<half2>(half2* out,
213+
const half2* input,
214+
const half2* weight,
215+
const half2* bias,
216+
const float epsilon,
217+
int m,
218+
int n) {
219+
layer_norm_kernel<half2><<<m, n>>>(out, input, weight, bias, epsilon, n);
220+
}
221+
template <>
222+
void invoke_layernorm_kernel<float>(float* out,
223+
const float* input,
224+
const float* weight,
225+
const float* bias,
226+
const float epsilon,
227+
int m,
228+
int n) {
229+
layer_norm_kernel<float><<<m, n>>>(out, input, weight, bias, epsilon, n);
230+
}
231+
// void invoke_float_layernorm_kernel(float* out,
232+
// const float* input,
233+
// const float* weight,
234+
// const float* bias,
235+
// const float epsilon,
236+
// int m,
237+
// int n){
238+
// layer_norm_kernel<float><<<m, n>>>(out, input, weight, bias, epsilon, n);
239+
// }
240+
241+
// void invoke_half2_layernorm_kernel(half2* out,
242+
// const half2* input,
243+
// const half2* weight,
244+
// const half2* bias,
245+
// const float epsilon,
246+
// int m,
247+
// int n){
248+
// layer_norm_kernel<half2><<<m, n>>>(out, input, weight, bias, epsilon, n);
249+
// }
138250
} // namespace llm::kernel

src/kernels/layernorm_kernels.h

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,35 @@ void rms_norm(torch::Tensor& out,
77
torch::Tensor input,
88
torch::Tensor weight,
99
float epsilon);
10-
10+
1111
void layer_norm(torch::Tensor& out,
1212
torch::Tensor input,
1313
torch::Tensor weight,
1414
torch::Tensor bias,
1515
float epsilon);
1616

17+
template <typename T>
18+
void invoke_layernorm_kernel(T* out,
19+
const T* input,
20+
const T* weight,
21+
const T* bias,
22+
const float epsilon,
23+
int m,
24+
int n);
25+
26+
// void invoke_float_layernorm_kernel(float* out,
27+
// const float* input,
28+
// const float* weight,
29+
// const float* bias,
30+
// const float epsilon,
31+
// int m,
32+
// int n);
33+
34+
// void invoke_half2_layernorm_kernel(half2* out,
35+
// const half2* input,
36+
// const half2* weight,
37+
// const half2* bias,
38+
// const float epsilon,
39+
// int m,
40+
// int n);
1741
} // namespace llm::kernel

0 commit comments

Comments
 (0)