Skip to content

Commit

Permalink
Add snip_momentum structured pruning example with 80% sparsity ratio (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
ftian1 authored May 8, 2023
1 parent 40e33a4 commit 2ec4be7
Show file tree
Hide file tree
Showing 2 changed files with 264 additions and 0 deletions.
103 changes: 103 additions & 0 deletions compression/bert/bash_script/pruning_sparse_snip_momentum.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
#!/bin/bash
DIR=`pwd`
export CUDA_VISIBLE_DEVICES=0
TASK_NAME=mnli #mnli sst2 stsb mnli qqp rte cola mrpc qnli
STAGE=one_stage
LRATE=5e-5
EPOCH=10
WARMUP_EPOCH=1
BATCH_SIZE_PER_GPU=32
NAME="pruning_sparse"
SAVE_PATH=./out/${NAME}/
mkdir -p ${SAVE_PATH}

###Layer Reduction
LAYER_REDUCTION_ENABLE="false"
FP16_ENABLE="false"

###weight quantization
WEIGHT_QUANT_ENABLE="false"
Q_GROUP=64
W_BIT1=4
W_BIT2=2
###activation quantization
ACTIVATION_QUANT_ENABLE="false"
A_BIT1=8
A_BIT2=4
#############pruning
###sparse_pruning (structural pruning)
SPARSE_PRUNING_ENABLE="true" #<=============================================================
SPARSE_PRUNING_BLOCK_PATTERN="\"4x1\""
SPARSE_PRUNING_OFFSET_STRIDE=1000
SPARSE_PRUNING_OFFSET=1000
SPARSE_PRUNING_OFFSET_END=51000
SPARSE_PRUNING_EXCLUDED_MODULES="[\"classifier\", \"pooler\"]"
S_DENSE_RATIO=0.2 #<=============================================================
###row_pruning (unstructure pruning)
ROW_PRUNING_ENABLE="false"
R_DENSE_RATIO=0.6
###HEAD_PRUNING_ENABLE
HEAD_PRUNING_ENABLE="false"
H_DENSE_RATIO=0.6

template_json="config/ds_config_structural_pruning_TEMPLATE.json"
config_json="config/ds_config_structural_${NAME}.json"


if [ "${FP16_ENABLE}" = "true" ]; then
QuantW_FORWARD="false"
else
QuantW_FORWARD="true"
fi
sed "s/LAYER_REDUCTION_ENABLE/${LAYER_REDUCTION_ENABLE}/" ${template_json} \
| sed "s/WEIGHT_QUANT_ENABLE/${WEIGHT_QUANT_ENABLE}/" \
| sed "s/Q_GROUP/${Q_GROUP}/" \
| sed "s/W_BIT1/${W_BIT1}/" \
| sed "s/W_BIT2/${W_BIT2}/" \
| sed "s/ACTIVATION_QUANT_ENABLE/${ACTIVATION_QUANT_ENABLE}/" \
| sed "s/A_BIT1/${A_BIT1}/" \
| sed "s/A_BIT2/${A_BIT2}/" \
| sed "s/SPARSE_PRUNING_ENABLE/${SPARSE_PRUNING_ENABLE}/" \
| sed "s/SPARSE_PRUNING_BLOCK_PATTERN/${SPARSE_PRUNING_BLOCK_PATTERN}/" \
| sed "s/SPARSE_PRUNING_OFFSET_STRIDE/${SPARSE_PRUNING_OFFSET_STRIDE}/" \
| sed "s/SPARSE_PRUNING_OFFSET_END/${SPARSE_PRUNING_OFFSET_END}/" \
| sed "s/SPARSE_PRUNING_OFFSET/${SPARSE_PRUNING_OFFSET}/" \
| sed "s/SPARSE_PRUNING_EXCLUDED_MODULES/${SPARSE_PRUNING_EXCLUDED_MODULES}/" \
| sed "s/S_DENSE_RATIO/${S_DENSE_RATIO}/" \
| sed "s/ROW_PRUNING_ENABLE/${ROW_PRUNING_ENABLE}/" \
| sed "s/R_DENSE_RATIO/${R_DENSE_RATIO}/" \
| sed "s/HEAD_PRUNING_ENABLE/${HEAD_PRUNING_ENABLE}/" \
| sed "s/H_DENSE_RATIO/${H_DENSE_RATIO}/" \
| sed "s/FP16_ENABLE/${FP16_ENABLE}/" \
| sed "s/QuantW_FORWARD/${QuantW_FORWARD}/" \
| sed "s/BATCH_SIZE_PER_GPU/${BATCH_SIZE_PER_GPU}/" \
> ${config_json}

CONFIG=${config_json}
# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% if users provide *NO* models, use the following script %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% the following command will first download huggingface models and then compress %%%%%%%
MODEL=yoshitomo-matsubara/bert-base-uncased-${TASK_NAME} ## for both student and teacher
run_cmd="python -m torch.distributed.launch --nproc_per_node=1 \
--master_port 6618 \
run_glue_no_trainer.py \
--seed 42 \
--distill_method ${STAGE} \
--model_name_or_path ${MODEL} \
--task_name $TASK_NAME \
--max_length 128 \
--pad_to_max_length \
--per_device_train_batch_size ${BATCH_SIZE_PER_GPU} \
--per_device_eval_batch_size 64 \
--learning_rate $LRATE \
--num_train_epochs ${EPOCH}\
--num_warmup_epochs ${WARMUP_EPOCH} \
--eval_step 1000 \
--deepspeed_config ${CONFIG} \
--deepspeed \
--save_best_model --clean_best_model \
--gradient_accumulation_steps 1 \
--output_dir ${SAVE_PATH} | tee -a ${SAVE_PATH}/train.log"

echo ${run_cmd}
eval ${run_cmd}
set +x
161 changes: 161 additions & 0 deletions compression/bert/config/ds_config_structural_pruning_TEMPLATE.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
{
"train_batch_size": 32,
"train_micro_batch_size_per_gpu": BATCH_SIZE_PER_GPU,
"steps_per_print": 200,
"zero_optimization": {
"stage": 0
},
"fp16": {
"enabled": FP16_ENABLE
},
"gradient_clipping": 1.0,
"prescale_gradients": true,
"wall_clock_breakdown": false,
"compression_training": {
"layer_reduction": {
"enabled": LAYER_REDUCTION_ENABLE,
"keep_number_layer": 5,
"module_name_prefix": "bert.encoder.layer",
"teacher_layer": [
2,
4,
6,
8,
10
],
"other_module_name": [
"bert.pooler",
"bert.embeddings",
"classifier"
]
},
"weight_quantization": {
"shared_parameters": {
"enabled": WEIGHT_QUANT_ENABLE,
"quantizer_kernel": false,
"schedule_offset": 0,
"quantize_groups": Q_GROUP,
"quantize_verbose": false,
"quantization_type": "symmetric",
"quantize_weight_in_forward": QuantW_FORWARD,
"rounding": "nearest",
"fp16_mixed_quantize": {
"enabled": false,
"quantize_change_ratio": 0.1
}
},
"different_groups": {
"wq1": {
"params": {
"start_bits": W_BIT1,
"target_bits": W_BIT1,
"quantization_period": 0
},
"modules": [
"attention.self",
"word_embeddings"
]
},
"wq2": {
"params": {
"start_bits": W_BIT2,
"target_bits": W_BIT2,
"quantization_period": 0
},
"modules": [
"output.dense",
"intermediate"
]
}
}
},
"activation_quantization": {
"shared_parameters": {
"enabled": ACTIVATION_QUANT_ENABLE,
"quantization_type": "symmetric",
"range_calibration": "dynamic",
"schedule_offset": 0
},
"different_groups": {
"aq1": {
"params": {
"bits": A_BIT1
},
"modules": [
"attention.self"
]
},
"aq2": {
"params": {
"bits": A_BIT2
},
"modules": [
"output.dense",
"intermediate"
]
}
}
},
"sparse_pruning": {
"shared_parameters": {
"enabled": SPARSE_PRUNING_ENABLE,
"schedule_offset": SPARSE_PRUNING_OFFSET,
"schedule_offset_end": SPARSE_PRUNING_OFFSET_END,
"schedule_offset_stride": SPARSE_PRUNING_OFFSET_STRIDE,
"method": "snip_momentum",
"block_pattern": SPARSE_PRUNING_BLOCK_PATTERN,
"dense_ratio": S_DENSE_RATIO,
"excluded_modules": SPARSE_PRUNING_EXCLUDED_MODULES
},
"different_groups": {
}
},
"row_pruning": {
"shared_parameters": {
"enabled": ROW_PRUNING_ENABLE,
"schedule_offset": 2000,
"method": "topk"
},
"different_groups": {
"rp1": {
"params": {
"dense_ratio": R_DENSE_RATIO
},
"modules": [
"intermediate.dense"
],
"related_modules": [
[
"layer.\\w+.output.dense"
]
]
}
}
},
"head_pruning": {
"shared_parameters": {
"enabled": HEAD_PRUNING_ENABLE,
"schedule_offset": 2000,
"method": "topk",
"num_heads": 12
},
"different_groups": {
"rp1": {
"params": {
"dense_ratio": H_DENSE_RATIO
},
"modules": [
"attention.output.dense"
],
"related_modules": [
[
"self.query",
"self.key",
"self.value"
]
]
}
}
}
}
}

0 comments on commit 2ec4be7

Please sign in to comment.