From 3036559081a58b2e459ca5edc7f91ba0db9e2d7d Mon Sep 17 00:00:00 2001 From: Jongwon Lee Date: Mon, 19 Apr 2021 19:26:30 +0900 Subject: [PATCH 1/4] init --- .gitignore | 2 ++ evaluation.py | 2 +- model.py | 2 +- requirements.txt | 1 + train.py | 35 +++++++++++++++++------------------ utils/ckpt_utils.py | 2 +- utils/data_utils.py | 8 ++++---- 7 files changed, 27 insertions(+), 25 deletions(-) diff --git a/.gitignore b/.gitignore index f393884..172c75b 100644 --- a/.gitignore +++ b/.gitignore @@ -8,3 +8,5 @@ tempo/ qualitative/ outputs/ *.ipynb +raw_data/ +.idea/ diff --git a/evaluation.py b/evaluation.py index 8ccbb5e..4cb685c 100644 --- a/evaluation.py +++ b/evaluation.py @@ -7,7 +7,7 @@ from utils.data_utils import prepare_dataset, MultiWozDataset from utils.data_utils import make_slot_meta, domain2id, OP_SET, make_turn_label, postprocessing from utils.eval_utils import compute_prf, compute_acc, per_domain_join_accuracy -from pytorch_transformers import BertTokenizer, BertConfig +from transformers import BertTokenizer, BertConfig from model import SomDST import torch.nn as nn diff --git a/model.py b/model.py index 7665c91..2c42097 100644 --- a/model.py +++ b/model.py @@ -6,7 +6,7 @@ import torch import torch.nn as nn -from pytorch_transformers.modeling_bert import BertPreTrainedModel, BertModel +from transformers import BertPreTrainedModel, BertModel class SomDST(BertPreTrainedModel): diff --git a/requirements.txt b/requirements.txt index 4c538d6..bad0946 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ pytorch-transformers==1.0.0 torch==1.3.0a0+24ae9b5 wget==3.2 +transformers \ No newline at end of file diff --git a/train.py b/train.py index fc0c021..ed2af2f 100644 --- a/train.py +++ b/train.py @@ -4,24 +4,23 @@ MIT license """ -from model import SomDST -from pytorch_transformers import BertTokenizer, AdamW, WarmupLinearSchedule, BertConfig -from utils.data_utils import prepare_dataset, MultiWozDataset -from utils.data_utils import make_slot_meta, domain2id, OP_SET, make_turn_label, postprocessing -from utils.eval_utils import compute_prf, compute_acc, per_domain_join_accuracy -from utils.ckpt_utils import download_ckpt, convert_ckpt_compatible -from evaluation import model_evaluation +import argparse +import json +import os +import random +import numpy as np import torch import torch.nn as nn -from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler -import numpy as np -import argparse -import random -import os -import json -import time +from torch.utils.data import DataLoader, RandomSampler +from transformers import BertTokenizer, AdamW, BertConfig +from transformers.optimization import get_linear_schedule_with_warmup +from evaluation import model_evaluation +from model import SomDST +from utils.ckpt_utils import download_ckpt +from utils.data_utils import make_slot_meta, domain2id, OP_SET +from utils.data_utils import prepare_dataset, MultiWozDataset device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') @@ -127,13 +126,13 @@ def worker_init_fn(worker_id): ] enc_optimizer = AdamW(enc_optimizer_grouped_parameters, lr=args.enc_lr) - enc_scheduler = WarmupLinearSchedule(enc_optimizer, int(num_train_steps * args.enc_warmup), - t_total=num_train_steps) + enc_scheduler = get_linear_schedule_with_warmup(enc_optimizer, int(num_train_steps * args.enc_warmup), + t_total=num_train_steps) dec_param_optimizer = list(model.decoder.parameters()) dec_optimizer = AdamW(dec_param_optimizer, lr=args.dec_lr) - dec_scheduler = WarmupLinearSchedule(dec_optimizer, int(num_train_steps * args.dec_warmup), - t_total=num_train_steps) + dec_scheduler = get_linear_schedule_with_warmup(dec_optimizer, int(num_train_steps * args.dec_warmup), + t_total=num_train_steps) if n_gpu > 1: model = torch.nn.DataParallel(model) diff --git a/utils/ckpt_utils.py b/utils/ckpt_utils.py index 6902d38..04ef767 100644 --- a/utils/ckpt_utils.py +++ b/utils/ckpt_utils.py @@ -1,7 +1,7 @@ import wget import os import torch -from pytorch_transformers import BertForPreTraining, BertConfig +from transformers import BertForPreTraining, BertConfig BERT_PRETRAINED_MODEL_ARCHIVE_MAP = { diff --git a/utils/data_utils.py b/utils/data_utils.py index b62806c..de94c7f 100644 --- a/utils/data_utils.py +++ b/utils/data_utils.py @@ -154,7 +154,7 @@ def prepare_dataset(data_path, tokenizer, slot_meta, domain_counter[domain] += 1 dialog_history = [] - last_dialog_state = {} + previous_dialog_state = {} # previous dialog_state last_uttr = "" for ti, turn in enumerate(dial_dict["dialogue"]): turn_domain = turn["domain"] @@ -166,7 +166,7 @@ def prepare_dataset(data_path, tokenizer, slot_meta, turn_dialog_state = fix_general_label_error(turn["belief_state"], False, slot_meta) last_uttr = turn_uttr - op_labels, generate_y, gold_state = make_turn_label(slot_meta, last_dialog_state, + op_labels, generate_y, gold_state = make_turn_label(slot_meta, previous_dialog_state, turn_dialog_state, tokenizer, op_code) if (ti + 1) == len(dial_dict["dialogue"]): @@ -176,12 +176,12 @@ def prepare_dataset(data_path, tokenizer, slot_meta, instance = TrainingInstance(dial_dict["dialogue_idx"], turn_domain, turn_id, turn_uttr, ' '.join(dialog_history[-n_history:]), - last_dialog_state, op_labels, + previous_dialog_state, op_labels, generate_y, gold_state, max_seq_length, slot_meta, is_last_turn, op_code=op_code) instance.make_instance(tokenizer) data.append(instance) - last_dialog_state = turn_dialog_state + previous_dialog_state = turn_dialog_state return data From 8d898ab05e69c96f1efcb13bbc02c46e06edb356 Mon Sep 17 00:00:00 2001 From: Jongwon Lee Date: Mon, 19 Apr 2021 20:21:04 +0900 Subject: [PATCH 2/4] doc: add comments --- utils/data_utils.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/utils/data_utils.py b/utils/data_utils.py index de94c7f..d32a71a 100644 --- a/utils/data_utils.py +++ b/utils/data_utils.py @@ -253,16 +253,16 @@ def make_instance(self, tokenizer, max_seq_length=None, t = tokenizer.tokenize(' '.join(k)) t.extend(['-', '[NULL]']) state.extend(t) - avail_length_1 = max_seq_length - len(state) - 3 + avail_length_1 = max_seq_length - len(state) - 3 # 1 * CLS + 2 * SEP diag_1 = tokenizer.tokenize(self.dialog_history) diag_2 = tokenizer.tokenize(self.turn_utter) avail_length = avail_length_1 - len(diag_2) - if len(diag_1) > avail_length: # truncated + if len(diag_1) > avail_length: # truncates diag_1 (dialog_history) avail_length = len(diag_1) - avail_length diag_1 = diag_1[avail_length:] - if len(diag_1) == 0 and len(diag_2) > avail_length_1: + if len(diag_1) == 0 and len(diag_2) > avail_length_1: # truncates diag_2 (turn_utter) avail_length = len(diag_2) - avail_length_1 diag_2 = diag_2[avail_length:] @@ -277,6 +277,7 @@ def make_instance(self, tokenizer, max_seq_length=None, drop_mask = np.array(drop_mask) word_drop = np.random.binomial(drop_mask.astype('int64'), word_dropout) diag = [w if word_drop[i] == 0 else '[UNK]' for i, w in enumerate(diag)] + input_ = diag + state segment = segment + [1]*len(state) self.input_ = input_ @@ -284,7 +285,7 @@ def make_instance(self, tokenizer, max_seq_length=None, self.segment_id = segment slot_position = [] for i, t in enumerate(self.input_): - if t == slot_token: + if t == slot_token: # delimiter "[SLOT]" slot_position.append(i) self.slot_position = slot_position @@ -292,6 +293,8 @@ def make_instance(self, tokenizer, max_seq_length=None, self.input_id = tokenizer.convert_tokens_to_ids(self.input_) if len(input_mask) < max_seq_length: self.input_id = self.input_id + [0] * (max_seq_length-len(input_mask)) + + # 000 (history), 11111 (turn_utter), 00000 (mask) self.segment_id = self.segment_id + [0] * (max_seq_length-len(input_mask)) input_mask = input_mask + [0] * (max_seq_length-len(input_mask)) From 81b904b6bfad8f1d9027e583e7292947394edbf8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=EC=9D=B4=EC=A2=85=EC=9B=90/AI=EA=B0=9C=EB=B0=9C=EA=B7=B8?= =?UTF-8?q?=EB=A3=B9=28=EB=AC=B4=EC=84=A0=29/Staff=20Engineer/=EC=82=BC?= =?UTF-8?q?=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Tue, 20 Apr 2021 21:09:44 +0900 Subject: [PATCH 3/4] update for strict loading --- model.py | 3 ++- requirements.txt | 5 ++--- train.py | 2 +- utils/ckpt_utils.py | 2 +- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/model.py b/model.py index 2c42097..b880d08 100644 --- a/model.py +++ b/model.py @@ -15,7 +15,8 @@ def __init__(self, config, n_op, n_domain, update_id, exclude_domain=False): self.hidden_size = config.hidden_size self.encoder = Encoder(config, n_op, n_domain, update_id, exclude_domain) self.decoder = Decoder(config, self.encoder.bert.embeddings.word_embeddings.weight) - self.apply(self.init_weights) + # self.apply(self.init_weights) + self.init_weights() def forward(self, input_ids, token_type_ids, state_positions, attention_mask, diff --git a/requirements.txt b/requirements.txt index bad0946..1e470d9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,3 @@ -pytorch-transformers==1.0.0 -torch==1.3.0a0+24ae9b5 -wget==3.2 +torch +wget transformers \ No newline at end of file diff --git a/train.py b/train.py index ed2af2f..370e71c 100644 --- a/train.py +++ b/train.py @@ -108,7 +108,7 @@ def worker_init_fn(worker_id): args.bert_ckpt_path = download_ckpt(args.bert_ckpt_path, args.bert_config_path, 'assets') ckpt = torch.load(args.bert_ckpt_path, map_location='cpu') - model.encoder.bert.load_state_dict(ckpt) + model.encoder.bert.load_state_dict(ckpt, strict=False) # re-initialize added special tokens ([SLOT], [NULL], [EOS]) model.encoder.bert.embeddings.word_embeddings.weight.data[1].normal_(mean=0.0, std=0.02) diff --git a/utils/ckpt_utils.py b/utils/ckpt_utils.py index 04ef767..a09a9c9 100644 --- a/utils/ckpt_utils.py +++ b/utils/ckpt_utils.py @@ -40,7 +40,7 @@ def convert_ckpt_compatible(ckpt_path, config_path): model_config = BertConfig.from_json_file(config_path) model = BertForPreTraining(model_config) - model.load_state_dict(ckpt) + model.load_state_dict(ckpt, strict=False) new_ckpt = model.bert.state_dict() return new_ckpt From ab51a40df4ed5e4f393f805eda6655a019fa5020 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=EC=9D=B4=EC=A2=85=EC=9B=90/AI=EA=B0=9C=EB=B0=9C=EA=B7=B8?= =?UTF-8?q?=EB=A3=B9=28=EB=AC=B4=EC=84=A0=29/Staff=20Engineer/=EC=82=BC?= =?UTF-8?q?=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Tue, 20 Apr 2021 21:42:59 +0900 Subject: [PATCH 4/4] fix: args for fn: get_linear_schedule --- train.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/train.py b/train.py index 370e71c..ad80d30 100644 --- a/train.py +++ b/train.py @@ -126,13 +126,15 @@ def worker_init_fn(worker_id): ] enc_optimizer = AdamW(enc_optimizer_grouped_parameters, lr=args.enc_lr) - enc_scheduler = get_linear_schedule_with_warmup(enc_optimizer, int(num_train_steps * args.enc_warmup), - t_total=num_train_steps) + enc_scheduler = get_linear_schedule_with_warmup(enc_optimizer, + num_warmup_steps=int(num_train_steps * args.enc_warmup), + num_training_steps=num_train_steps) dec_param_optimizer = list(model.decoder.parameters()) dec_optimizer = AdamW(dec_param_optimizer, lr=args.dec_lr) - dec_scheduler = get_linear_schedule_with_warmup(dec_optimizer, int(num_train_steps * args.dec_warmup), - t_total=num_train_steps) + dec_scheduler = get_linear_schedule_with_warmup(dec_optimizer, + num_warmup_steps=int(num_train_steps * args.dec_warmup), + num_training_steps=num_train_steps) if n_gpu > 1: model = torch.nn.DataParallel(model)