22#include < torch/torch.h>
33#include " dispatch.h"
44#include " reduce_kernel_utils.cuh"
5+ #include " layernorm_kernels.h"
56
67namespace llm ::kernel {
78
@@ -111,6 +112,67 @@ __global__ void layer_norm_kernel(T* __restrict__ out,
111112 }
112113}
113114
115+ // equation: x -> (x - E[x]) / sqrt(Var[x] + eps) * w + b
116+ // The mean and standard-deviation are calculated over the last dimension
117+ template <>
118+ __global__ void layer_norm_kernel<half2>(half2* __restrict__ out,
119+ const half2* __restrict__ input,
120+ const half2* __restrict__ weight,
121+ const half2* __restrict__ bias,
122+ const float epsilon,
123+ int n) {
124+ const int tidx = threadIdx .x ;
125+ const int bidx = blockIdx .x ;
126+
127+ __shared__ half s_mean;
128+ __shared__ half s_variance;
129+ half2 mean = make_half2 (__float2half (0 .0f ), __float2half (0 .0f ));
130+ half2 variance = make_half2 (__float2half (0 .0f ), __float2half (0 .0f ));
131+
132+ // calculate mean of the input.
133+ for (int i = tidx; i < n; i += blockDim .x ) {
134+ const int idx = bidx * n + i;
135+ mean = __hadd2 (mean, __ldg (&input[idx]));
136+ }
137+ mean = block_reduce_sum<half2>(mean);
138+ if (tidx == 0 ) {
139+ s_mean = __hdiv (__hadd (mean.x , mean.y ), __float2half ((float )n * 2 ));
140+ }
141+ __syncthreads ();
142+
143+ // calculate variance of the input.
144+ for (int i = tidx; i < n; i += blockDim .x ) {
145+ const half2 x = __hsub2 (input[bidx * n + i], make_half2 (s_mean, s_mean));
146+ variance = __hadd2 (variance, __hmul2 (x, x));
147+ }
148+ variance = block_reduce_sum<half2>(variance);
149+ if (tidx == 0 ) {
150+ // const half2 e = make_half2(__float2half(epsilon), __float2half(epsilon));
151+ s_variance = __hadd (variance.x , variance.y );
152+ s_variance = __hdiv (s_variance, __float2half ((float )n * 2 ));
153+ s_variance = __hadd (s_variance, __float2half (epsilon));
154+ s_variance = hrsqrt (s_variance);
155+ }
156+ __syncthreads ();
157+
158+ for (int i = tidx; i < n; i += blockDim .x ) {
159+ const int idx = bidx * n + i;
160+ // float local_out =
161+ // (__ldg(&input[idx]) - s_mean) * s_variance * __ldg(&weight[i]);
162+ // if (bias != nullptr) {
163+ // local_out += __ldg(&bias[i]);
164+ // }
165+ half2 local_out = __ldg (&input[idx]);
166+ local_out = __hsub2 (local_out, make_half2 (s_mean, s_mean));
167+ local_out = __hmul2 (local_out, make_half2 (s_variance, s_variance));
168+ local_out = __hmul2 (local_out, __ldg (&weight[i]));
169+ if (bias != nullptr ){
170+ local_out = __hadd2 (local_out, __ldg (&bias[i]));
171+ }
172+ out[idx] = local_out;
173+ }
174+ }
175+
114176void layer_norm (torch::Tensor& out,
115177 torch::Tensor input,
116178 torch::Tensor weight,
@@ -135,4 +197,54 @@ void layer_norm(torch::Tensor& out,
135197 });
136198}
137199
200+ template <typename T>
201+ void invoke_layernorm_kernel (T* out,
202+ const T* input,
203+ const T* weight,
204+ const T* bias,
205+ const float epsilon,
206+ int m,
207+ int n) {
208+ layer_norm_kernel<T><<<m, n>>> (out, input, weight, bias, epsilon, n);
209+ }
210+
211+ template <>
212+ void invoke_layernorm_kernel<half2>(half2* out,
213+ const half2* input,
214+ const half2* weight,
215+ const half2* bias,
216+ const float epsilon,
217+ int m,
218+ int n) {
219+ layer_norm_kernel<half2><<<m, n>>> (out, input, weight, bias, epsilon, n);
220+ }
221+ template <>
222+ void invoke_layernorm_kernel<float >(float * out,
223+ const float * input,
224+ const float * weight,
225+ const float * bias,
226+ const float epsilon,
227+ int m,
228+ int n) {
229+ layer_norm_kernel<float ><<<m, n>>> (out, input, weight, bias, epsilon, n);
230+ }
231+ // void invoke_float_layernorm_kernel(float* out,
232+ // const float* input,
233+ // const float* weight,
234+ // const float* bias,
235+ // const float epsilon,
236+ // int m,
237+ // int n){
238+ // layer_norm_kernel<float><<<m, n>>>(out, input, weight, bias, epsilon, n);
239+ // }
240+
241+ // void invoke_half2_layernorm_kernel(half2* out,
242+ // const half2* input,
243+ // const half2* weight,
244+ // const half2* bias,
245+ // const float epsilon,
246+ // int m,
247+ // int n){
248+ // layer_norm_kernel<half2><<<m, n>>>(out, input, weight, bias, epsilon, n);
249+ // }
138250} // namespace llm::kernel
0 commit comments