Skip to content

Commit

Permalink
add missing funcs
Browse files Browse the repository at this point in the history
  • Loading branch information
yana-xuyan committed May 24, 2022
1 parent af3b83d commit ed768a1
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 17 deletions.
47 changes: 36 additions & 11 deletions task2/data_utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,42 @@ def add_special_tokens_(tokenizer, model, decoder_only=False):
model.resize_token_embeddings(new_num_tokens=orig_num_tokens + num_added_tokens)


def get_lens(sents):
return [len(sent) for sent in sents]

def get_max_len(sents):
max_len = max([len(sent) for sent in sents])
return max_len

def pad_sents(sents, pad_token=0, max_len=512):
sents_padded = []
lens = get_lens(sents)
max_len = min(max(lens), max_len)
sents_padded = []
new_len = []
for i, l in enumerate(lens):
if l > max_len:
l = max_len
new_len.append(l)
sents_padded.append(sents[i][:l] + [pad_token] * (max_len - l))
return sents_padded, new_len

def sort_sents(sents, reverse=True):
sents.sort(key=(lambda s: len(s)), reverse=reverse)
return sents


def get_mask(sents, unmask_idx=1, mask_idx=0, max_len=512):
lens = get_lens(sents)
max_len = min(max(lens), max_len)
mask = []
for l in lens:
if l > max_len:
l = max_len
mask.append([unmask_idx] * l + [mask_idx] * (max_len - l))
return mask


def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=-100):
"""From fairseq"""
if target.dim() == lprobs.dim() - 1:
Expand Down Expand Up @@ -103,17 +139,6 @@ def decode_pred(pred: EvalPrediction) -> Tuple[List[str], List[str]]:
label_str = tokenizer.batch_decode(labels, skip_special_tokens=True)
pred_str = lmap(str.strip, pred_str)
label_str = lmap(str.strip, label_str)
# with open("save/test/preds.txt", "w") as f:
# for line in pred_str:
# f.write(line+"\n")
# with open("save/test/golds.txt", "w") as f:
# for line in label_str:
# f.write(line+"\n")
# print("-"*80)
# for p, g in zip(pred_str, label_str):
# print("gold", g)
# print("pred", p)
# input()
return pred_str, label_str

def summarization_metrics(pred: EvalPrediction) -> Dict:
Expand Down
19 changes: 13 additions & 6 deletions task2/wow_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,26 @@
import logging
import numpy as np
from tqdm import tqdm
from operator import itemgetter
from typing import Any, Dict, List, Optional, Tuple, Union, Iterator
from typing import Any, Dict, List

from parlai.core.dict import DictionaryAgent
from parlai.core.worlds import create_task

import torch
from torch.utils.data import Dataset, DataLoader
from torch.utils.data import Dataset
import transformers
from transformers import AutoTokenizer

from src.data_utils.utils import pad_sents, get_mask, pad_list_of_sents, get_list_of_mask
from src.data_utils.data_reader import getDataLoader
from src.data_utils.utils import pad_sents, get_mask


def getDataLoader(dataset, bsz, test=False):
shuffle=False if test else True
# prepare dataloader
loader = torch.utils.data.DataLoader(dataset=dataset,
batch_size=bsz,
shuffle=shuffle)
return loader

class DialogReader(Dataset):
def __init__(self,
Expand Down Expand Up @@ -383,7 +390,7 @@ def get_data_from_batch(batch, model_type="decoder_only"):
label_starts = torch.sum(response_masks, 1)
label_idxs = torch.sum(label_masks, 1)

return inputs, masks, kn_sent, kn_mask, topic, topic_masks, \
return inputs, masks, kn_sent, kn_mask, \
labels, label_masks, response_masks, label_starts, label_idxs, None

if __name__ == "__main__":
Expand Down

0 comments on commit ed768a1

Please sign in to comment.