|
#include <ATen/cuda/CUDAContext.h> |
|
#include <torch/all.h> |
|
#include <cmath> |
|
|
|
#include "dispatch_utils.h" |
|
|
|
#ifndef USE_ROCM |
|
#include <cub/util_type.cuh> |
|
#include <cub/cub.cuh> |
|
#else |
|
#include <hipcub/util_type.hpp> |
|
#include <hipcub/hipcub.hpp> |
|
#endif |
|
|
|
static inline __device__ int8_t float_to_int8_rn(float x) { |
|
#ifdef USE_ROCM |
|
static constexpr auto i8_min = |
|
static_cast<float>(std::numeric_limits<int8_t>::min()); |
|
static constexpr auto i8_max = |
|
static_cast<float>(std::numeric_limits<int8_t>::max()); |
|
|
|
|
|
|
|
|
|
|
|
float dst = std::nearbyint(x); |
|
|
|
|
|
dst = std::clamp(dst, i8_min, i8_max); |
|
return static_cast<int8_t>(dst); |
|
#else |
|
|
|
uint32_t dst; |
|
asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=r"(dst) : "f"(x)); |
|
return reinterpret_cast<const int8_t&>(dst); |
|
#endif |
|
} |
|
|
|
static inline __device__ int32_t float_to_int32_rn(float x) { |
|
#ifdef USE_ROCM |
|
|
|
|
|
|
|
|
|
static constexpr auto i32_min = std::numeric_limits<int32_t>::min(); |
|
static constexpr auto i32_min_f = static_cast<float>(i32_min); |
|
static constexpr auto i32_max = std::numeric_limits<int32_t>::max(); |
|
static constexpr auto i32_max_f = static_cast<float>(i32_max); |
|
|
|
|
|
|
|
|
|
|
|
float dst = std::nearbyint(x); |
|
|
|
|
|
if (dst >= i32_max_f) { |
|
return i32_max; |
|
} |
|
|
|
if (dst <= i32_min_f) { |
|
return i32_min; |
|
} |
|
|
|
return static_cast<int32_t>(dst); |
|
#else |
|
|
|
uint32_t dst; |
|
asm volatile("cvt.rni.sat.s32.f32 %0, %1;" : "=r"(dst) : "f"(x)); |
|
return reinterpret_cast<const int32_t&>(dst); |
|
#endif |
|
} |
|
|
|
static inline __device__ int8_t int32_to_int8(int32_t x) { |
|
#ifdef USE_ROCM |
|
static constexpr auto i8_min = |
|
static_cast<int32_t>(std::numeric_limits<int8_t>::min()); |
|
static constexpr auto i8_max = |
|
static_cast<int32_t>(std::numeric_limits<int8_t>::max()); |
|
|
|
|
|
int32_t dst = std::clamp(x, i8_min, i8_max); |
|
return static_cast<int8_t>(dst); |
|
#else |
|
|
|
uint32_t dst; |
|
asm volatile("cvt.sat.s8.s32 %0, %1;" : "=r"(dst) : "r"(x)); |
|
return reinterpret_cast<const int8_t&>(dst); |
|
#endif |
|
} |
|
|
|
namespace vllm { |
|
|
|
template <typename scalar_t, typename scale_type> |
|
__global__ void static_scaled_int8_quant_kernel( |
|
scalar_t const* __restrict__ input, int8_t* __restrict__ out, |
|
scale_type const* scale_ptr, const int hidden_size) { |
|
int const tid = threadIdx.x; |
|
int64_t const token_idx = blockIdx.x; |
|
scale_type const scale = *scale_ptr; |
|
|
|
|
|
out += token_idx * hidden_size; |
|
input += token_idx * hidden_size; |
|
|
|
for (int i = tid; i < hidden_size; i += blockDim.x) { |
|
out[i] = float_to_int8_rn(static_cast<float>(input[i]) / scale); |
|
} |
|
} |
|
|
|
template <typename scalar_t, typename scale_type, typename azp_type> |
|
__global__ void static_scaled_int8_azp_quant_kernel( |
|
scalar_t const* __restrict__ input, int8_t* __restrict__ out, |
|
scale_type const* scale_ptr, azp_type const* azp_ptr, |
|
const int hidden_size) { |
|
int const tid = threadIdx.x; |
|
int64_t const token_idx = blockIdx.x; |
|
scale_type const scale = *scale_ptr; |
|
azp_type const azp = *azp_ptr; |
|
|
|
|
|
out += token_idx * hidden_size; |
|
input += token_idx * hidden_size; |
|
|
|
for (int i = tid; i < hidden_size; i += blockDim.x) { |
|
auto const val = static_cast<float>(input[i]); |
|
auto const quant_val = int32_to_int8(float_to_int32_rn(val / scale) + azp); |
|
out[i] = quant_val; |
|
} |
|
} |
|
|
|
template <typename scalar_t, typename scale_type> |
|
__global__ void dynamic_scaled_int8_quant_kernel( |
|
scalar_t const* __restrict__ input, int8_t* __restrict__ out, |
|
scale_type* scale, const int hidden_size) { |
|
int const tid = threadIdx.x; |
|
int64_t const token_idx = blockIdx.x; |
|
float absmax_val = 0.0f; |
|
float const zero = 0.0f; |
|
|
|
|
|
out += token_idx * hidden_size; |
|
input += token_idx * hidden_size; |
|
|
|
for (int i = tid; i < hidden_size; i += blockDim.x) { |
|
float val = static_cast<float>(input[i]); |
|
val = val > zero ? val : -val; |
|
absmax_val = val > absmax_val ? val : absmax_val; |
|
} |
|
|
|
using BlockReduce = cub::BlockReduce<float, 1024>; |
|
__shared__ typename BlockReduce::TempStorage reduceStorage; |
|
float const block_absmax_val_maybe = |
|
BlockReduce(reduceStorage).Reduce(absmax_val, cub::Max{}, blockDim.x); |
|
__shared__ float block_absmax_val; |
|
if (tid == 0) { |
|
block_absmax_val = block_absmax_val_maybe; |
|
scale[token_idx] = block_absmax_val / 127.0f; |
|
} |
|
__syncthreads(); |
|
|
|
float const tmp_scale = 127.0f / block_absmax_val; |
|
for (int i = tid; i < hidden_size; i += blockDim.x) { |
|
out[i] = float_to_int8_rn(static_cast<float>(input[i]) * tmp_scale); |
|
} |
|
} |
|
|
|
template <typename scalar_t, typename scale_type, typename azp_type> |
|
__global__ void dynamic_scaled_int8_azp_quant_kernel( |
|
scalar_t const* __restrict__ input, int8_t* __restrict__ out, |
|
scale_type* scale, azp_type* azp, const int hidden_size) { |
|
int64_t const token_idx = blockIdx.x; |
|
|
|
|
|
out += token_idx * hidden_size; |
|
input += token_idx * hidden_size; |
|
|
|
|
|
float max_val = std::numeric_limits<float>::min(); |
|
float min_val = std::numeric_limits<float>::max(); |
|
for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) { |
|
auto val = static_cast<float>(input[i]); |
|
max_val = std::max(max_val, val); |
|
min_val = std::min(min_val, val); |
|
} |
|
|
|
|
|
using BlockReduce = cub::BlockReduce<float, 1024>; |
|
__shared__ typename BlockReduce::TempStorage reduceStorage; |
|
max_val = BlockReduce(reduceStorage).Reduce(max_val, cub::Max{}, blockDim.x); |
|
__syncthreads(); |
|
min_val = BlockReduce(reduceStorage).Reduce(min_val, cub::Min{}, blockDim.x); |
|
|
|
__shared__ scale_type scale_sh; |
|
__shared__ azp_type azp_sh; |
|
|
|
|
|
if (threadIdx.x == 0) { |
|
float const scale_val = (max_val - min_val) / 255.0f; |
|
|
|
auto const azp_float = std::nearbyint(-128.0f - min_val / scale_val); |
|
auto const azp_val = static_cast<azp_type>(azp_float); |
|
|
|
|
|
scale[token_idx] = scale_sh = scale_val; |
|
azp[token_idx] = azp_sh = azp_val; |
|
} |
|
|
|
|
|
__syncthreads(); |
|
|
|
float const scale_val = scale_sh; |
|
azp_type const azp_val = azp_sh; |
|
|
|
|
|
for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) { |
|
auto const val = static_cast<float>(input[i]); |
|
auto const quant_val = |
|
int32_to_int8(float_to_int32_rn(val / scale_val) + azp_val); |
|
out[i] = quant_val; |
|
} |
|
} |
|
|
|
} |
|
|
|
void static_scaled_int8_quant(torch::Tensor& out, |
|
torch::Tensor const& input, |
|
torch::Tensor const& scale, |
|
std::optional<torch::Tensor> const& azp) { |
|
TORCH_CHECK(input.is_contiguous()); |
|
TORCH_CHECK(out.is_contiguous()); |
|
TORCH_CHECK(scale.numel() == 1); |
|
TORCH_CHECK(!azp || azp->numel() == 1); |
|
|
|
int const hidden_size = input.size(-1); |
|
int const num_tokens = input.numel() / hidden_size; |
|
dim3 const grid(num_tokens); |
|
dim3 const block(std::min(hidden_size, 1024)); |
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); |
|
VLLM_DISPATCH_FLOATING_TYPES( |
|
input.scalar_type(), "static_scaled_int8_quant_kernel", [&] { |
|
if (!azp) { |
|
vllm::static_scaled_int8_quant_kernel<scalar_t, float> |
|
<<<grid, block, 0, stream>>>( |
|
input.data_ptr<scalar_t>(), out.data_ptr<int8_t>(), |
|
scale.data_ptr<float>(), hidden_size); |
|
} else { |
|
vllm::static_scaled_int8_azp_quant_kernel<scalar_t, float, int32_t> |
|
<<<grid, block, 0, stream>>>( |
|
input.data_ptr<scalar_t>(), out.data_ptr<int8_t>(), |
|
scale.data_ptr<float>(), azp->data_ptr<int32_t>(), |
|
hidden_size); |
|
} |
|
}); |
|
} |
|
|
|
void dynamic_scaled_int8_quant( |
|
torch::Tensor& out, |
|
torch::Tensor const& input, |
|
torch::Tensor& scales, std::optional<torch::Tensor> const& azp) { |
|
TORCH_CHECK(input.is_contiguous()); |
|
TORCH_CHECK(out.is_contiguous()); |
|
TORCH_CHECK(scales.is_contiguous()); |
|
TORCH_CHECK(!azp || azp->is_contiguous()); |
|
|
|
int const hidden_size = input.size(-1); |
|
int const num_tokens = input.numel() / hidden_size; |
|
dim3 const grid(num_tokens); |
|
dim3 const block(std::min(hidden_size, 1024)); |
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); |
|
VLLM_DISPATCH_FLOATING_TYPES( |
|
input.scalar_type(), "dynamic_scaled_int8_quant_kernel", [&] { |
|
if (!azp) { |
|
vllm::dynamic_scaled_int8_quant_kernel<scalar_t, float> |
|
<<<grid, block, 0, stream>>>( |
|
input.data_ptr<scalar_t>(), out.data_ptr<int8_t>(), |
|
scales.data_ptr<float>(), hidden_size); |
|
} else { |
|
vllm::dynamic_scaled_int8_azp_quant_kernel<scalar_t, float, int32_t> |
|
<<<grid, block, 0, stream>>>( |
|
input.data_ptr<scalar_t>(), out.data_ptr<int8_t>(), |
|
scales.data_ptr<float>(), azp->data_ptr<int32_t>(), |
|
hidden_size); |
|
} |
|
}); |
|
} |
|
|