Skip to content

Commit

Permalink
Major refactoring.
Browse files Browse the repository at this point in the history
  • Loading branch information
cschaefer26 committed Apr 29, 2021
1 parent d002e00 commit 89b7974
Show file tree
Hide file tree
Showing 26 changed files with 100 additions and 115 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,5 @@ dmypy.json

# Pyre type checker
.pyre/
/dp/checkpoints/
/dp/datasets/
2 changes: 0 additions & 2 deletions config.yaml → dp/configs/config.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@

paths:
train_file: /Users/cschaefe/datasets/nlp/de_us_phonemes_train.pkl # path train file (.pkl)
val_file: /Users/cschaefe/datasets/nlp/de_us_phonemes_val.pkl # optional, path to val file (.pkl)
checkpoint_dir: checkpoints # directory to store model checkpoints and tensorboard
data_dir: datasets # directory to store processed data

Expand Down
Empty file added dp/model/__init__.py
Empty file.
6 changes: 3 additions & 3 deletions dp/model.py → dp/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
import torch
import torch.nn as nn
from torch.nn import TransformerEncoderLayer, LayerNorm, TransformerEncoder, ModuleList
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence, pad_sequence
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

from dp.model_utils import get_dedup_tokens, make_len_mask, generate_square_subsequent_mask, PositionalEncoding
from dp.text import Preprocessor
from dp.model.utils import get_dedup_tokens, make_len_mask, generate_square_subsequent_mask, PositionalEncoding
from dp.preprocessing.text import Preprocessor


class Model(torch.nn.Module, ABC):
Expand Down
13 changes: 8 additions & 5 deletions dp/predictor.py → dp/model/predictor.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
from typing import Dict, Any, List, Tuple, Iterable
from typing import Dict, List, Tuple

import torch
from torch.nn.utils.rnn import pad_sequence

from dp.model import load_checkpoint
from dp.model_utils import get_len_util_stop
from dp.text import Preprocessor
from dp.model.model import load_checkpoint
from dp.model.utils import get_len_util_stop
from dp.preprocessing.text import Preprocessor
from dp.utils import batchify, get_sequence_prob




class Prediction:

def __init__(self,
Expand Down Expand Up @@ -124,4 +126,5 @@ def _predict_batch(self,
def from_checkpoint(cls, checkpoint_path: str, device='cpu') -> 'Predictor':
model, checkpoint = load_checkpoint(checkpoint_path, device=device)
preprocessor = checkpoint['preprocessor']
return Predictor(model=model, preprocessor=preprocessor)
return Predictor(model=model, preprocessor=preprocessor)

File renamed without changes.
9 changes: 4 additions & 5 deletions dp/phonemizer.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import re
from itertools import zip_longest
from typing import Dict, Union, Tuple, List, Set
from typing import Dict, Union, List, Set

from dp.model import load_checkpoint
from dp.predictor import Predictor, Prediction
from dp.text import Preprocessor
from dp.utils import get_sequence_prob
from dp.model.model import load_checkpoint
from dp.model.predictor import Predictor, Prediction
from dp.preprocessing.text import Preprocessor

DEFAULT_PUNCTUATION = '().,:?!/'

Expand Down
32 changes: 13 additions & 19 deletions preprocess.py → dp/preprocess.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,28 @@
from collections import Counter
from pathlib import Path
import argparse
from pathlib import Path
from random import Random
from typing import List, Tuple, Iterable

import tqdm

from dp.text import Preprocessor
from dp.utils import read_config, pickle_binary, unpickle_binary
from dp.preprocessing.text import Preprocessor
from dp.utils import read_config, pickle_binary


if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Preprocessing for DeepPhonemizer')
parser.add_argument('--config', '-c', default='config.yaml', help='Points to the config file.')
args = parser.parse_args()
config = read_config(args.config)
def preprocess(config_file: str,
train_data: List[Tuple[str, Iterable[str], Iterable[str]]],
val_data: List[Tuple[str, Iterable[str], Iterable[str]]] = None
) -> None:

config = read_config(config_file)
languages = set(config['preprocessing']['languages'])

train_file = config['paths']['train_file']
val_file = config['paths']['val_file']
print(f'Preprocessing, train data: with {len(train_data)} files.')

data_dir = Path(config['paths']['data_dir'])
data_dir.mkdir(parents=True, exist_ok=True)

print(f'Reading train data from {train_file}')
train_data = unpickle_binary(train_file)
train_data = [r for r in train_data if r[0] in languages]

if val_file is not None:
print(f'Reading val data from {val_file}')
val_data = unpickle_binary(val_file)
if val_data is not None:
val_data = [r for r in val_data if r[0] in languages]
else:
n_val = config['preprocessing']['n_val']
Expand All @@ -44,7 +38,7 @@
train_count = Counter()
val_count = Counter()

print('Processing data...')
print('Processing train data...')
train_dataset = []
for i, (lang, text, phonemes) in enumerate(tqdm.tqdm(train_data, total=len(train_data))):
tokens = preprocessor((lang, text, phonemes))
Expand Down
Empty file added dp/preprocessing/__init__.py
Empty file.
File renamed without changes.
35 changes: 10 additions & 25 deletions train.py → dp/train.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,16 @@
import pickle
import argparse
import random
from collections import Counter
from random import Random
from typing import List, Tuple, Dict, Iterable, Any
from dp.model.model import LstmModel, ForwardTransformer, AutoregressiveTransformer, load_checkpoint
from dp.preprocessing.text import Preprocessor
from dp.training.trainer import Trainer
from dp.utils import read_config

import tqdm
import torch
from dp.dataset import new_dataloader
from dp.model import LstmModel, ForwardTransformer, AutoregressiveTransformer, load_checkpoint
from dp.text import SequenceTokenizer, Preprocessor
from dp.trainer import Trainer
from dp.utils import read_config, pickle_binary

def train(config_file: str,
checkpoint_file: str = None) -> None:

if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Preprocessing for DeepPhonemizer.')
parser.add_argument('--config', '-c', default='config.yaml', help='Points to the config file.')
parser.add_argument('--checkpoint', '-cp', default=None, help='Points to the a model file to restore.')
parser.add_argument('--path', '-p', help='Points to the a file with data.')
args = parser.parse_args()

config = read_config(args.config)

if args.checkpoint:
print(f'Restoring model from checkpoint: {args.checkpoint}')
model, checkpoint = load_checkpoint(args.checkpoint)
config = read_config(config_file)
if checkpoint_file is not None:
print(f'Restoring model from checkpoint: {checkpoint_file}')
model, checkpoint = load_checkpoint(checkpoint_file)
model.train()
step = checkpoint['step']
print(f'Loaded model with step: {step}')
Expand Down
Empty file added dp/training/__init__.py
Empty file.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
15 changes: 7 additions & 8 deletions dp/trainer.py → dp/training/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,13 @@
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.tensorboard import SummaryWriter

from dp.dataset import new_dataloader
from dp.decorators import ignore_exception
from dp.losses import CrossEntropyLoss, CTCLoss
from dp.metrics import phoneme_error_rate, word_error
from dp.model import Model
from dp.model_utils import get_len_util_stop, trim_util_stop
from dp.predictor import Predictor
from dp.text import Preprocessor
from dp.training.dataset import new_dataloader
from dp.training.decorators import ignore_exception
from dp.training.losses import CrossEntropyLoss, CTCLoss
from dp.training.metrics import phoneme_error_rate, word_error
from dp.model.model import Model
from dp.model.utils import trim_util_stop
from dp.preprocessing.text import Preprocessor
from dp.utils import to_device, unpickle_binary


Expand Down
34 changes: 0 additions & 34 deletions predict.py

This file was deleted.

23 changes: 23 additions & 0 deletions run_prediction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import math

from dp.phonemizer import Phonemizer

if __name__ == '__main__':

checkpoint_path = 'checkpoints/best_model_no_optim.pt'
phonemizer = Phonemizer.from_checkpoint(checkpoint_path)

text = ['özdemir']

result = phonemizer.phonemise_list(text, lang='de')

for text, pred in result.predictions.items():
tokens, probs = pred.tokens, pred.token_probs
pred_decoded = phonemizer.predictor.phoneme_tokenizer.decode(
tokens, remove_special_tokens=False)
prob = math.exp(sum([math.log(p) for p in probs]))
for o, p in zip(pred_decoded, probs):
print(f'{o} {p}')
pred_decoded = ''.join(pred_decoded)
print(f'{text} {pred_decoded} | {prob}')

11 changes: 11 additions & 0 deletions run_preprocessing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from dp.preprocess import preprocess
from dp.utils import unpickle_binary

if __name__ == '__main__':

config_file = 'dp/configs/config.yaml'
train_data = unpickle_binary('/Users/cschaefe/datasets/nlp/de_us_phonemes_train.pkl')
val_data = unpickle_binary('/Users/cschaefe/datasets/nlp/de_us_phonemes_val.pkl')

preprocess(config_file=config_file, train_data=train_data)

7 changes: 7 additions & 0 deletions run_training.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from dp.train import train

if __name__ == '__main__':

config_file = 'dp/configs/config.yaml'
train(config_file=config_file)

2 changes: 1 addition & 1 deletion tests/test_language_tokenizer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import unittest

from dp.text import LanguageTokenizer
from dp.preprocessing.text import LanguageTokenizer


class TestSequenceTokenizer(unittest.TestCase):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_metrics.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import unittest

from dp.metrics import word_error, phoneme_error_rate
from dp.training.metrics import word_error, phoneme_error_rate


class TestWordError(unittest.TestCase):
Expand Down
12 changes: 5 additions & 7 deletions tests/test_phonemizer.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
import unittest
from typing import Dict, Any, Tuple, List
from unittest.mock import Mock, _patch_object, patch
from typing import List
from unittest.mock import patch

import torch

from dp.phonemizer import PhonemizerResult, Phonemizer
from dp.predictor import Predictor, Prediction
from dp.text import Preprocessor
from dp.phonemizer import Phonemizer
from dp.model.predictor import Predictor, Prediction
from dp.preprocessing.text import Preprocessor


class PredictorMock:
Expand Down
8 changes: 4 additions & 4 deletions tests/test_predictor.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import unittest
from typing import Dict, Any, Tuple
from unittest.mock import Mock, patch
from unittest.mock import patch

import torch

from dp.model import Model
from dp.predictor import Predictor
from dp.text import Preprocessor
from dp.model.model import Model
from dp.model.predictor import Predictor
from dp.preprocessing.text import Preprocessor


class ModelMock:
Expand Down
2 changes: 1 addition & 1 deletion tests/test_sequence_tokenizer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import unittest

from dp.text import SequenceTokenizer
from dp.preprocessing.text import SequenceTokenizer


class TestSequenceTokenizer(unittest.TestCase):
Expand Down

0 comments on commit 89b7974

Please sign in to comment.