Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

init #5

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,5 @@ tempo/
qualitative/
outputs/
*.ipynb
raw_data/
.idea/
2 changes: 1 addition & 1 deletion evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from utils.data_utils import prepare_dataset, MultiWozDataset
from utils.data_utils import make_slot_meta, domain2id, OP_SET, make_turn_label, postprocessing
from utils.eval_utils import compute_prf, compute_acc, per_domain_join_accuracy
from pytorch_transformers import BertTokenizer, BertConfig
from transformers import BertTokenizer, BertConfig

from model import SomDST
import torch.nn as nn
Expand Down
5 changes: 3 additions & 2 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import torch
import torch.nn as nn
from pytorch_transformers.modeling_bert import BertPreTrainedModel, BertModel
from transformers import BertPreTrainedModel, BertModel


class SomDST(BertPreTrainedModel):
Expand All @@ -15,7 +15,8 @@ def __init__(self, config, n_op, n_domain, update_id, exclude_domain=False):
self.hidden_size = config.hidden_size
self.encoder = Encoder(config, n_op, n_domain, update_id, exclude_domain)
self.decoder = Decoder(config, self.encoder.bert.embeddings.word_embeddings.weight)
self.apply(self.init_weights)
# self.apply(self.init_weights)
self.init_weights()

def forward(self, input_ids, token_type_ids,
state_positions, attention_mask,
Expand Down
6 changes: 3 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
pytorch-transformers==1.0.0
torch==1.3.0a0+24ae9b5
wget==3.2
torch
wget
transformers
39 changes: 20 additions & 19 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,23 @@
MIT license
"""

from model import SomDST
from pytorch_transformers import BertTokenizer, AdamW, WarmupLinearSchedule, BertConfig
from utils.data_utils import prepare_dataset, MultiWozDataset
from utils.data_utils import make_slot_meta, domain2id, OP_SET, make_turn_label, postprocessing
from utils.eval_utils import compute_prf, compute_acc, per_domain_join_accuracy
from utils.ckpt_utils import download_ckpt, convert_ckpt_compatible
from evaluation import model_evaluation
import argparse
import json
import os
import random

import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
import numpy as np
import argparse
import random
import os
import json
import time
from torch.utils.data import DataLoader, RandomSampler
from transformers import BertTokenizer, AdamW, BertConfig
from transformers.optimization import get_linear_schedule_with_warmup

from evaluation import model_evaluation
from model import SomDST
from utils.ckpt_utils import download_ckpt
from utils.data_utils import make_slot_meta, domain2id, OP_SET
from utils.data_utils import prepare_dataset, MultiWozDataset

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

Expand Down Expand Up @@ -109,7 +108,7 @@ def worker_init_fn(worker_id):
args.bert_ckpt_path = download_ckpt(args.bert_ckpt_path, args.bert_config_path, 'assets')

ckpt = torch.load(args.bert_ckpt_path, map_location='cpu')
model.encoder.bert.load_state_dict(ckpt)
model.encoder.bert.load_state_dict(ckpt, strict=False)

# re-initialize added special tokens ([SLOT], [NULL], [EOS])
model.encoder.bert.embeddings.word_embeddings.weight.data[1].normal_(mean=0.0, std=0.02)
Expand All @@ -127,13 +126,15 @@ def worker_init_fn(worker_id):
]

enc_optimizer = AdamW(enc_optimizer_grouped_parameters, lr=args.enc_lr)
enc_scheduler = WarmupLinearSchedule(enc_optimizer, int(num_train_steps * args.enc_warmup),
t_total=num_train_steps)
enc_scheduler = get_linear_schedule_with_warmup(enc_optimizer,
num_warmup_steps=int(num_train_steps * args.enc_warmup),
num_training_steps=num_train_steps)

dec_param_optimizer = list(model.decoder.parameters())
dec_optimizer = AdamW(dec_param_optimizer, lr=args.dec_lr)
dec_scheduler = WarmupLinearSchedule(dec_optimizer, int(num_train_steps * args.dec_warmup),
t_total=num_train_steps)
dec_scheduler = get_linear_schedule_with_warmup(dec_optimizer,
num_warmup_steps=int(num_train_steps * args.dec_warmup),
num_training_steps=num_train_steps)

if n_gpu > 1:
model = torch.nn.DataParallel(model)
Expand Down
4 changes: 2 additions & 2 deletions utils/ckpt_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import wget
import os
import torch
from pytorch_transformers import BertForPreTraining, BertConfig
from transformers import BertForPreTraining, BertConfig


BERT_PRETRAINED_MODEL_ARCHIVE_MAP = {
Expand Down Expand Up @@ -40,7 +40,7 @@ def convert_ckpt_compatible(ckpt_path, config_path):

model_config = BertConfig.from_json_file(config_path)
model = BertForPreTraining(model_config)
model.load_state_dict(ckpt)
model.load_state_dict(ckpt, strict=False)
new_ckpt = model.bert.state_dict()

return new_ckpt
19 changes: 11 additions & 8 deletions utils/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def prepare_dataset(data_path, tokenizer, slot_meta,
domain_counter[domain] += 1

dialog_history = []
last_dialog_state = {}
previous_dialog_state = {} # previous dialog_state
last_uttr = ""
for ti, turn in enumerate(dial_dict["dialogue"]):
turn_domain = turn["domain"]
Expand All @@ -166,7 +166,7 @@ def prepare_dataset(data_path, tokenizer, slot_meta,
turn_dialog_state = fix_general_label_error(turn["belief_state"], False, slot_meta)
last_uttr = turn_uttr

op_labels, generate_y, gold_state = make_turn_label(slot_meta, last_dialog_state,
op_labels, generate_y, gold_state = make_turn_label(slot_meta, previous_dialog_state,
turn_dialog_state,
tokenizer, op_code)
if (ti + 1) == len(dial_dict["dialogue"]):
Expand All @@ -176,12 +176,12 @@ def prepare_dataset(data_path, tokenizer, slot_meta,

instance = TrainingInstance(dial_dict["dialogue_idx"], turn_domain,
turn_id, turn_uttr, ' '.join(dialog_history[-n_history:]),
last_dialog_state, op_labels,
previous_dialog_state, op_labels,
generate_y, gold_state, max_seq_length, slot_meta,
is_last_turn, op_code=op_code)
instance.make_instance(tokenizer)
data.append(instance)
last_dialog_state = turn_dialog_state
previous_dialog_state = turn_dialog_state
return data


Expand Down Expand Up @@ -253,16 +253,16 @@ def make_instance(self, tokenizer, max_seq_length=None,
t = tokenizer.tokenize(' '.join(k))
t.extend(['-', '[NULL]'])
state.extend(t)
avail_length_1 = max_seq_length - len(state) - 3
avail_length_1 = max_seq_length - len(state) - 3 # 1 * CLS + 2 * SEP
diag_1 = tokenizer.tokenize(self.dialog_history)
diag_2 = tokenizer.tokenize(self.turn_utter)
avail_length = avail_length_1 - len(diag_2)

if len(diag_1) > avail_length: # truncated
if len(diag_1) > avail_length: # truncates diag_1 (dialog_history)
avail_length = len(diag_1) - avail_length
diag_1 = diag_1[avail_length:]

if len(diag_1) == 0 and len(diag_2) > avail_length_1:
if len(diag_1) == 0 and len(diag_2) > avail_length_1: # truncates diag_2 (turn_utter)
avail_length = len(diag_2) - avail_length_1
diag_2 = diag_2[avail_length:]

Expand All @@ -277,21 +277,24 @@ def make_instance(self, tokenizer, max_seq_length=None,
drop_mask = np.array(drop_mask)
word_drop = np.random.binomial(drop_mask.astype('int64'), word_dropout)
diag = [w if word_drop[i] == 0 else '[UNK]' for i, w in enumerate(diag)]

input_ = diag + state
segment = segment + [1]*len(state)
self.input_ = input_

self.segment_id = segment
slot_position = []
for i, t in enumerate(self.input_):
if t == slot_token:
if t == slot_token: # delimiter "[SLOT]"
slot_position.append(i)
self.slot_position = slot_position

input_mask = [1] * len(self.input_)
self.input_id = tokenizer.convert_tokens_to_ids(self.input_)
if len(input_mask) < max_seq_length:
self.input_id = self.input_id + [0] * (max_seq_length-len(input_mask))

# 000 (history), 11111 (turn_utter), 00000 (mask)
self.segment_id = self.segment_id + [0] * (max_seq_length-len(input_mask))
input_mask = input_mask + [0] * (max_seq_length-len(input_mask))

Expand Down