-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
66 lines (52 loc) · 1.89 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from speech_to_text import get_audio
import os
import torch
import random
import evaluate
import transformers
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional
from dataclasses import dataclass
from time import perf_counter
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset, disable_progress_bar
from transformers import (
AutoConfig,
AutoTokenizer,
AutoModelForSeq2SeqLM,
DataCollatorForSeq2Seq,
Seq2SeqTrainingArguments,
Seq2SeqTrainer,
EarlyStoppingCallback
)
#Loading finetuned model
class finetuned_model():
def __init__(self):
self.fine_tuned_model_checkpoint = './training/data/english_to_spanish/mt5-small_en-sp/checkpoint-4500'
self.tokenizer = AutoTokenizer.from_pretrained(self.fine_tuned_model_checkpoint)
self.model = AutoModelForSeq2SeqLM.from_pretrained(self.fine_tuned_model_checkpoint)
print('Model loaded')
#Defining the generator for translations
def generate_translation(model, tokenizer, example):
"""print out the source, target and predicted raw text."""
source = example
print(source)
input_ids = tokenizer(source)["input_ids"]
input_ids = torch.LongTensor(input_ids).view(1, -1).to(model.device)
print(input_ids)
generated_ids = model.generate(inputs=input_ids, max_length=20)
prediction = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
print('English source: ', source)
print('Spanish prediction: ', prediction)
flan_model=finetuned_model()
var='y'
while var=='y':
start=input('Are you ready to start? (Press y/n): ')
text_audio=get_audio('test')
generate_translation(flan_model.model, flan_model.tokenizer, text_audio)
var=input('Do you want to translate another phrase? (Press y/n): ')
print('End of translation')