Blackroot's picture
Upload 19 files
a83aa44 verified
import torch
import triton
import triton.language as tl
# https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/ops/rope.py
# BSD 2-CLAUSE LICENSE
# Copyright 2024 LinkedIn Corporation
# All Rights Reserved.
# Redistribution and use in source and binary forms, with or
# without modification, are permitted provided that the following
# conditions are met:
# 1. Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above
# copyright notice, this list of conditions and the following
# disclaimer in the documentation and/or other materials provided
# with the distribution.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
@triton.jit
def _triton_rope(
q_ptr,
q_row_stride,
k_ptr,
k_row_stride,
cos,
cos_row_stride,
sin,
sin_row_stride,
sl,
bs: tl.constexpr,
cos_bs: tl.constexpr,
n_qh: tl.constexpr,
n_kh: tl.constexpr,
hd: tl.constexpr,
pad_n_qh: tl.constexpr,
pad_n_kh: tl.constexpr,
pad_hd: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
BACKWARD_PASS: tl.constexpr = False,
):
# q size: (bsz, seq_len, num_q_heads, head_dim)
# q stride: (seq_len * num_q_heads * head_dim, num_q_heads * head_dim, head_dim, 1)
# k size: (bsz, seq_len, num_kv_heads, head_dim)
# k stride: (seq_len * num_kv_heads * head_dim, num_kv_heads * head_dim, head_dim, 1)
# cos size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
# stride: (seq_len * head_dim, head_dim, 1)
pid = tl.program_id(0)
# locate start address
q_ptr = q_ptr + pid * q_row_stride
k_ptr = k_ptr + pid * k_row_stride
# ####################################################################
# get the cos(mθ_{i...d/2}) and sin(mθ_{i...d/2}) for token position
# m of this program instance
# ####################################################################
# 1. program instances are laid out in a 1D vector of size bsz * seq_len, which
# effectively represents a 2D grid of size [bsz, seq_len] with seq_len dimension
# being the fastest changing dimension. Thus we can simply do pid // sl to get the batch index
# and pid % sl to get the sequence index.
# 2. We only need the left half of cos and sin matrix because the right half is just
# a clone of the left half.
batch_idx = pid // sl
cos_row_idx = pid % sl
cos = cos + tl.where(
cos_bs == 1,
cos_row_idx * cos_row_stride,
batch_idx * (sl * cos_row_stride) + cos_row_idx * cos_row_stride,
)
sin = sin + tl.where(
cos_bs == 1,
cos_row_idx * sin_row_stride,
batch_idx * (sl * sin_row_stride) + cos_row_idx * sin_row_stride,
)
cos_offsets = tl.arange(0, pad_hd // 2)
cos_mask = cos_offsets < hd // 2
cos_row = tl.load(cos + cos_offsets, mask=cos_mask, other=0)
sin_row = tl.load(sin + cos_offsets, mask=cos_mask, other=0)
# ####################################################################
# Load the left and right half of q and k for the current
# program instance (i.e. for the current token) separately
# ####################################################################
# left half of the head
first_half_q_offsets = tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
first_half_k_offsets = tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (tl.arange(0, pad_hd // 2)[None, :] < hd // 2)
first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (tl.arange(0, pad_hd // 2)[None, :] < hd // 2)
q_tile_1 = tl.load(q_ptr + first_half_q_offsets, mask=first_q_mask, other=0).to(sin_row.dtype)
k_tile_1 = tl.load(k_ptr + first_half_k_offsets, mask=first_k_mask, other=0).to(sin_row.dtype)
# right half of the head
second_half_q_offsets = first_half_q_offsets + (hd // 2)
second_half_k_offsets = first_half_k_offsets + (hd // 2)
second_q_mask = first_q_mask
second_k_mask = first_k_mask
q_tile_2 = tl.load(q_ptr + second_half_q_offsets, mask=second_q_mask, other=0).to(sin_row.dtype)
k_tile_2 = tl.load(k_ptr + second_half_k_offsets, mask=second_k_mask, other=0).to(sin_row.dtype)
if not BACKWARD_PASS:
# y = [x1, x2] * [cos, cos] + [-x2, x1] * [sin, sin]
new_q_tile_1 = q_tile_1 * cos_row - q_tile_2 * sin_row
tl.store(q_ptr + first_half_q_offsets, new_q_tile_1, mask=first_q_mask)
new_q_tile_2 = q_tile_2 * cos_row + q_tile_1 * sin_row
tl.store(q_ptr + second_half_q_offsets, new_q_tile_2, mask=second_q_mask)
new_k_tile_1 = k_tile_1 * cos_row - k_tile_2 * sin_row
tl.store(k_ptr + first_half_k_offsets, new_k_tile_1, mask=first_k_mask)
new_k_tile_2 = k_tile_2 * cos_row + k_tile_1 * sin_row
tl.store(k_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask)
else:
# with some math, we can get:
# dy = [dx1, dx2] * [cos, cos] + [-dx2, dx1] * [-sin, -sin]
new_q_tile_1 = q_tile_1 * cos_row + q_tile_2 * sin_row
tl.store(q_ptr + first_half_q_offsets, new_q_tile_1, mask=first_q_mask)
new_q_tile_2 = q_tile_2 * cos_row - q_tile_1 * sin_row
tl.store(q_ptr + second_half_q_offsets, new_q_tile_2, mask=second_q_mask)
new_k_tile_1 = k_tile_1 * cos_row + k_tile_2 * sin_row
tl.store(k_ptr + first_half_k_offsets, new_k_tile_1, mask=first_k_mask)
new_k_tile_2 = k_tile_2 * cos_row - k_tile_1 * sin_row
tl.store(k_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask)
def rope_forward(q, k, cos, sin):
# transpose it back to the physical shape because Triton looks at the physical storage
# note: q and k are incontiguous before the transformation and will become contiguous after transpose
batch_size, seq_len, n_q_head, head_dim = q.shape
n_kv_head = k.shape[2]
pad_hd = triton.next_power_of_2(head_dim)
pad_n_q_head = triton.next_power_of_2(n_q_head)
pad_n_kv_head = triton.next_power_of_2(n_kv_head)
BLOCK_SIZE = max(pad_n_q_head, pad_n_kv_head)
n_row = batch_size * seq_len
# ensure tensors passed into the kernel are contiguous. It will be no-op if they are already contiguous
q = q.contiguous()
k = k.contiguous()
cos = cos.contiguous()
sin = sin.contiguous()
cos_batch_size = cos.shape[0]
_triton_rope[(n_row,)](
q,
q.stride(1),
k,
k.stride(1),
cos,
cos.stride(-2),
sin,
sin.stride(-2),
seq_len,
batch_size,
cos_batch_size,
n_q_head,
n_kv_head,
head_dim,
pad_n_q_head,
pad_n_kv_head,
pad_hd,
BLOCK_SIZE=BLOCK_SIZE,
BACKWARD_PASS=False,
)
return q, k, cos, sin
def rope_backward(dq, dk, cos, sin):
batch_size, seq_len, n_q_head, head_dim = dq.shape
cos_batch_size = cos.shape[0]
n_kv_head = dk.shape[2]
pad_hd = triton.next_power_of_2(head_dim)
pad_n_q_head = triton.next_power_of_2(n_q_head)
pad_n_kv_head = triton.next_power_of_2(n_kv_head)
BLOCK_SIZE = max(pad_n_q_head, pad_n_kv_head)
n_row = batch_size * seq_len
# ensure dq and dk are contiguous
dq = dq.contiguous()
dk = dk.contiguous()
# backward is similar to forward except swapping few ops
_triton_rope[(n_row,)](
dq,
dq.stride(1),
dk,
dk.stride(1),
cos,
cos.stride(-2),
sin,
sin.stride(-2),
seq_len,
batch_size,
cos_batch_size,
n_q_head,
n_kv_head,
head_dim,
pad_n_q_head,
pad_n_kv_head,
pad_hd,
BLOCK_SIZE=BLOCK_SIZE,
BACKWARD_PASS=True,
)
return dq, dk
class LigerRopeFunction(torch.autograd.Function):
"""
Triton implementation of the Rotary Positional Embedding (RoPE) operation. Please note that
this implements the HuggingFace Llama & Mistral version, whose rotation matrix is slightly different
than the original RoPE paper.
Please find the corresponding HuggingFace implementation here:
https://github.com/huggingface/transformers/blob/v4.40.2/src/transformers/models/llama/modeling_llama.py#L184
For more details about the rotation matrix used here, please refer to:
https://discuss.huggingface.co/t/is-llama-rotary-embedding-implementation-correct/44509/2
"""
@staticmethod
def forward(ctx, q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""
q size: (bsz, n_q_head, seq_len, head_dim)
k size: (bsz, n_kv_head, seq_len, head_dim)
cos size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
sin size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
"""
q, k, cos, sin = rope_forward(q, k, cos, sin)
ctx.save_for_backward(cos, sin)
return q, k
def backward(ctx, dq, dk):
"""
dq size: (bsz, n_q_head, seq_len, head_dim)
dk size: (bsz, n_kv_head, seq_len, head_dim)
cos size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
sin size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
"""
cos, sin = ctx.saved_tensors
dq, dk = rope_backward(dq, dk, cos, sin)
return dq, dk, None, None, None, None