Skip to content

Commit

Permalink
support static cachekv quant
Browse files Browse the repository at this point in the history
  • Loading branch information
RichardWooSJTU committed Dec 12, 2023
1 parent ae59935 commit 1ecd7e2
Show file tree
Hide file tree
Showing 5 changed files with 169 additions and 21 deletions.
32 changes: 17 additions & 15 deletions llm/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,10 @@ class PredictorArgument:
)
block_attn: bool = field(default=False, metadata={"help": "whether use block attention"})
block_size: int = field(default=64, metadata={"help": "the block size for cache_kvs."})
use_cachekv_int8: bool = field(
default=False,
metadata={"help": "If use_cachekv_int8 set as `True`, dynamic cache kv quantization will be applied"},)
use_cachekv_int8: str = field(
default="None",
metadata={"help": "If use_cachekv_int8 set as `dynamic`, dynamic cache kv quantization will be applied; if set as `static`, static cache kv will be applied"},)

chat_template: str = field(
default=None,
metadata={
Expand Down Expand Up @@ -708,7 +709,7 @@ def __init__(
self.architectures = self.model_config.architectures[0].lower()

self.dtype = config.dtype or self.model_config
if config.use_cachekv_int8:
if config.use_cachekv_int8 == "dynamic":
self.cache_kvs = [paddle.zeros(shape, dtype="uint8") for shape in self.cache_kvs_shape]
else:
self.cache_kvs = [paddle.zeros(shape, dtype=self.dtype) for shape in self.cache_kvs_shape]
Expand All @@ -734,7 +735,7 @@ def __init__(
pre_cache_mask[:, :, :, self.pre_cache_length:] = paddle.tril(paddle.ones(shape=[config.batch_size, 1, config.src_length, config.src_length], dtype=self.dtype))
self.inputs["src_mask"] = (pre_cache_mask - 1) * 1e4

if config.use_cachekv_int8:
if config.use_cachekv_int8 == "dynamic":
self.k_quant_scales = [
paddle.zeros([self.num_attention_heads], dtype="float32") for _ in range(self.num_layers)
]
Expand All @@ -757,7 +758,7 @@ def __init__(
self.inputs["frequency_score"] = paddle.full(shape=[config.batch_size, 1], fill_value=0.0, dtype="float32")
self.inputs["presence_score"] = paddle.full(shape=[config.batch_size, 1], fill_value=0.0, dtype="float32")

if config.use_cachekv_int8:
if config.use_cachekv_int8 == "dynamic":
self.inputs["k_quant_scales"] = self.k_quant_scales
self.inputs["v_quant_scales"] = self.v_quant_scales
self.inputs["k_dequant_scales"] = self.k_dequant_scales
Expand Down Expand Up @@ -916,22 +917,22 @@ def __init__(
self.inputs["src_mask"] = (pre_cache_mask - 1) * 1e4

self.cache_kvs = {}
if not config.use_cachekv_int8:
if config.use_cachekv_int8 == "dynamic":
for i in range(len(self.cache_kvs_shape) // 2):
self.cache_kvs["key_caches_{}".format(i)] = paddle.zeros(
self.cache_kvs_shape[2 * i], dtype=config.dtype
)
self.cache_kvs["key_caches_{}".format(i)] = paddle.zeros(self.cache_kvs_shape[2 * i], dtype="uint8")
self.cache_kvs["value_caches_{}".format(i)] = paddle.zeros(
self.cache_kvs_shape[2 * i + 1], dtype=config.dtype
self.cache_kvs_shape[2 * i + 1], dtype="uint8"
)
else:
for i in range(len(self.cache_kvs_shape) // 2):
self.cache_kvs["key_caches_{}".format(i)] = paddle.zeros(self.cache_kvs_shape[2 * i], dtype="uint8")
self.cache_kvs["key_caches_{}".format(i)] = paddle.zeros(
self.cache_kvs_shape[2 * i], dtype=config.dtype
)
self.cache_kvs["value_caches_{}".format(i)] = paddle.zeros(
self.cache_kvs_shape[2 * i + 1], dtype="uint8"
self.cache_kvs_shape[2 * i + 1], dtype=config.dtype
)

if config.use_cachekv_int8:
if config.use_cachekv_int8 == "dynamic":
self.k_quant_scales = [
paddle.zeros([self.num_attention_heads], dtype="float32") for _ in range(self.num_layers)
]
Expand Down Expand Up @@ -1003,7 +1004,7 @@ def __init__(


for i in range(self.num_layers):
if self.config.use_cachekv_int8:
if self.config.use_cachekv_int8 == "dynamic":
self.inputs["k_quant_scales_" + str(i)] = self.k_quant_scales[i]
self.inputs["v_quant_scales_" + str(i)] = self.v_quant_scales[i]
self.inputs["k_dequant_scales_" + str(i)] = self.k_dequant_scales[i]
Expand Down Expand Up @@ -1319,6 +1320,7 @@ def create_predictor(
if predictor_args.block_attn:
config.block_size = predictor_args.block_size
config.max_seq_len = predictor_args.src_length
config.use_dynamic_cachekv_quant = predictor_args.use_cachekv_int8 == "dynamic"
from paddlenlp.experimental.transformers import (
LlamaForCausalLMBlockInferenceModel as LlamaInferenceModel,
)
Expand Down
28 changes: 28 additions & 0 deletions paddlenlp/experimental/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,3 +396,31 @@ def __init__(
self.scale["ffn1_weight_scale"].append(
np.concatenate([self.scale["ffn1_1_weight_scale"][i, :], self.scale["ffn1_2_weight_scale"][i, :]])
)


class CacheScaleLoader:
def __init__(
self,
scale_json_file_path="cache_scales.json",
key_map_dict=None,
num_of_layers=None,
num_heads=None
):
with open(scale_json_file_path) as json_file:
self.scale_dict = json.load(json_file)
self.key_map = key_map_dict
self.scale = {}
for scale_type, key_template in self.key_map.items():
print("scale_type: ", scale_type)
print("key_template: ", key_template)
if ("cache_k" in scale_type):
scale_type_out = "cache_k_out_scale"
else:
scale_type_out = "cache_v_out_scale"
self.scale[scale_type] = np.full([num_of_layers, num_heads], fill_value=-1.0)
self.scale[scale_type_out] = np.full([num_of_layers, num_heads], fill_value=-1.0)

for i in range(num_of_layers):
if key_template.replace("#", str(i)) in self.scale_dict.keys():
self.scale[scale_type][i, :] = [127.0 / num for num in self.scale_dict[key_template.replace("#", str(i))]]
self.scale[scale_type_out][i, :] = [1.0 / self.scale[scale_type][i, j] for j in range(num_heads)]
85 changes: 82 additions & 3 deletions paddlenlp/experimental/transformers/fused_transformer_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,10 @@ def __init__(
linear_smooth_attrs=None,
ffn2_shift_attrs=None,
ffn2_smooth_attrs=None,
cache_k_scale_attrs=None,
cache_v_scale_attrs=None,
cache_k_out_scale_attrs=None,
cache_v_out_scale_attrs=None,
quant_round_type=0,
quant_max_bound=127.0,
quant_min_bound=-127.0,
Expand All @@ -191,6 +195,7 @@ def __init__(
trans_qkvw=True,
ring_id=-1,
kv_num_heads=-1,
use_dynamic_cachekv_quant=True,
):
self.embed_dim = embed_dim
self.num_heads = num_heads
Expand Down Expand Up @@ -232,9 +237,15 @@ def __init__(
self.linear_smooth_attrs = linear_smooth_attrs
self.ffn2_shift_attrs = ffn2_shift_attrs
self.ffn2_smooth_attrs = ffn2_smooth_attrs
self.cache_k_scale_attrs = cache_k_scale_attrs
self.cache_v_scale_attrs = cache_v_scale_attrs
self.cache_k_out_scale_attrs = cache_k_out_scale_attrs
self.cache_v_out_scale_attrs = cache_v_out_scale_attrs

self.quant_round_type = quant_round_type
self.quant_max_bound = quant_max_bound
self.quant_min_bound = quant_min_bound
self.use_dynamic_cachekv_quant = use_dynamic_cachekv_quant

self.epsilon = epsilon
self.residual_alpha = residual_alpha
Expand All @@ -249,6 +260,8 @@ def __init__(
class FusedMultiTransformerBase(Layer):
def __init__(self, config: FusedMultiTransformerConfig):
super().__init__()

self.config = config

assert config.embed_dim > 0, "Expected embed_dim to be greater than 0, " "but received {}".format(
config.embed_dim
Expand Down Expand Up @@ -305,6 +318,8 @@ def __init__(self, config: FusedMultiTransformerConfig):
self.ffn_ln_scales, self.ffn_ln_biases = [], []
self.ffn1_weights, self.ffn1_biases = [], []
self.ffn2_weights, self.ffn2_biases = [], []
self.cache_k_scales, self.cache_v_scales = [], []
self.cache_k_out_scales, self.cache_v_out_scales = [], []

for i in range(self.num_layers):
ln_scale_attr = self.get_attr(config.ln_scale_attrs, i)
Expand All @@ -321,6 +336,11 @@ def __init__(self, config: FusedMultiTransformerConfig):
ffn1_bias_attr = self.get_attr(config.ffn1_bias_attrs, i)
ffn2_weight_attr = self.get_attr(config.ffn2_weight_attrs, i)
ffn2_bias_attr = self.get_attr(config.ffn2_bias_attrs, i)

cache_k_scale_attr = self.get_attr(config.cache_k_scale_attrs, i)
cache_v_scale_attr = self.get_attr(config.cache_v_scale_attrs, i)
cache_k_out_scale_attr = self.get_attr(config.cache_k_out_scale_attrs, i)
cache_v_out_scale_attr = self.get_attr(config.cache_v_out_scale_attrs, i)

ln_scale = self.create_parameter(
attr=ln_scale_attr,
Expand Down Expand Up @@ -419,6 +439,42 @@ def __init__(self, config: FusedMultiTransformerConfig):
dtype=self._dtype,
is_bias=True,
)

cache_k_scale = None
if cache_k_scale_attr:
cache_k_scale = self.create_parameter(
shape=[config.num_heads ],
attr=cache_k_scale_attr,
dtype='float32',
is_bias=False,
)

cache_v_scale = None
if cache_v_scale_attr:
cache_v_scale = self.create_parameter(
shape=[config.num_heads ],
attr=cache_v_scale_attr,
dtype='float32',
is_bias=False,
)

cache_k_out_scale = None
if cache_k_out_scale_attr:
cache_k_out_scale = self.create_parameter(
shape=[config.num_heads ],
attr=cache_k_out_scale_attr,
dtype='float32',
is_bias=False,
)

cache_v_out_scale = None
if cache_v_out_scale_attr:
cache_v_out_scale = self.create_parameter(
shape=[config.num_heads ],
attr=cache_v_out_scale_attr,
dtype='float32',
is_bias=False,
)

# tensor model parallel
if config.nranks > 1:
Expand All @@ -444,6 +500,11 @@ def __init__(self, config: FusedMultiTransformerConfig):
self.ffn1_biases.append(ffn1_bias)
self.ffn2_weights.append(ffn2_weight)
self.ffn2_biases.append(ffn2_bias)

self.cache_k_scales.append(cache_k_scale)
self.cache_v_scales.append(cache_v_scale)
self.cache_k_out_scales.append(cache_k_out_scale)
self.cache_v_out_scales.append(cache_v_out_scale)

self._add_parameter(ln_scale)
self._add_parameter(ln_bias)
Expand All @@ -458,6 +519,11 @@ def __init__(self, config: FusedMultiTransformerConfig):
self._add_parameter(ffn1_bias)
self._add_parameter(ffn2_weight)
self._add_parameter(ffn2_bias)

self._add_parameter(cache_k_scale)
self._add_parameter(cache_v_scale)
self._add_parameter(cache_k_out_scale)
self._add_parameter(cache_v_out_scale)

self.dropout_rate = config.dropout_rate

Expand Down Expand Up @@ -1313,7 +1379,14 @@ def compute_attn(
v_quant_scales = kwargs.get("v_quant_scales", None)
k_dequant_scales = kwargs.get("k_dequant_scales", None)
v_dequant_scales = kwargs.get("v_dequant_scales", None)


if not self.config.use_dynamic_cachekv_quant:
k_quant_scales = self.cache_k_scales
v_quant_scales = self.cache_v_scales
k_dequant_scales = self.cache_k_out_scales
v_dequant_scales = self.cache_v_out_scales



fmha_out = paddle.incubate.nn.functional.block_multihead_attention(
qkv_out,
Expand Down Expand Up @@ -1343,7 +1416,7 @@ def compute_attn(
kwargs.get("max_input_length", -1),
kwargs.get("block_size", 64),
self.use_neox_rotary_style,
True,
self.config.use_dynamic_cachekv_quant,
)[0]

out_linear_out = self.compute_out_linear(fmha_out, i)
Expand Down Expand Up @@ -1395,6 +1468,12 @@ def compute_attn(
v_quant_scales = kwargs.get("v_quant_scales", None)
k_dequant_scales = kwargs.get("k_dequant_scales", None)
v_dequant_scales = kwargs.get("v_dequant_scales", None)

if not self.config.use_dynamic_cachekv_quant:
k_quant_scales = self.cache_k_scales
v_quant_scales = self.cache_v_scales
k_dequant_scales = self.cache_k_out_scales
v_dequant_scales = self.cache_v_out_scales

# print("self.qkv_out_scales[i]", self.qkv_out_scales[i])
# print("self.qkv_biases[i]", self.qkv_biases[i])
Expand Down Expand Up @@ -1426,7 +1505,7 @@ def compute_attn(
kwargs.get("max_input_length", -1),
kwargs.get("block_size", 64),
self.use_neox_rotary_style,
True,
self.config.use_dynamic_cachekv_quant,
quant_round_type=self.quant_round_type,
quant_max_bound=self.quant_max_bound,
quant_min_bound=self.quant_min_bound,
Expand Down
4 changes: 2 additions & 2 deletions paddlenlp/experimental/transformers/generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,10 +440,10 @@ def to_static(self, output_path: str, config: dict):
]
else:
precache_kv_spec = None
use_cachekv_int8 = config.get("use_cachekv_int8", False)
use_cachekv_int8 = config.get("use_cachekv_int8", "None")
print("use_cachekv_int8", use_cachekv_int8)

if use_cachekv_int8:
if use_cachekv_int8 == "static" or use_cachekv_int8 == "dynamic":
cachekv_dtype = "uint8"
cache_k_quant_scales = [
paddle.static.InputSpec(
Expand Down
41 changes: 40 additions & 1 deletion paddlenlp/experimental/transformers/llama/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from paddlenlp.utils.log import logger


from paddlenlp.experimental.model_utils import ActScalesLoader, WeightScalesLoader
from paddlenlp.experimental.model_utils import ActScalesLoader, WeightScalesLoader, CacheScaleLoader
from paddlenlp.experimental.transformers.fused_transformer_layers import (
FusedMultiTransformerA8W8,
FusedMultiTransformerBase,
Expand Down Expand Up @@ -254,6 +254,18 @@ def __init__(self, config: LlamaConfig):
ffn2_weight_scale_attrs = [
paddle.ParamAttr(name="fusellama.{}.ffn2_weight_scale".format(i)) for i in range(self.num_layers)
]

cache_k_scale_attrs = None
cache_v_scale_attrs = None
cache_k_out_scale_attrs = None
cache_v_out_scale_attrs = None

if config.use_cachekv_int8 == "static":
cache_k_scale_attrs = [paddle.ParamAttr(name="fusellama.{}.cache_k_scale".format(i)) for i in range(self.num_layers)]
cache_v_scale_attrs = [paddle.ParamAttr(name="fusellama.{}.cache_v_scale".format(i)) for i in range(self.num_layers)]
cache_k_out_scale_attrs = [paddle.ParamAttr(name="fusellama.{}.cache_k_out_scale".format(i)) for i in range(self.num_layers)]
cache_v_out_scale_attrs = [paddle.ParamAttr(name="fusellama.{}.cache_v_out_scale".format(i)) for i in range(self.num_layers)]


transformer_config = FusedMultiTransformerConfig(
self.hidden_size,
Expand Down Expand Up @@ -288,9 +300,14 @@ def __init__(self, config: LlamaConfig):
ffn_ln_bias_attrs=ffn_ln_bias_attrs,
ffn1_bias_attrs=ffn1_bias_attrs,
ffn2_bias_attrs=ffn2_bias_attrs,
cache_k_scale_attrs=cache_k_scale_attrs,
cache_v_scale_attrs=cache_v_scale_attrs,
cache_k_out_scale_attrs=cache_k_out_scale_attrs,
cache_v_out_scale_attrs=cache_v_out_scale_attrs,
epsilon=self.epsilon,
norm_type="rmsnorm",
use_neox_rotary_style=True,
use_dynamic_cachekv_quant=config.use_cachekv_int8 == "dynamic",
)


Expand Down Expand Up @@ -653,6 +670,7 @@ def set_state_dict(self, state_dict):
scale_map_dict = json.load(json_file)
act_scale_map_dict = scale_map_dict["act_scale"]
weight_scale_map_dict = scale_map_dict["weight_scale"]
cache_scale_map_dict = scale_map_dict["cache_scale"]
# TODO(RichardWooSJTU): support multi-cards

act_scale_json_path = os.path.join(self.quant_model_path, "act_scales.json")
Expand All @@ -669,6 +687,27 @@ def set_state_dict(self, state_dict):
concat_qkv=True,
concat_ffn1=True,
)

if self.config.use_dynamic_cache_quant:
cache_scale_json_path = os.path.join(self.quant_model_path, "cache_act_scales.json")
cache_scales_loader = CacheScaleLoader(
cache_scale_json_path,
cache_scale_map_dict,
num_of_layers=self.config.num_hidden_layers,
num_heads=self.num_attention_heads // self.config.tensor_parallel_degree,
)
for k, v in cache_scales_loader.scale.items():
for i_layer, weight_scale in enumerate(v):
weight_scale = weight_scale.astype("float32")
if k == "cache_k_scale":
self.decoder.cache_k_scales[i_layer].set_value(weight_scale)
elif k == "cache_v_scale":
self.decoder.cache_v_scales[i_layer].set_value(weight_scale)
elif k == "cache_k_out_scale":
self.decoder.cache_k_out_scales[i_layer].set_value(weight_scale)
else:
self.decoder.cache_v_out_scales[i_layer].set_value(weight_scale)

for k, v in weight_scales_loader.scale.items():
if "qkv_" in k:
for i_layer, weight_scale in enumerate(v):
Expand Down

0 comments on commit 1ecd7e2

Please sign in to comment.