Skip to content

Commit

Permalink
batch implementation and model checkpoints updated
Browse files Browse the repository at this point in the history
  • Loading branch information
hemingkx committed May 10, 2022
1 parent 04a5885 commit 6c89c93
Show file tree
Hide file tree
Showing 11 changed files with 70,306 additions and 214 deletions.
File renamed without changes.
File renamed without changes.
File renamed without changes.
34,979 changes: 34,979 additions & 0 deletions data/wmt16.en-ro/dict.en.txt

Large diffs are not rendered by default.

34,979 changes: 34,979 additions & 0 deletions data/wmt16.en-ro/dict.ro.txt

Large diffs are not rendered by default.

27 changes: 27 additions & 0 deletions data/wmt16.en-ro/get_data.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
"""
#!/usr/bin/env python3 -u
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
"""

mkdir data

# WMT16 EN-RO
cd data
mkdir wmt16.en-ro
cd wmt16.en-ro
gdown https://drive.google.com/uc?id=1YrAwCEuktG-iDVxtEW-FE72uFTLc5QMl
tar -zxvf wmt16.tar.gz
mv wmt16/en-ro/train/corpus.bpe.en train.en
mv wmt16/en-ro/train/corpus.bpe.ro train.ro
mv wmt16/en-ro/dev/dev.bpe.en valid.en
mv wmt16/en-ro/dev/dev.bpe.ro valid.ro
mv wmt16/en-ro/test/test.bpe.en test.en
mv wmt16/en-ro/test/test.bpe.ro test.ro
rm wmt16.tar.gz
rm -r wmt16
282 changes: 165 additions & 117 deletions inference.py

Large diffs are not rendered by default.

20 changes: 11 additions & 9 deletions inference.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,20 @@ AR_checkpoint_path=./checkpoints/wmt14-en-de-base-at-verifier.pt # the dir that
input_path=./test.en # the dir that contains bpe test files
output_path=./output/block.out # the dir for outputs

BATCH=256
BEAM=1
strategy='gad' # fairseq, AR, gad
batch=32
beam=5

beta=5
tau=3.0
block_size=25
strategy='block' # 'fairseq', 'AR', 'block'

src=en
tgt=de


python inference.py $data_dir --path $checkpoint_path --user-dir block_plugins --task translation_lev_modified \
--remove-bpe --max-sentences 20 --source-lang ${src} --target-lang ${tgt} --iter-decode-max-iter 0 \
--iter-decode-eos-penalty 0 --iter-decode-with-beam 1 --gen-subset test --strategy ${strategy} \
--AR-path $AR_checkpoint_path --beam $BEAM --input-path $input_path --output-path $output_path --batch $BATCH \
--block-size ${block_size}

python inference.py ${data_dir} --path ${checkpoint_path} \
--user-dir block_plugins --task translation_lev_modified --remove-bpe --max-sentences 20 --source-lang ${src} \
--target-lang ${tgt} --iter-decode-max-iter 0 --iter-decode-eos-penalty 0 --iter-decode-with-beam 1 \
--gen-subset test --AR-path ${AR_checkpoint_path} --input-path ${input_path} --output-path ${output_path} \
--block-size ${block_size} --beta ${beta} --tau ${tau} --batch ${batch} --beam ${beam} --strategy ${strategy}
186 changes: 134 additions & 52 deletions pass_count.py → inference_paper.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,32 +28,82 @@ def write_result(results, output_file):


@torch.no_grad()
def AR_forward_decoder(model,
input_tokens,
encoder_out: Dict[str, List[Tensor]],
incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
parallel_forward_start_pos=None,
temperature: float = 1.0,
use_log_softmax=True,
beta: int = 1,
tau: float = 0.0):
def fairseq_generate(data_lines, cfg, models, task, batch_size, device):
# fairseq original decoding implementation
src_dict = task.source_dictionary
tgt_dict = task.target_dictionary
generator = task.build_generator(models, cfg.generation)
data_size = len(data_lines)
all_results = []
logger.info(f'Fairseq generate batch {batch_size}')
start = time.perf_counter()
for start_idx in tqdm(range(0, data_size, batch_size)):
batch_lines = [line for line in data_lines[start_idx: min(start_idx + batch_size, data_size)]]
batch_ids = [src_dict.encode_line(sentence, add_if_not_exist=False).long() for sentence in batch_lines]
lengths = torch.LongTensor([t.numel() for t in batch_ids])
batch_dataset = task.build_dataset_for_inference(batch_ids, lengths)
batch = batch_dataset.collater(batch_dataset)
batch = utils.apply_to_sample(lambda t: t.to(device), batch)
translations = generator.generate(models, batch, prefix_tokens=None)
results = []
for id, hypos in zip(batch["id"].tolist(), translations):
results.append((id, hypos))
batched_hypos = [hypos for _, hypos in sorted(results, key=lambda x: x[0])]
all_results.extend([tgt_dict.string(hypos[0]['tokens']) for hypos in batched_hypos])
delta = time.perf_counter() - start
remove_bpe_results = [line.replace('@@ ', '') for line in all_results]
return remove_bpe_results, delta


@torch.no_grad()
def baseline_forward_decoder(model,
input_tokens,
encoder_out: Dict[str, List[Tensor]],
incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
parallel_forward_start_pos=None,
temperature: float = 1.0):
decoder_out = model.decoder.forward(input_tokens,
encoder_out=encoder_out,
incremental_state=incremental_state,
parallel_forward_start_pos=parallel_forward_start_pos)
decoder_out_tuple = (decoder_out[0].div_(temperature), decoder_out[1])
if use_log_softmax:
probs = model.get_normalized_probs(decoder_out_tuple, log_probs=True, sample=None)
else:
probs = decoder_out_tuple[0]
topk_scores, indexes = torch.topk(probs, beta, dim=-1)
topk_scores = topk_scores[0].tolist()
indexes = indexes[0].tolist()
for i in range(len(topk_scores)):
for j, s in enumerate(topk_scores[i]):
if topk_scores[i][0] - s > tau:
indexes[i][j] = -1
return indexes
pred_tokens = torch.argmax(decoder_out_tuple[0], dim=-1).squeeze(0)
return pred_tokens


@torch.no_grad()
def baseline_generate(data_lines, model, task, device, max_len=200):
# simplified AR greedy decoding
src_dict = task.source_dictionary
tgt_dict = task.target_dictionary
data_size = len(data_lines)
all_results = []
logger.info(f'Baseline generate')
start = time.perf_counter()
for start_idx in tqdm(range(0, data_size)):
bpe_line = data_lines[start_idx]
src_tokens = src_dict.encode_line(bpe_line, add_if_not_exist=False).long()
net_input = {'src_tokens': src_tokens.unsqueeze(0).to(device),
'src_lengths': torch.LongTensor([src_tokens.numel()]).to(device)}
encoder_out = model.encoder.forward_torchscript(net_input)
incremental_state = torch.jit.annotate(Dict[str, Dict[str, Optional[Tensor]]],
torch.jit.annotate(Dict[str, Dict[str, Optional[Tensor]]], {}))
tokens = [tgt_dict.eos()]

for step in range(0, max_len):
cur_input_tokens = torch.tensor([tokens]).to(device).long()
pred_token = baseline_forward_decoder(model,
cur_input_tokens,
encoder_out,
incremental_state).item()
if pred_token == tgt_dict.eos():
break
else:
tokens.append(pred_token)
all_results.append(tgt_dict.string(tokens[1:]))
delta = time.perf_counter() - start
remove_bpe_results = [line.replace('@@ ', '') for line in all_results]
return remove_bpe_results, delta


def cut_incremental_state(incremental_state, keep_len, encoder_state_ids):
Expand All @@ -68,38 +118,60 @@ def cut_incremental_state(incremental_state, keep_len, encoder_state_ids):
incremental_state[n][k] = incremental_state[n][k][:, :keep_len]


def block_generate(data_lines, model, AR_model, task, block_size, device, max_len=200, beta=1, tau=0):
@torch.no_grad()
def forward_decoder(model,
input_tokens,
encoder_out: Dict[str, List[Tensor]],
incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
parallel_forward_start_pos=None,
temperature: float = 1.0,
beta: int = 1,
tau: float = 0.0):
decoder_out = model.decoder.forward(input_tokens,
encoder_out=encoder_out,
incremental_state=incremental_state,
parallel_forward_start_pos=parallel_forward_start_pos)
decoder_out_tuple = (decoder_out[0].div_(temperature), decoder_out[1])
topk_scores, indexes = torch.topk(decoder_out_tuple[0], beta, dim=-1)
topk_scores = topk_scores[0].tolist()
indexes = indexes[0].tolist()
for i in range(len(topk_scores)):
for j, s in enumerate(topk_scores[i]):
if topk_scores[i][0] - s > tau:
indexes[i][j] = -1
return indexes


def gad_generate(data_lines, model, AR_model, task, block_size, device, beta=1, tau=0, max_len=200):
# Generalized Aggressive Decoding
src_dict = task.source_dictionary
tgt_dict = task.target_dictionary
encoder_state_ids = []
for i in range(len(AR_model.decoder.layers)):
encoder_state_ids.append(AR_model.decoder.layers[i].encoder_attn._incremental_state_id)
data_size = len(data_lines)
all_results = []
logger.info(f'Block generate')
logger.info(f'GAD generate')
pass_tokens = [0] * max_len
sent_nums = [0] * max_len
start = time.perf_counter()
for start_idx in tqdm(range(0, data_size)):
bpe_line = data_lines[start_idx]

src_tokens = src_dict.encode_line(bpe_line, add_if_not_exist=False).long()
net_input = {'src_tokens': src_tokens.unsqueeze(0).to(device),
'src_lengths': torch.LongTensor([src_tokens.numel()]).to(device)}
AR_encoder_out = AR_model.encoder.forward_torchscript(net_input)
encoder_out = model.encoder.forward_torchscript(net_input)

incremental_state = torch.jit.annotate(Dict[str, Dict[str, Optional[Tensor]]],
torch.jit.annotate(Dict[str, Dict[str, Optional[Tensor]]], {}))

prev_output_tokens = [tgt_dict.unk()] * block_size
start_pos = 0
for step in range(0, max_len):
start_pos, prev_output_tokens, pass_token = block_forward(incremental_state, encoder_state_ids,
start_pos, block_size, tgt_dict,
prev_output_tokens,
encoder_out, AR_encoder_out, model,
AR_model, beta, tau)
start_pos, prev_output_tokens, pass_token = gad_forward(incremental_state, encoder_state_ids,
start_pos, block_size, tgt_dict,
prev_output_tokens,
encoder_out, AR_encoder_out, model,
AR_model, beta, tau)
pass_tokens[step] += pass_token
sent_nums[step] += 1
if start_pos == -1:
Expand Down Expand Up @@ -127,8 +199,8 @@ def block_generate(data_lines, model, AR_model, task, block_size, device, max_le
return remove_bpe_results, delta


def block_forward(incremental_state, encoder_state_ids, start_pos, block_size, tgt_dict, prev_output_tokens,
encoder_out, AR_encoder_out, model, AR_model, beta, tau, max_len=200):
def gad_forward(incremental_state, encoder_state_ids, start_pos, block_size, tgt_dict, prev_output_tokens,
encoder_out, AR_encoder_out, model, AR_model, beta, tau, max_len=200):
output_tokens = torch.tensor([prev_output_tokens]).to(device)
_scores, _tokens = model.decoder(
normalize=False,
Expand All @@ -139,17 +211,14 @@ def block_forward(incremental_state, encoder_state_ids, start_pos, block_size, t
prev_output_tokens[start_pos:start_pos + block_size] = _tokens[0].tolist()[start_pos:start_pos + block_size]

cut_incremental_state(incremental_state, keep_len=start_pos, encoder_state_ids=encoder_state_ids)

cur_span_input_tokens = torch.tensor([[tgt_dict.eos()] + prev_output_tokens]).to(device)
AR_topk_tokens = AR_forward_decoder(AR_model,
cur_span_input_tokens,
AR_encoder_out,
incremental_state,
use_log_softmax=False,
parallel_forward_start_pos=start_pos,
beta=beta,
tau=tau)

AR_topk_tokens = forward_decoder(AR_model,
cur_span_input_tokens,
AR_encoder_out,
incremental_state,
parallel_forward_start_pos=start_pos,
beta=beta,
tau=tau)

bifurcation = block_size
for i, (token, AR_topk_token) in enumerate(zip(prev_output_tokens[start_pos:], AR_topk_tokens[:-1][:])):
Expand All @@ -162,7 +231,7 @@ def block_forward(incremental_state, encoder_state_ids, start_pos, block_size, t

pass_token = 0
find_eos = False
for i, o in enumerate(prev_output_tokens[start_pos:start_pos + bifurcation] + [AR_topk_tokens[bifurcation][0]]):
for i, o in enumerate(next_output_tokens[start_pos:start_pos + bifurcation + 1]):
if o == tgt_dict.eos() or i + start_pos == max_len:
next_output_tokens = next_output_tokens[0:start_pos + i]
start_pos = -1
Expand All @@ -184,7 +253,11 @@ def block_forward(incremental_state, encoder_state_ids, start_pos, block_size, t
parser.add_argument('--output-path', type=str, default=None,
help='path to output file')
parser.add_argument('--AR-path', type=str, default=None,
help='path to AT verifier model')
help='path to AR model')
parser.add_argument('--strategy', type=str, default='fairseq',
help='decoding strategy, choose from: fairseq, AR, gad')
parser.add_argument('--batch', type=int, default=None,
help='batch size')
parser.add_argument('--block-size', type=int, default=5,
help='block size')
parser.add_argument('--beta', type=int, default=1,
Expand All @@ -197,20 +270,19 @@ def block_forward(incremental_state, encoder_state_ids, start_pos, block_size, t

cfg = convert_namespace_to_omegaconf(cmd_args)

# Load dataset splits
task = tasks.setup_task(cfg.task)

# NAR drafter
logger.info("loading model(s) from {}".format(cfg.common_eval.path))
# load model
models, _model_args, _model_task = load_model_ensemble_and_task(filenames=[cfg.common_eval.path],
task=task)
models, _model_args, _model_task = load_model_ensemble_and_task(filenames=[cfg.common_eval.path], task=task)

if cmd_args.cpu:
device = torch.device('cpu')
else:
device = torch.device('cuda')
model = models[0].to(device).eval()

# AR verifier
AR_model = None
AR_models = None
_AR_model_task = None
Expand All @@ -223,9 +295,19 @@ def block_forward(incremental_state, encoder_state_ids, start_pos, block_size, t
with open(cmd_args.input_path, 'r') as f:
bpe_sents = [l.strip() for l in f.readlines()]

logger.info("Decoding Strategy: Block")
remove_bpe_results, delta = block_generate(bpe_sents, model, AR_model, task, cmd_args.block_size, device, beta=cmd_args.beta, tau=cmd_args.tau)
logger.info(f'Block generate: {delta}')
if cmd_args.strategy == 'AR':
logger.info("Decoding Strategy: Simplified AR")
remove_bpe_results, delta = baseline_generate(bpe_sents, AR_model, _AR_model_task, device)
logger.info(f'Simplified AR generate: {delta}')
elif cmd_args.strategy == 'gad':
logger.info("Decoding Strategy: GAD")
remove_bpe_results, delta = gad_generate(bpe_sents, model, AR_model, task, cmd_args.block_size, device,
beta=cmd_args.beta, tau=cmd_args.tau)
logger.info(f'GAD generate: {delta}')
else:
logger.info("Decoding Strategy: fairseq")
remove_bpe_results, delta = fairseq_generate(bpe_sents, cfg, AR_models, _AR_model_task, cmd_args.batch, device)
logger.info(f'Fairseq generate batch {cmd_args.batch}, beam {cfg.generation.beam}: {delta}')

if cmd_args.output_path is not None:
write_result(remove_bpe_results, cmd_args.output_path)
19 changes: 0 additions & 19 deletions pass_count.sh

This file was deleted.

Loading

0 comments on commit 6c89c93

Please sign in to comment.