Skip to content

Commit

Permalink
[Kernel] Full Tensor Parallelism for LoRA Layers (#3524)
Browse files Browse the repository at this point in the history
Co-authored-by: Antoni Baum <[email protected]>
  • Loading branch information
FurtherAI and Yard1 authored Apr 27, 2024
1 parent 18d23f6 commit eefeb16
Show file tree
Hide file tree
Showing 19 changed files with 686 additions and 111 deletions.
1 change: 1 addition & 0 deletions csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
#include "bgmv_impl.cuh"

FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_bfloat16, nv_bfloat16)
FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, nv_bfloat16, nv_bfloat16, nv_bfloat16)
1 change: 1 addition & 0 deletions csrc/punica/bgmv/bgmv_bf16_fp32_bf16.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
#include "bgmv_impl.cuh"

FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, float, nv_bfloat16)
FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, nv_bfloat16, float, nv_bfloat16)
78 changes: 78 additions & 0 deletions csrc/punica/bgmv/bgmv_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,89 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
// Keep above in sync with vllm/lora/layers::LogitsProcessorWithLoRA
// and vllm/tests/lora/test_punica.py

// Used for defining kernels going from the variety of
// dim in to the narrow dim out
// Using it for the fully sharded column
// parallel LoRA A which splits the rank dim
#define FOR_INST_BGMV_NARROW(f, in_T, out_T, W_T, narrow) \
f(in_T, out_T, W_T, 128, narrow) \
f(in_T, out_T, W_T, 256, narrow) \
f(in_T, out_T, W_T, 512, narrow) \
f(in_T, out_T, W_T, 640, narrow) \
f(in_T, out_T, W_T, 768, narrow) \
f(in_T, out_T, W_T, 1024, narrow) \
f(in_T, out_T, W_T, 1152, narrow) \
f(in_T, out_T, W_T, 1280, narrow) \
f(in_T, out_T, W_T, 1536, narrow) \
f(in_T, out_T, W_T, 1728, narrow) \
f(in_T, out_T, W_T, 1792, narrow) \
f(in_T, out_T, W_T, 2048, narrow) \
f(in_T, out_T, W_T, 2304, narrow) \
f(in_T, out_T, W_T, 2560, narrow) \
f(in_T, out_T, W_T, 2752, narrow) \
f(in_T, out_T, W_T, 2816, narrow) \
f(in_T, out_T, W_T, 3072, narrow) \
f(in_T, out_T, W_T, 3456, narrow) \
f(in_T, out_T, W_T, 3584, narrow) \
f(in_T, out_T, W_T, 4096, narrow) \
f(in_T, out_T, W_T, 4608, narrow) \
f(in_T, out_T, W_T, 5120, narrow) \
f(in_T, out_T, W_T, 5504, narrow) \
f(in_T, out_T, W_T, 5632, narrow) \
f(in_T, out_T, W_T, 6144, narrow) \
f(in_T, out_T, W_T, 6848, narrow) \
f(in_T, out_T, W_T, 6912, narrow) \
f(in_T, out_T, W_T, 7168, narrow) \
f(in_T, out_T, W_T, 8192, narrow) \
f(in_T, out_T, W_T, 9216, narrow) \
f(in_T, out_T, W_T, 10240, narrow) \
f(in_T, out_T, W_T, 11008, narrow) \
f(in_T, out_T, W_T, 12288, narrow) \
f(in_T, out_T, W_T, 13696, narrow) \
f(in_T, out_T, W_T, 13824, narrow) \
f(in_T, out_T, W_T, 14336, narrow) \
f(in_T, out_T, W_T, 15360, narrow) \
f(in_T, out_T, W_T, 16384, narrow) \
f(in_T, out_T, W_T, 20480, narrow) \
f(in_T, out_T, W_T, 22016, narrow) \
f(in_T, out_T, W_T, 24576, narrow) \
f(in_T, out_T, W_T, 27392, narrow) \
f(in_T, out_T, W_T, 28672, narrow) \
f(in_T, out_T, W_T, 32000, narrow) \
f(in_T, out_T, W_T, 32256, narrow) \
f(in_T, out_T, W_T, 32512, narrow) \
f(in_T, out_T, W_T, 32768, narrow) \
f(in_T, out_T, W_T, 33024, narrow) \
f(in_T, out_T, W_T, 36864, narrow) \
f(in_T, out_T, W_T, 43264, narrow) \
f(in_T, out_T, W_T, 49152, narrow) \
f(in_T, out_T, W_T, 64000, narrow) \
f(in_T, out_T, W_T, 64256, narrow) \
f(in_T, out_T, W_T, 64512, narrow) \
f(in_T, out_T, W_T, 102400, narrow) \
f(in_T, out_T, W_T, 102656, narrow) \
f(in_T, out_T, W_T, 102912, narrow) \
f(in_T, out_T, W_T, 128000, narrow) \
f(in_T, out_T, W_T, 128256, narrow) \
f(in_T, out_T, W_T, 128512, narrow) \
// Keep above in sync with vllm/lora/layers::SamplerWithLoRA


// Keep this in sync with vllm/config::LoRAConfig
#define FOR_BGMV_WIDE_NARROW(f, in_T, out_T, W_T) \
FOR_BGMV_WIDE(f, in_T, out_T, W_T, 8) \
FOR_BGMV_WIDE(f, in_T, out_T, W_T, 16) \
FOR_BGMV_WIDE(f, in_T, out_T, W_T, 32) \
FOR_BGMV_WIDE(f, in_T, out_T, W_T, 64)


#define FOR_INST_BGMV_WIDE_NARROW(f, in_T, out_T, W_T) \
FOR_INST_BGMV_NARROW(f, in_T, out_T, W_T, 1) \
FOR_INST_BGMV_NARROW(f, in_T, out_T, W_T, 2) \
FOR_INST_BGMV_NARROW(f, in_T, out_T, W_T, 4) \
f(in_T, out_T, W_T, 8, 64) \
f(in_T, out_T, W_T, 16, 64) \
f(in_T, out_T, W_T, 32, 64) \
f(in_T, out_T, W_T, 64, 64)

// clang-format on
1 change: 1 addition & 0 deletions csrc/punica/bgmv/bgmv_fp16_fp16_fp16.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
#include "bgmv_impl.cuh"

FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_half, nv_half)
FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, nv_half, nv_half, nv_half)
1 change: 1 addition & 0 deletions csrc/punica/bgmv/bgmv_fp16_fp32_fp16.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
#include "bgmv_impl.cuh"

FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, float, nv_half)
FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, nv_half, float, nv_half)
1 change: 1 addition & 0 deletions csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
#include "bgmv_impl.cuh"

FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_bfloat16, nv_bfloat16)
FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, float, nv_bfloat16, nv_bfloat16)
1 change: 1 addition & 0 deletions csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
#include "bgmv_impl.cuh"

FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_half, nv_half)
FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, float, nv_half, nv_half)
5 changes: 4 additions & 1 deletion csrc/punica/bgmv/bgmv_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
constexpr int tz = 4;
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();

if constexpr (feat_in < feat_out) {
if constexpr (feat_in <= feat_out) {
static_assert(feat_in % vec_size == 0);
constexpr int tx = feat_in / vec_size;

Expand Down Expand Up @@ -289,6 +289,9 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
int64_t y_offset, int64_t full_y_size, int64_t batch_size, \
int64_t num_layers, int64_t layer_idx, float scale);

#define INST_BGMV_ONESIDE(in_T, out_T, W_T, feat_in, feat_out) \
INST_BGMV(feat_in, feat_out, in_T, out_T, W_T)

#define INST_BGMV_TWOSIDE(in_T, out_T, W_T, narrow, wide) \
INST_BGMV(narrow, wide, in_T, out_T, W_T) \
INST_BGMV(wide, narrow, in_T, out_T, W_T)
1 change: 1 addition & 0 deletions csrc/punica/bgmv/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, {input_dtype}, {output_dtype}, {weight_dtype})
FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, {input_dtype}, {output_dtype}, {weight_dtype})
""".lstrip() # noqa: E501

for input_dtype in DTYPES:
Expand Down
2 changes: 1 addition & 1 deletion csrc/punica/punica_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -79,12 +79,12 @@ inline bool launch_bgmv_kernel(out_T *Y, const in_T *X, const W_T *W,
CASE_ONESIDE(in_T, out_T, W_T, wide, narrow)

FOR_BGMV_WIDE_NARROW(CASE, _, _, _)
FOR_INST_BGMV_WIDE_NARROW(CASE_ONESIDE, _, _, _)
#undef CASE
#undef CASE_ONESIDE
default:
return false;
}

return true;
}

Expand Down
29 changes: 23 additions & 6 deletions tests/lora/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@
import torch.nn.functional as F

from vllm.config import LoRAConfig
from vllm.lora.fully_sharded_layers import (
ColumnParallelLinearWithShardedLoRA,
MergedColumnParallelLinearWithShardedLoRA,
MergedQKVParallelLinearWithShardedLora, RowParallelLinearWithShardedLoRA)
# yapf conflicts with isort for this block
# yapf: disable
from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA,
Expand Down Expand Up @@ -524,13 +528,16 @@ def _pretest():
@torch.inference_mode()
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
@pytest.mark.parametrize("orientation", ["row", "column"])
@pytest.mark.parametrize("fully_shard", [True, False])
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_linear_parallel(dist_init, num_loras, orientation, device) -> None:
def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
device) -> None:

torch.set_default_device(device)
max_loras = 8
lora_config = LoRAConfig(max_loras=max_loras,
max_lora_rank=8,
fully_sharded_loras=fully_shard,
lora_dtype=torch.float16)

def create_random_linear_parallel_layer():
Expand All @@ -540,14 +547,17 @@ def create_random_linear_parallel_layer():
bias=False,
params_dtype=torch.float16)
linear.weight.data = torch.rand_like(linear.weight.data)
lora_linear = RowParallelLinearWithLoRA(linear)
lora_linear = (RowParallelLinearWithLoRA(linear) if not fully_shard
else RowParallelLinearWithShardedLoRA(linear))
else:
linear = ColumnParallelLinear(4096,
4096,
bias=False,
params_dtype=torch.float16)
linear.weight.data = torch.rand_like(linear.weight.data)
lora_linear = ColumnParallelLinearWithLoRA(linear)
lora_linear = (ColumnParallelLinearWithLoRA(linear)
if not fully_shard else
ColumnParallelLinearWithShardedLoRA(linear))
lora_linear.create_lora_weights(max_loras, lora_config)

return linear, lora_linear
Expand Down Expand Up @@ -629,13 +639,16 @@ def create_random_linear_parallel_layer():
@torch.inference_mode()
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
@pytest.mark.parametrize("repeats", [1, 2, 3])
@pytest.mark.parametrize("fully_shard", [True, False])
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_column_parallel_packed(dist_init, num_loras, repeats, device) -> None:
def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
device) -> None:

torch.set_default_device(device)
max_loras = 8
lora_config = LoRAConfig(max_loras=max_loras,
max_lora_rank=8,
fully_sharded_loras=fully_shard,
lora_dtype=torch.float16)

def create_column_parallel_packed_layer():
Expand All @@ -644,15 +657,19 @@ def create_column_parallel_packed_layer():
bias=False,
params_dtype=torch.float16)
linear.weight.data = torch.rand_like(linear.weight.data)
lora_linear = MergedColumnParallelLinearWithLoRA(linear)
lora_linear = (MergedColumnParallelLinearWithLoRA(linear)
if not fully_shard else
MergedColumnParallelLinearWithShardedLoRA(linear))
elif repeats == 3:
linear = QKVParallelLinear(4096,
64,
32,
bias=False,
params_dtype=torch.float16)
linear.weight.data = torch.rand_like(linear.weight.data)
lora_linear = MergedQKVParallelLinearWithLora(linear)
lora_linear = (MergedQKVParallelLinearWithLora(linear)
if not fully_shard else
MergedQKVParallelLinearWithShardedLora(linear))
else:
linear = QKVParallelLinear(4096,
64,
Expand Down
51 changes: 49 additions & 2 deletions tests/lora/test_punica.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,14 @@ def _lora_ref_impl(
for i, lora_idx in zip(range(bs), indicies.cpu().tolist()):
xi = x[i].unsqueeze(0).to(torch.float32)
wa = wa_T_all[lora_idx, layer_idx].transpose(-1, -2).to(torch.float32)
wb = wb_T_all[lora_idx, layer_idx].transpose(-1, -2).to(torch.float32)
if wb_T_all is not None:
wb = wb_T_all[lora_idx, layer_idx].transpose(-1,
-2).to(torch.float32)

tmp = xi @ wa
y_stage_1[i] = tmp.squeeze(0)
y_final[i] += (tmp @ wb).squeeze(0) * s
y_final[i] += ((tmp @ wb).squeeze(0) *
s if wb_T_all is not None else y_stage_1[i])
return y_final, y_stage_1


Expand Down Expand Up @@ -91,12 +94,56 @@ def _lora_ref_impl(
128000,
128256,
]
H2 = [64] + H2
R = [1, 2, 4]
SEED = [0xabcdabcd987]
CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]


@pytest.mark.parametrize("dtype_str", ["float16", "bfloat16"])
@pytest.mark.parametrize("h1", H1)
@pytest.mark.parametrize("r", R)
@pytest.mark.parametrize("seed", SEED)
@torch.inference_mode()
def test_lora_a_extra_shapes(dtype_str, h1, r, seed):
torch.manual_seed(seed)
num_loras = 4
num_layers = 1
bs = 32
dtype = getattr(torch, dtype_str)
device = torch.device("cuda")

wa_T_all = torch.randn(num_loras,
num_layers,
r,
h1,
dtype=dtype,
device=device)
indices = torch.randint(num_loras, (bs, ), dtype=torch.long, device=device)

for layer_idx in range(num_layers):
x = torch.randn(bs, h1, dtype=dtype, device=device)
y = torch.randn(bs, r, dtype=dtype, device=device)

y_ref = y.clone()
_lora_ref_impl(
y_ref,
x,
wa_T_all,
None,
indices,
layer_idx,
1.0,
)

y_our = y.clone()
punica.bgmv(y_our, x, wa_T_all, indices, layer_idx, 1.0)

assert_close(y_ref, y_our)


@pytest.mark.parametrize("dtype_str", ["float16", "bfloat16"])
@pytest.mark.parametrize("h1", H1)
@pytest.mark.parametrize("h2", H2)
Expand Down
1 change: 1 addition & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -862,6 +862,7 @@ def __repr__(self) -> str:
class LoRAConfig:
max_lora_rank: int
max_loras: int
fully_sharded_loras: bool = False
max_cpu_loras: Optional[int] = None
lora_dtype: Optional[torch.dtype] = None
lora_extra_vocab_size: int = 256
Expand Down
10 changes: 10 additions & 0 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ class EngineArgs:
enable_lora: bool = False
max_loras: int = 1
max_lora_rank: int = 16
fully_sharded_loras: bool = False
lora_extra_vocab_size: int = 256
lora_dtype = 'auto'
max_cpu_loras: Optional[int] = None
Expand Down Expand Up @@ -376,6 +377,14 @@ def add_cli_args(
help=('Maximum number of LoRAs to store in CPU memory. '
'Must be >= than max_num_seqs. '
'Defaults to max_num_seqs.'))
parser.add_argument(
'--fully-sharded-loras',
action='store_true',
help=('By default, only half of the LoRA computation is '
'sharded with tensor parallelism. '
'Enabling this will use the fully sharded layers. '
'At high sequence length, max rank or '
'tensor parallel size, this is likely faster.'))
parser.add_argument("--device",
type=str,
default=EngineArgs.device,
Expand Down Expand Up @@ -509,6 +518,7 @@ def create_engine_config(self, ) -> EngineConfig:
lora_config = LoRAConfig(
max_lora_rank=self.max_lora_rank,
max_loras=self.max_loras,
fully_sharded_loras=self.fully_sharded_loras,
lora_extra_vocab_size=self.lora_extra_vocab_size,
lora_dtype=self.lora_dtype,
max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras
Expand Down
Loading

0 comments on commit eefeb16

Please sign in to comment.