Skip to content

Commit

Permalink
feat: dense passage retrieval
Browse files Browse the repository at this point in the history
  • Loading branch information
kooqooo committed Feb 18, 2024
1 parent 69c25e6 commit 0669e29
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 18 deletions.
26 changes: 14 additions & 12 deletions code/inference_dpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,9 @@ def run_dense_retrieval(
p_encoder=p_encoder,
q_encoder=q_encoder
)
retriever.train()
num_pre_batch = 0
retriever.train(num_pre_batch=num_pre_batch)
df = retriever.retrieve(query_or_dataset=datasets['validation'], topk=data_args.top_k_retrieval)
# retriever.get_sparse_embedding()

# if data_args.use_faiss:
Expand All @@ -146,15 +148,15 @@ def run_dense_retrieval(
# else:
# df = retriever.retrieve(datasets["validation"], topk=data_args.top_k_retrieval)

# # test data 에 대해선 정답이 없으므로 id question context 로만 데이터셋이 구성됩니다.
# if training_args.do_predict:
# f = Features(
# {
# "context": Value(dtype="string", id=None),
# "id": Value(dtype="string", id=None),
# "question": Value(dtype="string", id=None),
# }
# )
# test data 에 대해선 정답이 없으므로 id question context 로만 데이터셋이 구성됩니다.
if training_args.do_predict:
f = Features(
{
"context": Value(dtype="string", id=None),
"id": Value(dtype="string", id=None),
"question": Value(dtype="string", id=None),
}
)

# # train data 에 대해선 정답이 존재하므로 id question context answer 로 데이터셋이 구성됩니다.
# elif training_args.do_eval:
Expand All @@ -173,8 +175,8 @@ def run_dense_retrieval(
# "question": Value(dtype="string", id=None),
# }
# )
# datasets = DatasetDict({"validation": Dataset.from_pandas(df, features=f)})
# return datasets
datasets = DatasetDict({"validation": Dataset.from_pandas(df, features=f)})
return datasets


def run_mrc(
Expand Down
75 changes: 69 additions & 6 deletions code/retrieval_dpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def __init__(self, args, num_neg, tokenizer, p_encoder, q_encoder, num_sample: i
self.args = args
self.data_path = data_path
self.dataset = load_from_disk(os.path.join(self.data_path, 'train_dataset'))
self.train_dataset = self.dataset['train']
self.train_dataset = self.dataset['train'] if num_sample is -1 else self.dataset['train'][:num_sample]
self.valid_dataset = self.dataset['validation']
testdata = load_from_disk(os.path.join(self.data_path), 'test_dataset')
self.test_dataset = testdata['validation']
Expand Down Expand Up @@ -79,6 +79,8 @@ def prepare_in_batch_negative(self, dataset=None, num_neg=3, tokenizer=None):

if dataset is None:
dataset = self.dataset
dataset = concatenate_datasets([dataset["train"].flatten_indices(),
dataset["validation"].flatten_indices()])

if tokenizer is None:
tokenizer = self.tokenizer
Expand Down Expand Up @@ -123,8 +125,9 @@ def prepare_in_batch_negative(self, dataset=None, num_neg=3, tokenizer=None):
self.passage_dataloader = DataLoader(passage_dataset, batch_size=self.args.per_device_train_batch_size)


def prepare_pre_batch_negative(self):

def prepare_in_batch_negative_for_checkpoint(self):
# reference: https://arxiv.org/abs/2012.12624
# reference: https://github.com/boostcampaitech5/level2_nlp_mrc-nlp-04/blob/main/dpr_retrieval.py

q_seqs = self.tokenizer(
self.train_dataset['question'], padding="max_length",
Expand Down Expand Up @@ -314,7 +317,7 @@ def retrieve(
total = []
with timer("query exhaustive search"):
doc_scores, doc_indices = self.get_relevant_doc_bulk(
query_or_dataset["question"], k=topk
query_or_dataset["question"], k=topk, p_encoder=self.p_encoder, q_encoder=self.q_encoder
)
for idx, example in enumerate(
tqdm(query_or_dataset, desc="Sparse retrieval: ")
Expand All @@ -336,7 +339,23 @@ def retrieve(

cqas = pd.DataFrame(total)
return cqas



def get_q_embs(self):
with torch.no_grad():
self.q_encoder.eval()

self.q_embs = []
for batch in tqdm(self.passage_dataloader):
q_inputs = {
'input_ids': batch[0].to(self.args.device),
'attention_mask': batch[1].to(self.args.device),
'token_type_ids': batch[2].to(self.args.device)
}
q_emb = self.q_encoder(**q_inputs).to('cpu')
self.q_embs.append(q_emb)
return self.q_embs


def get_relevant_doc(self, query, k=10, args=None, p_encoder=None, q_encoder=None):

Expand All @@ -363,7 +382,51 @@ def get_relevant_doc(self, query, k=10, args=None, p_encoder=None, q_encoder=Non

dot_prod_scores = torch.matmul(q_emb, torch.transpose(p_embs, 0, 1))
rank = torch.argsort(dot_prod_scores, dim=1, descending=True).squeeze()
return rank[:k]
doc_score = dot_prod_scores.squeeze()[rank][:k]
doc_indices = rank.tolist()[:k]
return doc_score, doc_indices


def get_relevant_doc_bulk(self, query, k=10, args=None, p_encoder=None, q_encoder=None):

with torch.no_grad():
p_encoder.eval()
q_encoder.eval()

# q_seqs_val = self.tokenizer([query], padding="max_length", truncation=True, return_tensors='pt').to(args.device)
# q_emb = q_encoder(**q_seqs_val).to('cpu') # (num_query=1, emb_dim)
q_embs = self.get_q_embs()

p_embs = []
for batch in self.passage_dataloader:

batch = tuple(t.to(args.device) for t in batch)
p_inputs = {
'input_ids': batch[0],
'attention_mask': batch[1],
'token_type_ids': batch[2]
}
p_emb = p_encoder(**p_inputs).to('cpu')
p_embs.append(p_emb)

# p_embs = torch.stack(p_embs, dim=0).view(len(self.passage_dataloader.dataset), -1) # (num_passage, emb_dim)

# dot_prod_scores = torch.matmul(q_emb, torch.transpose(p_embs, 0, 1))
# rank = torch.argsort(dot_prod_scores, dim=1, descending=True).squeeze()
# doc_score = dot_prod_scores.squeeze()[rank][:k]
# doc_indices = rank.tolist()[:k]
# return doc_score, doc_indices
dot_pord_scores = torch.matmul(q_embs, torch.transpose(self.p_embs, 0, 1))
doc_scores = []
doc_indices = []
for i in range(dot_pord_scores.shape[0]):
rank = torch.argsort(dot_pord_scores[i, :], dim=-1, descending=True).squeeze()
doc_scores.append(dot_pord_scores[i, :][rank].tolist()[:k])
doc_indices.append(rank.tolist()[:k])

return doc_scores, doc_indices



class BertEncoder(BertPreTrainedModel):

Expand Down

0 comments on commit 0669e29

Please sign in to comment.