Skip to content

Commit

Permalink
temp: 백업용 임시 저장
Browse files Browse the repository at this point in the history
  • Loading branch information
kooqooo committed Feb 18, 2024
1 parent 620a5c7 commit 09f482d
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 34 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ cython_debug/
wandb
code/assets/*
code/prediction/*
code/models/*
*.ckpt
*.pt
**.csv
Expand Down
12 changes: 11 additions & 1 deletion code/inference_dpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,18 @@ def run_dense_retrieval(
# retriever = SparseRetrieval(
# tokenize_fn=tokenize_fn, data_path=data_path, context_path=context_path
# )
training_args.model_name_or_path = 'klue/bert-base'
args = TrainingArguments(
output_dir="dense_retireval",
evaluation_strategy="epoch",
learning_rate=1e-5,
per_device_train_batch_size=32,
per_device_eval_batch_size=32,
num_train_epochs=5,
weight_decay=0.01
)
retriever = DenseRetrieval(
args=training_args,
args=args,
dataset=datasets,
num_neg=3,
tokenizer=tokenizer,
Expand Down
86 changes: 53 additions & 33 deletions code/retrieval_dpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@
import time
import pickle
from contextlib import contextmanager
from collections import deque

import numpy as np
import pandas as pd
from tqdm.auto import tqdm
from pprint import pprint
from pyprnt import prnt
from datasets import Dataset, concatenate_datasets, load_from_disk
from datasets import Dataset, DatasetDict, concatenate_datasets, load_from_disk

import torch
from torch.utils.data import DataLoader, TensorDataset
Expand Down Expand Up @@ -38,19 +39,26 @@ def timer(name):

class DenseRetrieval:

def __init__(self, args, dataset, num_neg, tokenizer, p_encoder, q_encoder):
def __init__(self, args, dataset, num_neg, tokenizer, p_encoder, q_encoder, num_sample=-1):

'''
학습과 추론에 사용될 여러 셋업을 마쳐봅시다.
'''

self.args = args
self.dataset = dataset
if isinstance(dataset, DatasetDict):
self.dataset = concatenate_datasets(
[dataset["train"].flatten_indices(),
dataset["validation"].flatten_indices()])

# self.train_dataset = dataset[:num_sample]
# self.valid_dataset = dataset[num_sample]
self.num_neg = num_neg

self.tokenizer = tokenizer
self.p_encoder = p_encoder
self.q_encoder = q_encoder
self.pwd = os.getcwd()
self.save_path = os.path.join(self.pwd, '/models/dpr')

self.prepare_in_batch_negative(num_neg=num_neg)

Expand Down Expand Up @@ -102,7 +110,7 @@ def prepare_in_batch_negative(self, dataset=None, num_neg=3, tokenizer=None):
self.passage_dataloader = DataLoader(passage_dataset, batch_size=self.args.per_device_train_batch_size)


def train(self, args=None):
def train(self, args=None, override=False, num_pre_batch=0):

if args is None:
args = self.args
Expand All @@ -128,9 +136,8 @@ def train(self, args=None):
torch.cuda.empty_cache()

train_iterator = tqdm(range(int(args.num_train_epochs)), desc="Epoch")
# for _ in range(int(args.num_train_epochs)):
for _ in train_iterator:

p_queue = deque(maxlen=num_pre_batch)
with tqdm(self.train_dataloader, unit="batch") as tepoch:
for batch in tepoch:

Expand All @@ -140,11 +147,19 @@ def train(self, args=None):
targets = torch.zeros(batch_size).long() # positive example은 전부 첫 번째에 위치하므로
targets = targets.to(args.device)

p_inputs = {
'input_ids': batch[0].view(batch_size * (self.num_neg + 1), -1).to(args.device),
'attention_mask': batch[1].view(batch_size * (self.num_neg + 1), -1).to(args.device),
'token_type_ids': batch[2].view(batch_size * (self.num_neg + 1), -1).to(args.device)
}
if num_pre_batch is not 0: # In-batch or Pre-batch
p_inputs = {
'input_ids': batch[0].to(args.device),
'attention_mask': batch[1].to(args.device),
'token_type_ids': batch[2].to(args.device)
}

else: # negtive sampling
p_inputs = {
'input_ids': batch[0].view(batch_size * (self.num_neg + 1), -1).to(args.device),
'attention_mask': batch[1].view(batch_size * (self.num_neg + 1), -1).to(args.device),
'token_type_ids': batch[2].view(batch_size * (self.num_neg + 1), -1).to(args.device)
}

q_inputs = {
'input_ids': batch[3].to(args.device),
Expand All @@ -154,42 +169,47 @@ def train(self, args=None):

p_outputs = self.p_encoder(**p_inputs) # (batch_size*(num_neg+1), emb_dim)
q_outputs = self.q_encoder(**q_inputs) # (batch_size*, emb_dim)
if num_pre_batch is not 0: # In-batch or Pre-batch
temp = p_outputs.clone().detach()
p_outputs = torch.cat((p_outputs, *p_queue), dim=0)
p_queue.append(temp)

# Calculate similarity score & loss
sim_scores = torch.matmul(q_outputs, torch.transpose(p_outputs, 0, 1)).squeeze()
sim_scores = F.log_softmax(sim_scores, dim=1)
loss = F.nll_loss(sim_scores, targets)
tepoch.set_postfix(loss=f'{str(loss.item())}')

# Calculate similarity score & loss
p_outputs = p_outputs.view(batch_size, self.num_neg + 1, -1)
q_outputs = q_outputs.view(batch_size, 1, -1)
else: # negative sampling
# Calculate similarity score & loss
p_outputs = p_outputs.view(batch_size, self.num_neg + 1, -1)
q_outputs = q_outputs.view(batch_size, 1, -1)

sim_scores = torch.bmm(q_outputs, torch.transpose(p_outputs, 1, 2)).squeeze() #(batch_size, num_neg + 1)
sim_scores = sim_scores.view(batch_size, -1)
sim_scores = F.log_softmax(sim_scores, dim=1)
sim_scores = torch.bmm(q_outputs, torch.transpose(p_outputs, 1, 2)).squeeze() #(batch_size, num_neg + 1)
sim_scores = sim_scores.view(batch_size, -1)
sim_scores = F.log_softmax(sim_scores, dim=1)

loss = F.nll_loss(sim_scores, targets)
tepoch.set_postfix(loss=f'{str(loss.item())}')
loss = F.nll_loss(sim_scores, targets)
tepoch.set_postfix(loss=f'{str(loss.item())}')

loss.backward()
optimizer.step()
scheduler.step()

self.p_encoder.zero_grad()
self.q_encoder.zero_grad()

global_step += 1

torch.cuda.empty_cache()

del p_inputs, q_inputs

if not os.path.exists(self.save_path):
os.makedirs(self.save_path)
torch.save(self.p_encoder.state_dict(), os.path.join(self.save_path, 'p_encoder_state_dict'))
torch.save(self.q_encoder.state_dict(), os.path.join(self.save_path, 'q_encoder_state_dict'))
print('encoder statedict saved')


def get_relevant_doc(self, query, k=1, args=None, p_encoder=None, q_encoder=None):

if args is None:
args = self.args

if p_encoder is None:
p_encoder = self.p_encoder

if q_encoder is None:
q_encoder = self.q_encoder
def get_relevant_doc(self, query, k=10, args=None, p_encoder=None, q_encoder=None):

with torch.no_grad():
p_encoder.eval()
Expand Down

0 comments on commit 09f482d

Please sign in to comment.