Skip to content

Commit

Permalink
SpecDec updated
Browse files Browse the repository at this point in the history
  • Loading branch information
hemingkx committed Nov 8, 2023
1 parent a892d35 commit bda1141
Show file tree
Hide file tree
Showing 25 changed files with 75 additions and 80 deletions.
File renamed without changes
Binary file removed block_plugins/__pycache__/__init__.cpython-37.pyc
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file removed block_plugins/models/__pycache__/GAD.cpython-37.pyc
Binary file not shown.
Binary file not shown.
Binary file removed block_plugins/tasks/__pycache__/__init__.cpython-37.pyc
Binary file not shown.
Binary file not shown.
Binary file added data/.DS_Store
Binary file not shown.
8 changes: 4 additions & 4 deletions encoder_initial.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@


def model_preparing(model_path, save_path):
"""only save the AT encoder params to initialize the NAT drafter's encoder"""
"""only save the AR encoder params to initialize the NAR drafter's encoder"""
key_l = []
raw_model = torch.load(model_path)
for key in raw_model['model']:
Expand All @@ -17,7 +17,7 @@ def model_preparing(model_path, save_path):


def param_checking(model1, model2):
"""check the parameters of the AT verifier and the NAT drafter"""
"""check the parameters of the AR verifier and the NAR drafter"""
key_l1 = []
key_l2 = []
raw_model1 = torch.load(model1)
Expand All @@ -37,6 +37,6 @@ def param_checking(model1, model2):


if __name__ == "__main__":
AR_path = './checkpoints/wmt14-en-de-base-at-verifier.pt' # the dir that contains AT verifier checkpoint
save_path = './checkpoints/initial_checkpoint.pt' # the save dir of your fairseq NAT drafter checkpoints
AR_path = './checkpoints/wmt14-en-de-base-at-verifier.pt' # the dir that contains AR verifier checkpoint
save_path = './checkpoints/initial_checkpoint.pt' # the save dir of your fairseq NAR drafter checkpoints
model_preparing(AR_path, save_path)
32 changes: 16 additions & 16 deletions inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,13 +142,13 @@ def forward_decoder(model, input_tokens, encoder_out, incremental_state=None,


@torch.no_grad()
def gad_generate(data_lines, model, AR_model, task, block_size, batch_size, device, beta=1, tau=0, max_len=200):
# Generalized Aggressive Decoding
def specdec_generate(data_lines, model, AR_model, task, block_size, batch_size, device, beta=1, tau=0, max_len=200):
# Speculative Decoding
src_dict = task.source_dictionary
tgt_dict = task.target_dictionary
data_size = len(data_lines)
all_results = []
logger.info(f'GAD generate')
logger.info(f'SpecDec generate')
start = time.perf_counter()
for start_idx in tqdm(range(0, data_size, batch_size)):
batch_size = min(data_size - start_idx, batch_size)
Expand All @@ -167,9 +167,9 @@ def gad_generate(data_lines, model, AR_model, task, block_size, batch_size, devi
start_pos_list = [0] * batch_size
finish_list = []
for step in range(0, max_len):
prev_output_tokens, start_pos_list = gad_forward(start_pos_list, block_size, batch_size,
tgt_dict, prev_output_tokens,
encoder_out, AR_encoder_out, model, AR_model, beta, tau)
prev_output_tokens, start_pos_list = specdec_forward(start_pos_list, block_size, batch_size,
tgt_dict, prev_output_tokens, encoder_out,
AR_encoder_out, model, AR_model, beta, tau)
for i, start_pos in enumerate(start_pos_list):
if i not in finish_list:
if start_pos == -1:
Expand All @@ -187,8 +187,8 @@ def gad_generate(data_lines, model, AR_model, task, block_size, batch_size, devi


@torch.no_grad()
def gad_forward(start_pos_list, block_size, batch_size, tgt_dict, prev_output_tokens,
encoder_out, AR_encoder_out, model, AR_model, beta, tau, max_len=200):
def specdec_forward(start_pos_list, block_size, batch_size, tgt_dict, prev_output_tokens,
encoder_out, AR_encoder_out, model, AR_model, beta, tau, max_len=200):
pad_tokens = [[tgt_dict.pad()] * (max_len + block_size) for _ in range(batch_size)]
for i in range(batch_size):
pad_tokens[i][:len(prev_output_tokens[i])] = prev_output_tokens[i]
Expand Down Expand Up @@ -258,9 +258,9 @@ def gad_forward(start_pos_list, block_size, batch_size, tgt_dict, prev_output_to
parser.add_argument('--output-path', type=str, default=None,
help='path to output file')
parser.add_argument('--AR-path', type=str, default=None,
help='path to AR model')
help='path to autoregressive model (to be accelerated)')
parser.add_argument('--strategy', type=str, default='fairseq',
help='decoding strategy, choose from: fairseq, AR, gad')
help='decoding strategy, choose from: fairseq, AR, specdec')
parser.add_argument('--batch', type=int, default=None,
help='batch size')
parser.add_argument('--block-size', type=int, default=5,
Expand All @@ -283,7 +283,7 @@ def gad_forward(start_pos_list, block_size, batch_size, tgt_dict, prev_output_to
device = torch.device('cuda')

# NAR drafter
if cmd_args.strategy == 'gad':
if cmd_args.strategy == 'specdec':
logger.info("loading model(s) from {}".format(cfg.common_eval.path))
models, _model_args, _model_task = load_model_ensemble_and_task(filenames=[cfg.common_eval.path], task=task)
model = models[0].to(device).eval()
Expand All @@ -308,11 +308,11 @@ def gad_forward(start_pos_list, block_size, batch_size, tgt_dict, prev_output_to
logger.info("Decoding Strategy: Simplified AR")
remove_bpe_results, delta = baseline_generate(bpe_sents, AR_model, _AR_model_task, cmd_args.batch, device)
logger.info(f'Simplified AR generate: {delta}')
elif cmd_args.strategy == 'gad':
logger.info("Decoding Strategy: GAD")
remove_bpe_results, delta = gad_generate(bpe_sents, model, AR_model, task, cmd_args.block_size, cmd_args.batch,
device, beta=cmd_args.beta, tau=cmd_args.tau)
logger.info(f'GAD generate: {delta}')
elif cmd_args.strategy == 'specdec':
logger.info("Decoding Strategy: SpecDec")
remove_bpe_results, delta = specdec_generate(bpe_sents, model, AR_model, task, cmd_args.block_size, cmd_args.batch,
device, beta=cmd_args.beta, tau=cmd_args.tau)
logger.info(f'SpecDec generate: {delta}')
else:
logger.info("Decoding Strategy: fairseq")
remove_bpe_results, delta = fairseq_generate(bpe_sents, cfg, AR_models, _AR_model_task, cmd_args.batch, device)
Expand Down
16 changes: 8 additions & 8 deletions inference.sh
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
data_dir=./data # the dir that contains dict files
checkpoint_path=./checkpoints/wmt14-en-de-base-nat-drafter-checkpoint.avg10.pt # the dir that contains NAT drafter checkpoint
AR_checkpoint_path=./checkpoints/wmt14-en-de-base-at-verifier.pt # the dir that contains AT verifier checkpoint
input_path=./test.en # the dir that contains bpe test files
data_dir=./data/wmt14.en-de # the dir that contains dict files
checkpoint_path=/home/xiaheming/data/SpecDec/wmt14-en-de-base-nat-drafter-checkpoint.avg10.pt # the dir that contains NAT drafter checkpoint
AR_checkpoint_path=/home/xiaheming/data/SpecDec/wmt14-en-de-base-at-verifier.pt # the dir that contains AT verifier checkpoint
input_path=./data/wmt14.en-de/test.en # the dir that contains bpe test files
output_path=./output/block.out # the dir for outputs

strategy='gad' # fairseq, AR, gad
batch=32
strategy='specdec' # fairseq, AR, specdec
batch=1
beam=5

beta=5
Expand All @@ -16,8 +16,8 @@ src=en
tgt=de


python inference.py ${data_dir} --path ${checkpoint_path} \
--user-dir block_plugins --task translation_lev_modified --remove-bpe --max-sentences 20 --source-lang ${src} \
python inference_paper.py ${data_dir} --path ${checkpoint_path} \
--user-dir specdec_plugins --task translation_lev_modified --remove-bpe --max-sentences 20 --source-lang ${src} \
--target-lang ${tgt} --iter-decode-max-iter 0 --iter-decode-eos-penalty 0 --iter-decode-with-beam 1 \
--gen-subset test --AR-path ${AR_checkpoint_path} --input-path ${input_path} --output-path ${output_path} \
--block-size ${block_size} --beta ${beta} --tau ${tau} --batch ${batch} --beam ${beam} --strategy ${strategy}
14 changes: 7 additions & 7 deletions inference_drafter.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,12 @@ def write_result(results, output_file):


@torch.no_grad()
def gad_generate(data_lines, model, task, block_size, device, max_len=200):
def drafter_generate(data_lines, model, task, block_size, device, max_len=200):
src_dict = task.source_dictionary
tgt_dict = task.target_dictionary
data_size = len(data_lines)
all_results = []
logger.info(f'GAD generate')
logger.info(f'Spec-Drafter generate')
pass_tokens = [0] * max_len
sent_nums = [0] * max_len
start = time.perf_counter()
Expand All @@ -44,8 +44,8 @@ def gad_generate(data_lines, model, task, block_size, device, max_len=200):
prev_output_tokens = [tgt_dict.unk()] * block_size
start_pos = 0
for step in range(0, max_len):
start_pos, prev_output_tokens, pass_token = gad_forward(start_pos, block_size, tgt_dict,
prev_output_tokens, encoder_out, model)
start_pos, prev_output_tokens, pass_token = drafter_forward(start_pos, block_size, tgt_dict,
prev_output_tokens, encoder_out, model)
pass_tokens[step] += pass_token
sent_nums[step] += 1
if start_pos == -1:
Expand Down Expand Up @@ -74,7 +74,7 @@ def gad_generate(data_lines, model, task, block_size, device, max_len=200):


@torch.no_grad()
def gad_forward(start_pos, block_size, tgt_dict, prev_output_tokens, encoder_out, model, max_len=200):
def drafter_forward(start_pos, block_size, tgt_dict, prev_output_tokens, encoder_out, model, max_len=200):
output_tokens = torch.tensor([prev_output_tokens]).to(device)
block_mask = torch.zeros_like(output_tokens).to(output_tokens)
block_mask[0][start_pos:start_pos + block_size] = 1
Expand Down Expand Up @@ -136,8 +136,8 @@ def gad_forward(start_pos, block_size, tgt_dict, prev_output_tokens, encoder_out
with open(cmd_args.input_path, 'r') as f:
bpe_sents = [l.strip() for l in f.readlines()]

logger.info("Decoding Strategy: GAD")
remove_bpe_results, delta = gad_generate(bpe_sents, model, task, cmd_args.block_size, device)
logger.info("Decoding Strategy: Spec-Drafter")
remove_bpe_results, delta = drafter_generate(bpe_sents, model, task, cmd_args.block_size, device)
logger.info(f'GAD generate: {delta}')

if cmd_args.output_path is not None:
Expand Down
34 changes: 17 additions & 17 deletions inference_paper.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,16 +146,16 @@ def forward_decoder(model,


@torch.no_grad()
def gad_generate(data_lines, model, AR_model, task, block_size, device, beta=1, tau=0, max_len=200):
# Generalized Aggressive Decoding
def specdec_generate(data_lines, model, AR_model, task, block_size, device, beta=1, tau=0, max_len=200):
# Speculative Decoding
src_dict = task.source_dictionary
tgt_dict = task.target_dictionary
encoder_state_ids = []
for i in range(len(AR_model.decoder.layers)):
encoder_state_ids.append(AR_model.decoder.layers[i].encoder_attn._incremental_state_id)
data_size = len(data_lines)
all_results = []
logger.info(f'GAD generate')
logger.info(f'SpecDec generate')
pass_tokens = [0] * max_len
sent_nums = [0] * max_len
start = time.perf_counter()
Expand All @@ -171,11 +171,11 @@ def gad_generate(data_lines, model, AR_model, task, block_size, device, beta=1,
prev_output_tokens = [tgt_dict.unk()] * block_size
start_pos = 0
for step in range(0, max_len):
start_pos, prev_output_tokens, pass_token = gad_forward(incremental_state, encoder_state_ids,
start_pos, block_size, tgt_dict,
prev_output_tokens,
encoder_out, AR_encoder_out, model,
AR_model, beta, tau)
start_pos, prev_output_tokens, pass_token = specdec_forward(incremental_state, encoder_state_ids,
start_pos, block_size, tgt_dict,
prev_output_tokens,
encoder_out, AR_encoder_out, model,
AR_model, beta, tau)
pass_tokens[step] += pass_token
sent_nums[step] += 1
if start_pos == -1:
Expand Down Expand Up @@ -204,8 +204,8 @@ def gad_generate(data_lines, model, AR_model, task, block_size, device, beta=1,


@torch.no_grad()
def gad_forward(incremental_state, encoder_state_ids, start_pos, block_size, tgt_dict, prev_output_tokens,
encoder_out, AR_encoder_out, model, AR_model, beta, tau, max_len=200):
def specdec_forward(incremental_state, encoder_state_ids, start_pos, block_size, tgt_dict, prev_output_tokens,
encoder_out, AR_encoder_out, model, AR_model, beta, tau, max_len=200):
output_tokens = torch.tensor([prev_output_tokens]).to(device)
block_mask = torch.zeros_like(output_tokens).to(output_tokens)
block_mask[0][start_pos:start_pos + block_size] = 1
Expand Down Expand Up @@ -263,7 +263,7 @@ def gad_forward(incremental_state, encoder_state_ids, start_pos, block_size, tgt
parser.add_argument('--AR-path', type=str, default=None,
help='path to AR model')
parser.add_argument('--strategy', type=str, default='fairseq',
help='decoding strategy, choose from: fairseq, AR, gad')
help='decoding strategy, choose from: fairseq, AR, specdec')
parser.add_argument('--batch', type=int, default=None,
help='batch size')
parser.add_argument('--block-size', type=int, default=5,
Expand All @@ -286,7 +286,7 @@ def gad_forward(incremental_state, encoder_state_ids, start_pos, block_size, tgt
device = torch.device('cuda')

# NAR drafter
if cmd_args.strategy == 'gad':
if cmd_args.strategy == 'specdec':
logger.info("loading model(s) from {}".format(cfg.common_eval.path))
models, _model_args, _model_task = load_model_ensemble_and_task(filenames=[cfg.common_eval.path], task=task)
model = models[0].to(device).eval()
Expand All @@ -304,11 +304,11 @@ def gad_forward(incremental_state, encoder_state_ids, start_pos, block_size, tgt
logger.info("Decoding Strategy: Simplified AR")
remove_bpe_results, delta = baseline_generate(bpe_sents, AR_model, _AR_model_task, device)
logger.info(f'Simplified AR generate: {delta}')
elif cmd_args.strategy == 'gad':
logger.info("Decoding Strategy: GAD")
remove_bpe_results, delta = gad_generate(bpe_sents, model, AR_model, task, cmd_args.block_size, device,
beta=cmd_args.beta, tau=cmd_args.tau)
logger.info(f'GAD generate: {delta}')
elif cmd_args.strategy == 'specdec':
logger.info("Decoding Strategy: SpecDec")
remove_bpe_results, delta = specdec_generate(bpe_sents, model, AR_model, task, cmd_args.block_size, device,
beta=cmd_args.beta, tau=cmd_args.tau)
logger.info(f'SpecDec generate: {delta}')
else:
logger.info("Decoding Strategy: fairseq")
remove_bpe_results, delta = fairseq_generate(bpe_sents, cfg, AR_models, _AR_model_task, cmd_args.batch, device)
Expand Down
Loading

0 comments on commit bda1141

Please sign in to comment.