Skip to content

Commit

Permalink
Update model names and paths
Browse files Browse the repository at this point in the history
  • Loading branch information
noahzhy committed Dec 19, 2023
1 parent 3b1d842 commit c103905
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 33 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@ The pipeline is based on the paper "Towards efficient models for real-time deep
The `model/nsnet2` model structure is same as the original model.
The `model/nsnet2_ex` model is a modified version of the original model. Includes the preprocessing and postprocessing steps in the model, but excludes the FFT and IFFT processes.

In addition, the `model/tinySenet` model is a modified version of the original model. The model is implemented with a small number of parameters via tensorflow. Replace the original model GRU with a FastGRNN cell. The quantized tf-lite model get 0.067 ms inference time on Apple M2 chip.
In addition, the `model/tinyNSNet` model is a modified version of the original model. The model is implemented with a small number of parameters via tensorflow. Replace the original model GRU with a FastGRNN cell. The quantized tf-lite model get 0.067 ms inference time on Apple M2 chip.

## Inference Time

| Model | Platform | Inference Time |
| :---: | :---: | :---: |
| tinySenet | Apple M2 | 0.067 ms |
| tinyNSNet | Apple M2 | 0.067 ms |

## Attribution

Expand Down
2 changes: 1 addition & 1 deletion model/tflite_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@


# load model from tflite
interpreter = tf.lite.Interpreter(model_path='save/tinySenet.tflite')
interpreter = tf.lite.Interpreter(model_path='save/tinyNSNet.tflite')
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
Expand Down
16 changes: 8 additions & 8 deletions model/tinySenet.py → model/tinyNSNet.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

sys.path.append(os.path.join(os.path.dirname(__file__), '../'))
from model.rnn import *
from get_flops import try_count_flops
# from get_flops import try_count_flops


# using cpu only
Expand All @@ -17,9 +17,9 @@
hidden_size = 320


class TinySenet(models.Model):
def __init__(self, n_freq_bins=161, hidden_size = 320, rnn_dropout=0.2):
super(TinySenet, self).__init__()
class TinyNSNet(models.Model):
def __init__(self, n_freq_bins=161, hidden_size=320, rnn_dropout=0.2):
super(TinyNSNet, self).__init__()
self.n_freq_bins = n_freq_bins
self.dense0 = Dense(hidden_size, activation='relu')
self.rnn = models.Sequential([], name='rnn')
Expand Down Expand Up @@ -60,16 +60,16 @@ def build(self, input_shape):
x = self.dense2(x)
x = self.dense3(x)
x = self.limitGain(x, inputs)
return models.Model(inputs=inputs, outputs=x, name='tinySenet')
return models.Model(inputs=inputs, outputs=x, name='tinyNSNet')

def call(self, x):
return self.build(x.shape)


# main
if __name__ == '__main__':
model = TinySenet().build((161, 1))
model.save('save/tinySenet.h5')
model = TinyNSNet().build((161, 1))
model.save('save/tinyNSNet.h5')
model.summary()

# count flops
Expand All @@ -90,5 +90,5 @@ def call(self, x):
]

tflite_model = converter.convert()
with open('save/tinySenet.tflite', 'wb') as f:
with open('save/tinyNSNet.tflite', 'wb') as f:
f.write(tflite_model)
57 changes: 35 additions & 22 deletions train.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,43 @@
import os
import torch
import os, sys, time, glob, random

from model.nsnet2 import NSNet2
import tensorflow as tf
from tensorflow.keras import optimizers
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, TensorBoard
from tensorflow.keras.optimizers import *
from tensorflow.keras.losses import *


train_dir = './WAVs/dataset/training'
val_dir = './WAVs/dataset/validation'
from model.tinyNSNet import TinyNSNet

train_cfg = {
'train_dir': train_dir,
'val_dir': val_dir,
'batch_size': 64,
'alpha': 0.35,
}

model_cfg = {
'n_fft': 320,
'hop_len': 160,
'win_len': 320,
}

model = NSNet2(model_cfg)
# adamw
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
# MSE
criterion = torch.nn.MSELoss()
# main
if __name__ == '__main__':
train_dir = './WAVs/dataset/training'
val_dir = './WAVs/dataset/validation'

# train
train_cfg = {
'train_dir': train_dir,
'val_dir': val_dir,
'batch_size': 64,
'alpha': 0.35,
}

model_cfg = {
'n_fft': 320,
'hop_len': 160,
'win_len': 320,
}

model = TinyNSNet().build(input_shape=(161, 1))
# summary
model.summary()
quit()

# adamw
optimizer = Nadam(learning_rate=0.001, beta_1=0.9, beta_2=0.999, epsilon=1e-08)
# MSE
criterion = MSELoss()

# train
train(model, optimizer, criterion, train_cfg, model_cfg)
30 changes: 30 additions & 0 deletions train_torch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import os
import torch

from model.nsnet2 import NSNet2


train_dir = './WAVs/dataset/training'
val_dir = './WAVs/dataset/validation'

train_cfg = {
'train_dir': train_dir,
'val_dir': val_dir,
'batch_size': 64,
'alpha': 0.35,
}

model_cfg = {
'n_fft': 320,
'hop_len': 160,
'win_len': 320,
}

model = NSNet2(model_cfg)
# adamw
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
# MSE
criterion = torch.nn.MSELoss()

# train

0 comments on commit c103905

Please sign in to comment.