Skip to content

Commit

Permalink
feat: temp 백업용 임시 저장
Browse files Browse the repository at this point in the history
  • Loading branch information
kooqooo committed Feb 18, 2024
1 parent 09f482d commit 69c25e6
Showing 1 changed file with 140 additions and 11 deletions.
151 changes: 140 additions & 11 deletions code/retrieval_dpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pickle
from contextlib import contextmanager
from collections import deque
from typing import List, Optional, Tuple, Union

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -39,17 +40,28 @@ def timer(name):

class DenseRetrieval:

def __init__(self, args, dataset, num_neg, tokenizer, p_encoder, q_encoder, num_sample=-1):
def __init__(self, args, num_neg, tokenizer, p_encoder, q_encoder, num_sample: int = -1, data_path: Optional[str] = './data', context_path: Optional[str] = "wikipedia_documents.json",):

'''
학습과 추론에 사용될 여러 셋업을 마쳐봅시다.
'''
self.args = args
self.dataset = dataset
if isinstance(dataset, DatasetDict):
self.dataset = concatenate_datasets(
[dataset["train"].flatten_indices(),
dataset["validation"].flatten_indices()])
self.data_path = data_path
self.dataset = load_from_disk(os.path.join(self.data_path, 'train_dataset'))
self.train_dataset = self.dataset['train']
self.valid_dataset = self.dataset['validation']
testdata = load_from_disk(os.path.join(self.data_path), 'test_dataset')
self.test_dataset = testdata['validation']
del testdata

with open(os.path.join(data_path, context_path), "r", encoding="utf-8") as f:
wiki = json.load(f)

self.contexts = list(
dict.fromkeys([v["text"] for v in wiki.values()])
) # set 은 매번 순서가 바뀌므로
print(f"Lengths of unique contexts : {len(self.contexts)}")
self.ids = list(range(len(self.contexts)))

# self.train_dataset = dataset[:num_sample]
# self.valid_dataset = dataset[num_sample]
Expand All @@ -58,10 +70,11 @@ def __init__(self, args, dataset, num_neg, tokenizer, p_encoder, q_encoder, num_
self.p_encoder = p_encoder
self.q_encoder = q_encoder
self.pwd = os.getcwd()
self.save_path = os.path.join(self.pwd, '/models/dpr')
self.save_path = os.path.join(self.pwd, 'models/dpr')

self.prepare_in_batch_negative(num_neg=num_neg)


def prepare_in_batch_negative(self, dataset=None, num_neg=3, tokenizer=None):

if dataset is None:
Expand Down Expand Up @@ -110,6 +123,55 @@ def prepare_in_batch_negative(self, dataset=None, num_neg=3, tokenizer=None):
self.passage_dataloader = DataLoader(passage_dataset, batch_size=self.args.per_device_train_batch_size)


def prepare_pre_batch_negative(self):


q_seqs = self.tokenizer(
self.train_dataset['question'], padding="max_length",
truncation=True, return_tensors='pt'
)
p_seqs = self.tokenizer(
self.train_dataset['context'], padding="max_length",
truncation=True, return_tensors='pt'
)
train_dataset = TensorDataset(
p_seqs['input_ids'], p_seqs['attention_mask'], p_seqs['token_type_ids'],
q_seqs['input_ids'], q_seqs['attention_mask'], q_seqs['token_type_ids']
)
self.train_dataloader = DataLoader(
train_dataset, shuffle=True, batch_size=self.args.per_device_train_batch_size, drop_last=False)
###################################
valid_q_seqs = self.tokenizer(
self.valid_dataset['question'], padding="max_length",
truncation=True, return_tensors='pt'
)
valid_dataset = TensorDataset(
valid_q_seqs['input_ids'], valid_q_seqs['attention_mask'], valid_q_seqs['token_type_ids']
)
self.valid_dataloader = DataLoader(
valid_dataset, batch_size=self.args.per_device_train_batch_size, drop_last=False)
###################################
test_q_seqs = self.tokenizer(
self.test_dataset['question'], padding="max_length",
truncation=True, return_tensors='pt'
)
test_dataset = TensorDataset(
test_q_seqs['input_ids'], test_q_seqs['attention_mask'], test_q_seqs['token_type_ids']
)
self.test_dataloader = DataLoader(
test_dataset, batch_size=self.args.per_device_train_batch_size, drop_last=False)
###################################
wiki_seqs = self.tokenizer(
self.contexts, padding="max_length",
truncation=True, return_tensors='pt'
)
wiki_dataset = TensorDataset(
wiki_seqs['input_ids'], wiki_seqs['attention_mask'], wiki_seqs['token_type_ids']
)
self.wiki_dataloader = DataLoader(
wiki_dataset, batch_size=self.args.per_device_train_batch_size, drop_last=False)


def train(self, args=None, override=False, num_pre_batch=0):

if args is None:
Expand Down Expand Up @@ -147,14 +209,14 @@ def train(self, args=None, override=False, num_pre_batch=0):
targets = torch.zeros(batch_size).long() # positive example은 전부 첫 번째에 위치하므로
targets = targets.to(args.device)

if num_pre_batch is not 0: # In-batch or Pre-batch
if num_pre_batch is not 0: # Pre-batch
p_inputs = {
'input_ids': batch[0].to(args.device),
'attention_mask': batch[1].to(args.device),
'token_type_ids': batch[2].to(args.device)
}

else: # negtive sampling
else: # In-batch negtive sampling
p_inputs = {
'input_ids': batch[0].view(batch_size * (self.num_neg + 1), -1).to(args.device),
'attention_mask': batch[1].view(batch_size * (self.num_neg + 1), -1).to(args.device),
Expand All @@ -169,7 +231,7 @@ def train(self, args=None, override=False, num_pre_batch=0):

p_outputs = self.p_encoder(**p_inputs) # (batch_size*(num_neg+1), emb_dim)
q_outputs = self.q_encoder(**q_inputs) # (batch_size*, emb_dim)
if num_pre_batch is not 0: # In-batch or Pre-batch
if num_pre_batch is not 0: # Pre-batch negative sampling
temp = p_outputs.clone().detach()
p_outputs = torch.cat((p_outputs, *p_queue), dim=0)
p_queue.append(temp)
Expand All @@ -180,7 +242,7 @@ def train(self, args=None, override=False, num_pre_batch=0):
loss = F.nll_loss(sim_scores, targets)
tepoch.set_postfix(loss=f'{str(loss.item())}')

else: # negative sampling
else: # In-batch negative sampling
# Calculate similarity score & loss
p_outputs = p_outputs.view(batch_size, self.num_neg + 1, -1)
q_outputs = q_outputs.view(batch_size, 1, -1)
Expand Down Expand Up @@ -209,6 +271,73 @@ def train(self, args=None, override=False, num_pre_batch=0):
torch.save(self.q_encoder.state_dict(), os.path.join(self.save_path, 'q_encoder_state_dict'))
print('encoder statedict saved')


def retrieve(
self, query_or_dataset: Union[str, Dataset], topk: Optional[int] = 1
) -> Union[Tuple[List, List], pd.DataFrame]:

"""
Arguments:
query_or_dataset (Union[str, Dataset]):
str이나 Dataset으로 이루어진 Query를 받습니다.
str 형태인 하나의 query만 받으면 `get_relevant_doc`을 통해 유사도를 구합니다.
Dataset 형태는 query를 포함한 HF.Dataset을 받습니다.
이 경우 `get_relevant_doc_bulk`를 통해 유사도를 구합니다.
topk (Optional[int], optional): Defaults to 1.
상위 몇 개의 passage를 사용할 것인지 지정합니다.
Returns:
1개의 Query를 받는 경우 -> Tuple(List, List)
다수의 Query를 받는 경우 -> pd.DataFrame: [description]
Note:
다수의 Query를 받는 경우,
Ground Truth가 있는 Query (train/valid) -> 기존 Ground Truth Passage를 같이 반환합니다.
Ground Truth가 없는 Query (test) -> Retrieval한 Passage만 반환합니다.
"""

assert self.p_embedding is not None, "get_sparse_embedding() 메소드를 먼저 수행해줘야합니다."

if isinstance(query_or_dataset, str):
doc_scores, doc_indices = self.get_relevant_doc(query_or_dataset, k=topk)
print("[Search query]\n", query_or_dataset, "\n")

for i in range(topk):
print(f"Top-{i+1} passage with score {doc_scores[i]:4f}")
print(self.contexts[doc_indices[i]])

return (doc_scores, [self.contexts[doc_indices[i]] for i in range(topk)])

elif isinstance(query_or_dataset, Dataset):

# Retrieve한 Passage를 pd.DataFrame으로 반환합니다.
total = []
with timer("query exhaustive search"):
doc_scores, doc_indices = self.get_relevant_doc_bulk(
query_or_dataset["question"], k=topk
)
for idx, example in enumerate(
tqdm(query_or_dataset, desc="Sparse retrieval: ")
):
tmp = {
# Query와 해당 id를 반환합니다.
"question": example["question"],
"id": example["id"],
# Retrieve한 Passage의 id, context를 반환합니다.
"context": " ".join(
[self.contexts[pid] for pid in doc_indices[idx]]
),
}
if "context" in example.keys() and "answers" in example.keys():
# validation 데이터를 사용하면 ground_truth context와 answer도 반환합니다.
tmp["original_context"] = example["context"]
tmp["answers"] = example["answers"]
total.append(tmp)

cqas = pd.DataFrame(total)
return cqas


def get_relevant_doc(self, query, k=10, args=None, p_encoder=None, q_encoder=None):

with torch.no_grad():
Expand Down

0 comments on commit 69c25e6

Please sign in to comment.