Skip to content

Commit

Permalink
fix cachekv
Browse files Browse the repository at this point in the history
  • Loading branch information
RichardWooSJTU committed Dec 12, 2023
1 parent 0a378f6 commit 1c1a14e
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 13 deletions.
4 changes: 2 additions & 2 deletions llm/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,7 +709,7 @@ def __init__(
self.architectures = self.model_config.architectures[0].lower()

self.dtype = config.dtype or self.model_config
if config.use_cachekv_int8 == "dynamic":
if config.use_cachekv_int8 == "dynamic" or config.use_cachekv_int8 == "static":
self.cache_kvs = [paddle.zeros(shape, dtype="uint8") for shape in self.cache_kvs_shape]
else:
self.cache_kvs = [paddle.zeros(shape, dtype=self.dtype) for shape in self.cache_kvs_shape]
Expand Down Expand Up @@ -917,7 +917,7 @@ def __init__(
self.inputs["src_mask"] = (pre_cache_mask - 1) * 1e4

self.cache_kvs = {}
if config.use_cachekv_int8 == "dynamic":
if config.use_cachekv_int8 == "dynamic" or config.use_cachekv_int8 == "static":
for i in range(len(self.cache_kvs_shape) // 2):
self.cache_kvs["key_caches_{}".format(i)] = paddle.zeros(self.cache_kvs_shape[2 * i], dtype="uint8")
self.cache_kvs["value_caches_{}".format(i)] = paddle.zeros(
Expand Down
7 changes: 4 additions & 3 deletions llm/run_dygraph.sh
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ export PYTHONPATH=$(dirname $(pwd)):$PYTHONPATH

export FLAGS_call_stack_level=2
export GLOG_logtostderr=true
export GLOG_v=0
export GLOG_v=1

export FLAGS_control_flow_use_new_executor=1
export FLAGS_new_executor_serial_run=1
Expand All @@ -27,7 +27,7 @@ export FLAGS_fraction_of_gpu_memory_to_use=0.92

model_dir=${1:-"checkpoints/llama65b_ptq"}
src_len=${2:-1024}
dec_len=${3:-1024}
dec_len=${3:-100}
quant_type=${4:-"a8w8"}
# quant_type=${4:-"None"}

Expand All @@ -45,5 +45,6 @@ python -m paddle.distributed.launch \
--batch_size 2 \
--inference_model \
--quant_type ${quant_type} \
--block_attn
--block_attn \
--use_cachekv_int8 static

Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,7 @@ def __init__(self, config: FusedMultiTransformerConfig):
cache_k_scale = None
if cache_k_scale_attr:
cache_k_scale = self.create_parameter(
shape=[config.num_heads ],
shape=[self.num_heads],
attr=cache_k_scale_attr,
dtype='float32',
is_bias=False,
Expand All @@ -452,7 +452,7 @@ def __init__(self, config: FusedMultiTransformerConfig):
cache_v_scale = None
if cache_v_scale_attr:
cache_v_scale = self.create_parameter(
shape=[config.num_heads ],
shape=[self.num_heads],
attr=cache_v_scale_attr,
dtype='float32',
is_bias=False,
Expand All @@ -461,7 +461,7 @@ def __init__(self, config: FusedMultiTransformerConfig):
cache_k_out_scale = None
if cache_k_out_scale_attr:
cache_k_out_scale = self.create_parameter(
shape=[config.num_heads ],
shape=[self.num_heads],
attr=cache_k_out_scale_attr,
dtype='float32',
is_bias=False,
Expand All @@ -470,7 +470,7 @@ def __init__(self, config: FusedMultiTransformerConfig):
cache_v_out_scale = None
if cache_v_out_scale_attr:
cache_v_out_scale = self.create_parameter(
shape=[config.num_heads ],
shape=[self.num_heads],
attr=cache_v_out_scale_attr,
dtype='float32',
is_bias=False,
Expand Down
8 changes: 4 additions & 4 deletions paddlenlp/experimental/transformers/llama/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -707,13 +707,13 @@ def set_state_dict(self, state_dict):
for i_layer, weight_scale in enumerate(v):
weight_scale = weight_scale.astype("float32")
if k == "cache_k_scale":
self.decoder.cache_k_scales[i_layer].set_value(weight_scale)
self.transformer_block.cache_k_scales[i_layer].set_value(weight_scale)
elif k == "cache_v_scale":
self.decoder.cache_v_scales[i_layer].set_value(weight_scale)
self.transformer_block.cache_v_scales[i_layer].set_value(weight_scale)
elif k == "cache_k_out_scale":
self.decoder.cache_k_out_scales[i_layer].set_value(weight_scale)
self.transformer_block.cache_k_out_scales[i_layer].set_value(weight_scale)
else:
self.decoder.cache_v_out_scales[i_layer].set_value(weight_scale)
self.transformer_block.cache_v_out_scales[i_layer].set_value(weight_scale)

for k, v in weight_scales_loader.scale.items():
if "qkv_" in k:
Expand Down

0 comments on commit 1c1a14e

Please sign in to comment.