From 0270e2570d3870a9264c6eb4c76c92b8bf73097d Mon Sep 17 00:00:00 2001 From: Jeronymous Date: Tue, 16 May 2023 17:03:25 +0200 Subject: [PATCH 01/11] Punctuation service based on RecasePunc --- .env_default | 15 ++ .env_default_http | 8 - .env_default_task | 15 -- Dockerfile | 29 ++- README.md | 17 +- RELEASE.md | 3 + celery_app/register.py | 5 +- celery_app/tasks.py | 41 +--- docker-compose.yml | 2 +- docker-entrypoint.sh | 13 +- http_server/ingress.py | 45 ++--- http_server/serving.py | 23 +++ punctuation/recasepunc.py | 402 ++++++++++++++++++++++++++++++++++++++ requirements.txt | 20 +- 14 files changed, 501 insertions(+), 137 deletions(-) create mode 100644 .env_default delete mode 100644 .env_default_http delete mode 100644 .env_default_task create mode 100644 punctuation/recasepunc.py diff --git a/.env_default b/.env_default new file mode 100644 index 0000000..734405d --- /dev/null +++ b/.env_default @@ -0,0 +1,15 @@ +# SERVING PARAMETERS +SERVICE_MODE=task + +# SERVICE PARAMETERS +SERVICES_BROKER=redis://172.17.0.1:6379 +BROKER_PASS= + +# SERVICE DISCOVERY +SERVICE_NAME=linto-punctuation +LANGUAGE=fr-FR +# QUEUE_NAME=(Optionnal) +# MODEL_INFO=This model does something + +# CONCURRENCY +CONCURRENCY=2 \ No newline at end of file diff --git a/.env_default_http b/.env_default_http deleted file mode 100644 index 6aed7a9..0000000 --- a/.env_default_http +++ /dev/null @@ -1,8 +0,0 @@ -# SERVING PARAMETERS -SERVICE_MODE=http - -# SERVICE DISCOVERY -SERVICE_NAME=MY_PUNCTUATION_SERVICE - -# CONCURRENCY -CONCURRENCY=2 \ No newline at end of file diff --git a/.env_default_task b/.env_default_task deleted file mode 100644 index b60f87d..0000000 --- a/.env_default_task +++ /dev/null @@ -1,15 +0,0 @@ -# SERVING PARAMETERS -SERVICE_MODE=task - -# SERVICE PARAMETERS -SERVICES_BROKER=redis://192.168.0.1:6379 -BROKER_PASS=password - -# SERVICE DISCOVERY -SERVICE_NAME=MY_PUNCTUATION_SERVICE -LANGUAGE=en-US/fr-FR/* -QUEUE_NAME=(Optionnal) -MODEL_INFO=This model does something - -# CONCURRENCY -CONCURRENCY=2 \ No newline at end of file diff --git a/Dockerfile b/Dockerfile index 7b50705..e4de246 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,15 +1,13 @@ -FROM python:3.8 -LABEL maintainer="rbaraglia@linagora.com" -ENV PYTHONUNBUFFERED TRUE -ENV IMAGE_NAME linto-platform-diarization +FROM python:3.9 +LABEL maintainer="jlouradour@linagora.com" -RUN apt-get update \ - && apt-get install --no-install-recommends -y \ - ca-certificates \ - g++ \ - openjdk-11-jre-headless \ - curl \ - wget +RUN apt-get update && \ + apt-get install -y --no-install-recommends \ + ca-certificates \ + g++ \ + openjdk-11-jre-headless \ + curl \ + wget # Rust compiler for tokenizers RUN curl https://sh.rustup.rs -sSf | bash -s -- -y @@ -18,8 +16,8 @@ ENV PATH="/root/.cargo/bin:${PATH}" WORKDIR /usr/src/app # Python dependencies -COPY requirements.txt ./ -RUN pip install --no-cache-dir -r requirements.txt +COPY requirements.txt . +RUN pip3 install --no-cache-dir -r requirements.txt -f https://download.pytorch.org/whl/torch_stable.html # Supervisor COPY celery_app /usr/src/app/celery_app @@ -28,13 +26,8 @@ COPY document /usr/src/app/document COPY punctuation /usr/src/app/punctuation RUN mkdir /usr/src/app/model-store RUN mkdir -p /usr/src/app/tmp -COPY config.properties /usr/src/app/config.properties -COPY RELEASE.md ./ COPY docker-entrypoint.sh wait-for-it.sh healthcheck.sh ./ -# Grep CURRENT VERSION -RUN export VERSION=$(awk -v RS='' '/#/ {print; exit}' RELEASE.md | head -1 | sed 's/#//' | sed 's/ //') - ENV PYTHONPATH="${PYTHONPATH}:/usr/src/app/punctuation" HEALTHCHECK CMD ./healthcheck.sh diff --git a/README.md b/README.md index 05e8c51..ee896a4 100644 --- a/README.md +++ b/README.md @@ -22,7 +22,7 @@ LinTO-platform-punctuation can either be used as a standalone punctuation servic ## Pre-requisites ### Models -The punctuation service relies on a trained punctuation prediction model. +The punctuation service relies on a trained recasing and punctuation prediction model. We provide homebrew models on [dl.linto.ai](https://dl.linto.ai/downloads/model-distribution/punctuation_models/). @@ -52,13 +52,13 @@ docker pull registry.linto.ai/lintoai/linto-platform-punctuation:latest **2- Download the models** -Have the punctuation model (.mar) ready at MODEL_PATH. +Have the punctuation model ready at MODEL_PATH. ### HTTP **1- Fill the .env** ```bash -cp .env_default_http .env +cp .env_default .env ``` Fill the .env with your values. @@ -73,7 +73,7 @@ Fill the .env with your values. ```bash docker run --rm \ --v MODEL_PATH:/usr/src/app/model-store/punctuation.mar \ +-v MODEL_PATH:/usr/src/app/model-store/model \ -p HOST_SERVING_PORT:80 \ --env-file .env \ linto-platform-punctuation:latest @@ -90,7 +90,7 @@ You need a message broker up and running at MY_SERVICE_BROKER. Instance are typi **1- Fill the .env** ```bash -cp .env_default_task .env +cp .env_default .env ``` Fill the .env with your values. @@ -118,7 +118,7 @@ services: punctuation-service: image: linto-platform-punctuation:latest volumes: - - /my/path/to/models/punctuation.mar:/usr/src/app/model-store/punctuation.mar + - /my/path/to/models/punctuation.mar:/usr/src/app/model-store/model env_file: .env deploy: replicas: 1 @@ -156,7 +156,7 @@ The following information are registered: "service_language": $LANGUAGE, "queue_name": $QUEUE_NAME, "version": "1.2.0", # This repository's version - "info": "Bert Based Punctuation model for french punctuation prediction", + "info": "Punctuation model for french punctuation prediction", "last_alive": 65478213, "concurrency": 1 } @@ -223,3 +223,6 @@ curl -X POST "http://YOUR_SERVICE:YOUR_PORT/punctuation" -H "accept: applicatio ## License This project is developped under the AGPLv3 License (see LICENSE). + +## Acknowledgments +* [recasepunc](https://github.com/benob/recasepunc) Python library to train recasing and punctuation models, and to apply them (License BSD 3). \ No newline at end of file diff --git a/RELEASE.md b/RELEASE.md index de33f05..6848347 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -1,3 +1,6 @@ +# 2.0.0 +- Integration of recasepunc + # 1.1.0 - Added service registration - Updated README diff --git a/celery_app/register.py b/celery_app/register.py index 7100440..e343eeb 100644 --- a/celery_app/register.py +++ b/celery_app/register.py @@ -26,6 +26,7 @@ def register(is_heartbeat: bool = False) -> bool: """ host, port = os.environ.get("SERVICES_BROKER").split("//")[1].split(":") password = os.environ.get("BROKER_PASS", None) + if not password: password = None r = redis.Redis( host=host, port=int(port), db=SERVICE_DISCOVERY_DB, password=password ) @@ -59,8 +60,10 @@ def unregister() -> None: """Un-register the service""" try: host, port = os.environ.get("SERVICES_BROKER").split("//")[1].split(":") + password = os.environ.get("BROKER_PASS", None) + if not password: password = None r = redis.Redis( - host=host, port=int(port), db=SERVICE_DISCOVERY_DB, password="password" + host=host, port=int(port), db=SERVICE_DISCOVERY_DB, password=password ) r.json().delete(f"service:{host_name}") except Exception as error: diff --git a/celery_app/tasks.py b/celery_app/tasks.py index 9093020..366f4ac 100644 --- a/celery_app/tasks.py +++ b/celery_app/tasks.py @@ -1,51 +1,26 @@ -import json from typing import Union -import requests - from celery_app.celeryapp import celery +from punctuation.recasepunc import load_model, generate_predictions + +MODEL = load_model() @celery.task(name="punctuation_task", bind=True) def punctuation_task(self, text: Union[str, list]): """punctuation_task do a synchronous call to the punctuation serving API""" self.update_state(state="STARTED") - # Fetch model name - try: - result = requests.get( - "http://localhost:8081/models", - headers={ - "accept": "application/json", - }, - ) - models = json.loads(result.text) - model_name = models["models"][0]["modelName"] - except Exception as error: - raise Exception("Failed to fetch model name") from error + + unique = isinstance(text, str) - if isinstance(text, str): + if unique: sentences = [text] else: sentences = text - punctuated_sentences = [] - for i, sentence in enumerate(sentences): - self.update_state(state="STARTED", meta={"current": i, "total": len(sentences)}) - - result = requests.post( - f"http://localhost:8080/predictions/{model_name}", - headers={"content-type": "application/octet-stream"}, - data=sentence.strip().encode("utf-8"), - ) - if result.status_code == 200: - punctuated_sentence = result.text - else: - print("Failed to predict punctuation on sentence: >{sentence}<") - punctuated_sentence = sentence - punctuated_sentence = punctuated_sentence[0].upper() + punctuated_sentence[1:] - punctuated_sentences.append(punctuated_sentence) + punctuated_sentences = generate_predictions(MODEL, sentences) return ( punctuated_sentences[0] - if len(punctuated_sentences) == 1 + if unique else punctuated_sentences ) diff --git a/docker-compose.yml b/docker-compose.yml index 1b4dcc9..bce398d 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -4,7 +4,7 @@ services: punctuation-service: image: linto-platform-punctuation:latest volumes: - - /path/to/your/model.mar/usr/src/app/model-store/punctuation.mar + - /path/to/your/model.mar/usr/src/app/model-store/model env_file: .env deploy: replicas: 1 diff --git a/docker-entrypoint.sh b/docker-entrypoint.sh index 94fe59b..1aeaf02 100755 --- a/docker-entrypoint.sh +++ b/docker-entrypoint.sh @@ -2,15 +2,11 @@ echo "RUNNING service" -export VERSION=$(awk -v RS='' '/#/ {print; exit}' RELEASE.md | head -1 | sed 's/#//' | sed 's/ //') - if [ -z "$SERVICE_MODE" ] then echo "ERROR: Must specify a serving mode: [ http | task ]" exit -1 else - # Model serving - torchserve --start --ncs --ts-config /usr/src/app/config.properties if [ "$SERVICE_MODE" = "http" ] then echo "Running http server" @@ -19,25 +15,24 @@ else elif [ "$SERVICE_MODE" == "task" ] then echo "Running celery worker" - /usr/src/app/wait-for-it.sh $(echo $SERVICES_BROKER | cut -d'/' -f 3) --timeout=20 --strict -- echo " $SERVICES_BROKER (Service Broker) is up" + /usr/src/app/wait-for-it.sh $(echo $SERVICES_BROKER | cut -d'/' -f 3) --timeout=20 --strict -- echo " $SERVICES_BROKER (Service Broker) is up" || exit $? # MICRO SERVICE ## QUEUE NAME QUEUE=$(python -c "from celery_app.register import queue; exit(queue())" 2>&1) echo "Service set to $QUEUE" ## REGISTRATION - python -c "from celery_app.register import register; register()" + python -c "from celery_app.register import register; register()" # || exit $? echo "Service registered" ## WORKER - celery --app=celery_app.celeryapp worker -n punctuation_$SERVICE_NAME@%h --queues=$QUEUE -c $CONCURRENCY + celery --app=celery_app.celeryapp worker --pool=solo -n punctuation_$SERVICE_NAME@%h --queues=$QUEUE -c $CONCURRENCY ## UNREGISTERING - python -c "from celery_app.register import unregister; unregister()" + python -c "from celery_app.register import unregister; unregister()" || exit $? echo "Service unregistered" else echo "ERROR: Wrong serving command: $SERVICE_MODE" exit -1 fi - torchserve --stop fi diff --git a/http_server/ingress.py b/http_server/ingress.py index 66c40d0..0ea4c63 100644 --- a/http_server/ingress.py +++ b/http_server/ingress.py @@ -3,16 +3,22 @@ import json import logging -import requests +import time from confparser import createParser from flask import Flask, json, request -from serving import GunicornServing +from serving import GeventServing as Serving from swagger import setupSwaggerUI from punctuation import logger +from punctuation.recasepunc import load_model, generate_predictions app = Flask("__punctuation-worker__") +logger.info("Loading model") +tic = time.time() +MODEL = load_model() +logger.info("Model loaded in {}s".format(time.time() - tic)) + @app.route("/healthcheck", methods=["GET"]) def healthcheck(): @@ -38,35 +44,10 @@ def punctuate(): if not sentences: return "", 200 - # Fetch model name - try: - result = requests.get( - "http://localhost:8081/models", - headers={ - "accept": "application/json", - }, - ) - models = json.loads(result.text) - model_name = models["models"][0]["modelName"] - except: - raise Exception("Failed to fetch model name") - - punctuated_sentences = [] - for sentence in sentences: - result = requests.post( - "http://localhost:8080/predictions/{}".format(model_name), - headers={"content-type": "application/octet-stream"}, - data=sentence.strip().encode("utf-8"), - ) - if result.status_code == 200: - punctuated_sentence = result.text - # First letter in capital - punctuated_sentence = ( - punctuated_sentence[0].upper() + punctuated_sentence[1:] - ) - punctuated_sentences.append(punctuated_sentence) - else: - raise Exception(result.text) + tic = time.time() + punctuated_sentences = generate_predictions(MODEL, sentences) + logger.info("Prediction done in {}s".format(time.time() - tic)) + if return_json: return {"punctuated_sentences": punctuated_sentences}, 200 @@ -110,7 +91,7 @@ def server_error(error): except Exception as e: logger.warning("Could not setup swagger: {}".format(str(e))) - serving = GunicornServing( + serving = Serving( app, { "bind": f"0.0.0.0:{args.service_port}", diff --git a/http_server/serving.py b/http_server/serving.py index d2dd7e8..c450fe2 100644 --- a/http_server/serving.py +++ b/http_server/serving.py @@ -1,4 +1,7 @@ import gunicorn.app.base +import gevent.pywsgi +import gevent.monkey +gevent.monkey.patch_all() class GunicornServing(gunicorn.app.base.BaseApplication): @@ -18,3 +21,23 @@ def load_config(self): def load(self): return self.application + +class GeventServing(): + + def __init__(self, app, options=None): + self.options = options or {} + self.application = app + + def run(self): + bind = self.options.get('bind', "0.0.0.0:8080") + workers = self.options.get('workers', 1) + listener = bind.split(':') + try: + assert len(listener) == 2 + host, port = (listener[0], int(listener[1])) + except: + print(f"Invalid bind address {bind}") + + server = gevent.pywsgi.WSGIServer((host, port), self.application, spawn = workers, environ={'wsgi.multithread': True,'wsgi.multiprocess': True,}) + server.serve_forever() + diff --git a/punctuation/recasepunc.py b/punctuation/recasepunc.py new file mode 100644 index 0000000..9cadcfe --- /dev/null +++ b/punctuation/recasepunc.py @@ -0,0 +1,402 @@ +# coding=utf-8 + +"""recasepunc file.""" + +import argparse +import collections +import os +import random +import sys +import unicodedata + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +# from mosestokenizer import * +from tqdm import tqdm +from transformers import AutoModel, AutoTokenizer, BertTokenizer + +default_config = argparse.Namespace( + seed=871253, + lang='fr', + # flavor='flaubert/flaubert_base_uncased', + flavor=None, + max_length=256, + batch_size=16, + updates=24000, + period=1000, + lr=1e-5, + dab_rate=0.1, + device='cuda', + debug=False +) + +default_flavors = { + 'fr': 'flaubert/flaubert_base_uncased', + 'en': 'bert-base-uncased', + 'zh': 'ckiplab/bert-base-chinese', + 'it': 'dbmdz/bert-base-italian-uncased', +} + + +class Config(argparse.Namespace): + def __init__(self, **kwargs): + super().__init__() + for key, value in default_config.__dict__.items(): + setattr(self, key, value) + for key, value in kwargs.items(): + setattr(self, key, value) + + assert self.lang in ['fr', 'en', 'zh', 'it'] + + if 'lang' in kwargs and ('flavor' not in kwargs or kwargs['flavor'] is None): + self.flavor = default_flavors[self.lang] + + # print(self.lang, self.flavor) + + +def init_random(seed): + # make sure everything is deterministic + os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' + torch.use_deterministic_algorithms(True) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + random.seed(seed) + np.random.seed(seed) + + +# NOTE: it is assumed in the implementation that y[:,0] is the punctuation label, and y[:,1] is the case label! + +punctuation = { + 'O': 0, + 'COMMA': 1, + 'PERIOD': 2, + 'QUESTION': 3, + 'EXCLAMATION': 4, +} + +punctuation_syms = ['', ',', '.', ' ?', ' !'] + +case = { + 'LOWER': 0, + 'UPPER': 1, + 'CAPITALIZE': 2, + 'OTHER': 3, +} + + +class Model(nn.Module): + def __init__(self, flavor, device): + super().__init__() + self.bert = AutoModel.from_pretrained(flavor) + # need a proper way of determining representation size + size = self.bert.dim \ + if hasattr(self.bert, 'dim') else self.bert.config.pooler_fc_size \ + if hasattr(self.bert.config, 'pooler_fc_size') else self.bert.config.emb_dim \ + if hasattr(self.bert.config, 'emb_dim') else self.bert.config.hidden_size + self.punc = nn.Linear(size, 5) + self.case = nn.Linear(size, 4) + self.dropout = nn.Dropout(0.3) + self.to(device) + + def forward(self, x): + output = self.bert(x) + representations = self.dropout(F.gelu(output['last_hidden_state'])) + punc = self.punc(representations) + case = self.case(representations) + return punc, case + +def recase(token, label): + if label == case['LOWER']: + return token.lower() + if label == case['CAPITALIZE']: + return token.lower().capitalize() + if label == case['UPPER']: + return token.upper() + return token + +def load_model(checkpoint_path="/usr/src/app/model-store/model", config=None): + if config is None: + config = default_config + if not torch.cuda.is_available(): + config.device = 'cpu' + + loaded = torch.load(checkpoint_path, map_location=config.device) + if 'config' in loaded: + config = Config(**loaded['config']) + + if config.flavor is None: + config.flavor = default_flavors[config.lang] + + init(config) + + model = Model(config.flavor, config.device) + model.load_state_dict(loaded['model_state_dict']) + + config.model = model + + return config + + +def generate_predictions(config, line): + if isinstance(line, list): + return [generate_predictions(config, l) for l in line] + + model = config.model + + # also drop punctuation that we may generate + line = ''.join([c for c in line if c not in mapped_punctuation]) + output = '' + if config.debug: + print(line) + tokens = [config.cls_token] + config.tokenizer.tokenize(line) + [config.sep_token] + if config.debug: + print(tokens) + previous_label = punctuation['PERIOD'] + first_time = True + was_word = False + for start in range(0, len(tokens), config.max_length): + instance = tokens[start: start + config.max_length] + ids = config.tokenizer.convert_tokens_to_ids(instance) + if len(ids) < config.max_length: + ids += [config.pad_token_id] * (config.max_length - len(ids)) + x = torch.tensor([ids]).long().to(config.device) + y_scores1, y_scores2 = model(x) + y_pred1 = torch.max(y_scores1, 2)[1] + y_pred2 = torch.max(y_scores2, 2)[1] + for id, token, punc_label, case_label in zip(ids, instance, y_pred1[0].tolist()[:len(instance)], + y_pred2[0].tolist()[:len(instance)]): + if config.debug: + print(id, token, punc_label, case_label, file=sys.stderr) + if id in (config.cls_token_id, config.sep_token_id): + continue + if previous_label is not None and previous_label > 1: + if case_label in [case['LOWER'], case['OTHER']]: + case_label = case['CAPITALIZE'] + previous_label = punc_label + # different strategy due to sub-lexical token encoding in Flaubert + if config.lang == 'fr': + if token.endswith(''): + cased_token = recase(token[:-4], case_label) + if was_word: + output += ' ' + output += cased_token + punctuation_syms[punc_label] + was_word = True + else: + cased_token = recase(token, case_label) + if was_word: + output += ' ' + output += cased_token + was_word = False + else: + if token.startswith('##'): + cased_token = recase(token[2:], case_label) + output += cased_token + else: + cased_token = recase(token, case_label) + if not first_time: + output += ' ' + first_time = False + output += cased_token + punctuation_syms[punc_label] + if previous_label == 0: + output += '.' + return output + +mapped_punctuation = { + '.': 'PERIOD', + '...': 'PERIOD', + ',': 'COMMA', + ';': 'COMMA', + ':': 'COMMA', + '(': 'COMMA', + ')': 'COMMA', + '?': 'QUESTION', + '!': 'EXCLAMATION', + ',': 'COMMA', + '!': 'EXCLAMATION', + '?': 'QUESTION', + ';': 'COMMA', + ':': 'COMMA', + '(': 'COMMA', + '(': 'COMMA', + ')': 'COMMA', + '[': 'COMMA', + ']': 'COMMA', + '【': 'COMMA', + '】': 'COMMA', + '└': 'COMMA', + '└ ': 'COMMA', + '_': 'O', + '。': 'PERIOD', + '、': 'COMMA', # enumeration comma + '、': 'COMMA', + '…': 'PERIOD', + '—': 'COMMA', + '「': 'COMMA', + '」': 'COMMA', + '.': 'PERIOD', + '《': 'O', + '》': 'O', + ',': 'COMMA', + '“': 'O', + '”': 'O', + '"': 'O', + '-': 'O', + '-': 'O', + '〉': 'COMMA', + '〈': 'COMMA', + '↑': 'O', + '〔': 'COMMA', + '〕': 'COMMA', +} + +# modification of the wordpiece tokenizer to keep case information even if vocab is lower cased +# forked from https://github.com/huggingface/transformers/blob/master/src/transformers/models/bert/tokenization_bert.py + +class WordpieceTokenizer: + """Runs WordPiece tokenization.""" + + def __init__(self, vocab, unk_token, max_input_chars_per_word=100, keep_case=True): + self.vocab = vocab + self.unk_token = unk_token + self.max_input_chars_per_word = max_input_chars_per_word + self.keep_case = keep_case + + def tokenize(self, text): + """ + Tokenizes a piece of text into its word pieces. This uses a greedy longest-match-first algorithm to perform + tokenization using the given vocabulary. + For example, :obj:`input = "unaffable"` wil return as output :obj:`["un", "##aff", "##able"]`. + Args: + text: A single token or whitespace separated tokens. This should have + already been passed through `BasicTokenizer`. + Returns: + A list of wordpiece tokens. + """ + + output_tokens = [] + for token in text.strip().split(): + chars = list(token) + if len(chars) > self.max_input_chars_per_word: + output_tokens.append(self.unk_token) + continue + + is_bad = False + start = 0 + sub_tokens = [] + while start < len(chars): + end = len(chars) + cur_substr = None + while start < end: + substr = "".join(chars[start:end]) + if start > 0: + substr = "##" + substr + # optionaly lowercase substring before checking for inclusion in vocab + if (self.keep_case and substr.lower() in self.vocab) or (substr in self.vocab): + cur_substr = substr + break + end -= 1 + if cur_substr is None: + is_bad = True + break + sub_tokens.append(cur_substr) + start = end + + if is_bad: + output_tokens.append(self.unk_token) + else: + output_tokens.extend(sub_tokens) + return output_tokens + + +# modification of XLM bpe tokenizer for keeping case information when vocab is lowercase +# forked from https://github.com/huggingface/transformers/blob/cd56f3fe7eae4a53a9880e3f5e8f91877a78271c/src/transformers/models/xlm/tokenization_xlm.py +def bpe(self, token): + def to_lower(pair): + # print(' ',pair) + return (pair[0].lower(), pair[1].lower()) + + from transformers.models.xlm.tokenization_xlm import get_pairs + + word = tuple(token[:-1]) + (token[-1] + "",) + if token in self.cache: + return self.cache[token] + pairs = get_pairs(word) + + if not pairs: + return token + "" + + while True: + bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(to_lower(pair), float("inf"))) + # print(bigram) + if to_lower(bigram) not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + except ValueError: + new_word.extend(word[i:]) + break + else: + new_word.extend(word[i:j]) + i = j + + if word[i] == first and i < len(word) - 1 and word[i + 1] == second: + new_word.append(first + second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + pairs = get_pairs(word) + word = " ".join(word) + if word == "\n ": + word = "\n" + self.cache[token] = word + return word + + +def init(config): + init_random(config.seed) + + if config.lang == 'fr': + config.tokenizer = tokenizer = AutoTokenizer.from_pretrained(config.flavor, do_lower_case=False) + + from transformers.models.xlm.tokenization_xlm import XLMTokenizer + assert isinstance(tokenizer, XLMTokenizer) + + # monkey patch XLM tokenizer + import types + tokenizer.bpe = types.MethodType(bpe, tokenizer) + else: + # warning: needs to be BertTokenizer for monkey patching to work + config.tokenizer = tokenizer = BertTokenizer.from_pretrained(config.flavor, do_lower_case=False) + + # warning: monkey patch tokenizer to keep case information + # from recasing_tokenizer import WordpieceTokenizer + config.tokenizer.wordpiece_tokenizer = WordpieceTokenizer(vocab=tokenizer.vocab, unk_token=tokenizer.unk_token) + + if config.lang == 'fr': + config.pad_token_id = tokenizer.pad_token_id + config.cls_token_id = tokenizer.bos_token_id + config.cls_token = tokenizer.bos_token + config.sep_token_id = tokenizer.sep_token_id + config.sep_token = tokenizer.sep_token + else: + config.pad_token_id = tokenizer.pad_token_id + config.cls_token_id = tokenizer.cls_token_id + config.cls_token = tokenizer.cls_token + config.sep_token_id = tokenizer.sep_token_id + config.sep_token = tokenizer.sep_token + + if not torch.cuda.is_available() and config.device == 'cuda': + print('WARNING: reverting to cpu as cuda is not available', file=sys.stderr) + config.device = torch.device(config.device if torch.cuda.is_available() else 'cpu') + diff --git a/requirements.txt b/requirements.txt index 04ce3f5..3499eab 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,18 +1,12 @@ -captum>=0.3.1 celery[redis,auth,msgpack]>=4.4.7 -pandas>=1.1.5 flask>=1.1.2 flask-cors>=3.0.10 flask-swagger-ui>=3.36.0 -pyyaml>=5.4.1 +gevent>=22.10.2 # NOCOMMIT +waitress>=2.1.2 # NOCOMMIT gunicorn>=20.1.0 -numpy>=1.19.5 -sklearn -supervisor>=4.2.2 -transformers==3.0.2 -torch>=1.7.1 -torch-model-archiver>=0.3.0 -torchserve>=0.3.0 -torchtext>=0.8.1 -torchvision>=0.8.2 -redis \ No newline at end of file +git+https://github.com/benob/mosestokenizer.git +numpy==1.19.5 +regex==2021.8.28 +torch==1.9.0+cpu # NOCOMMIT : use cu111? +transformers==4.10.0 From 4cea20a63d472060f7079f36d588ec66e0ca9658 Mon Sep 17 00:00:00 2001 From: Jeronymous Date: Wed, 17 May 2023 08:45:28 +0200 Subject: [PATCH 02/11] prevent from removing hyphen from input --- punctuation/recasepunc.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/punctuation/recasepunc.py b/punctuation/recasepunc.py index 9cadcfe..480f139 100644 --- a/punctuation/recasepunc.py +++ b/punctuation/recasepunc.py @@ -242,8 +242,7 @@ def generate_predictions(config, line): '“': 'O', '”': 'O', '"': 'O', - '-': 'O', - '-': 'O', + #'-': 'O', # hyphen is a word piece '〉': 'COMMA', '〈': 'COMMA', '↑': 'O', From 008ed81052097b85853da62e8e0dbdd4b1b05327 Mon Sep 17 00:00:00 2001 From: Jeronymous Date: Wed, 17 May 2023 09:01:48 +0200 Subject: [PATCH 03/11] glue apostrophes --- punctuation/recasepunc.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/punctuation/recasepunc.py b/punctuation/recasepunc.py index 480f139..4865c05 100644 --- a/punctuation/recasepunc.py +++ b/punctuation/recasepunc.py @@ -3,11 +3,10 @@ """recasepunc file.""" import argparse -import collections import os import random import sys -import unicodedata +import re import numpy as np import torch @@ -201,6 +200,8 @@ def generate_predictions(config, line): output += cased_token + punctuation_syms[punc_label] if previous_label == 0: output += '.' + # Glue apostrophes back to words + output = re.sub(r"(\w) *' *(\w)", r"\1'\2", output) return output mapped_punctuation = { From 732b8c558a45c17e1a805fe7f9131fa5c1328bf7 Mon Sep 17 00:00:00 2001 From: Jeronymous Date: Tue, 13 Jun 2023 13:38:57 +0200 Subject: [PATCH 04/11] Update requirements. Document models --- Dockerfile.cpu | 35 +++++++++++++++++++++++++++++++++++ README.md | 13 ++++++++++++- requirements.cpu.txt | 11 +++++++++++ requirements.txt | 7 +++---- 4 files changed, 61 insertions(+), 5 deletions(-) create mode 100644 Dockerfile.cpu create mode 100644 requirements.cpu.txt diff --git a/Dockerfile.cpu b/Dockerfile.cpu new file mode 100644 index 0000000..5547e06 --- /dev/null +++ b/Dockerfile.cpu @@ -0,0 +1,35 @@ +FROM python:3.9 +LABEL maintainer="jlouradour@linagora.com" + +RUN apt-get update && \ + apt-get install -y --no-install-recommends \ + ca-certificates \ + g++ \ + openjdk-11-jre-headless \ + curl \ + wget + +# Rust compiler for tokenizers +RUN curl https://sh.rustup.rs -sSf | bash -s -- -y +ENV PATH="/root/.cargo/bin:${PATH}" + +WORKDIR /usr/src/app + +# Python dependencies +COPY requirements.cpu.txt . +RUN pip3 install --no-cache-dir -r requirements.cpu.txt -f https://download.pytorch.org/whl/torch_stable.html + +# Supervisor +COPY celery_app /usr/src/app/celery_app +COPY http_server /usr/src/app/http_server +COPY document /usr/src/app/document +COPY punctuation /usr/src/app/punctuation +RUN mkdir /usr/src/app/model-store +RUN mkdir -p /usr/src/app/tmp +COPY docker-entrypoint.sh wait-for-it.sh healthcheck.sh ./ + +ENV PYTHONPATH="${PYTHONPATH}:/usr/src/app/punctuation" +HEALTHCHECK CMD ./healthcheck.sh + +ENV TEMP=/usr/src/app/tmp +ENTRYPOINT ["./docker-entrypoint.sh"] diff --git a/README.md b/README.md index ee896a4..0b0816d 100644 --- a/README.md +++ b/README.md @@ -24,7 +24,18 @@ LinTO-platform-punctuation can either be used as a standalone punctuation servic ### Models The punctuation service relies on a trained recasing and punctuation prediction model. -We provide homebrew models on [dl.linto.ai](https://dl.linto.ai/downloads/model-distribution/punctuation_models/). +Some models trained on [Common Crawl](http://data.statmt.org/cc-100/) are available on [recasepunc](https://github.com/benob/recasepunc) for the following the languages: +* French + * [fr-txt.large.19000](https://github.com/benob/recasepunc/releases/download/0.3/fr-txt.large.19000) + * [fr.22000](https://github.com/benob/recasepunc/releases/download/0.3/fr.22000) +* English + * [en.23000](https://github.com/benob/recasepunc/releases/download/0.3/en.23000) +* Italian + * [it.22000](https://github.com/CoffeePerry/recasepunc/releases/download/v0.1.0/it.22000) +* Chinese + * [zh.24000](https://github.com/benob/recasepunc/releases/download/0.3/zh.24000) + + ### Docker The punctuation service requires docker up and running. diff --git a/requirements.cpu.txt b/requirements.cpu.txt new file mode 100644 index 0000000..a24a7b6 --- /dev/null +++ b/requirements.cpu.txt @@ -0,0 +1,11 @@ +celery[redis,auth,msgpack]>=4.4.7 +flask>=1.1.2 +flask-cors>=3.0.10 +flask-swagger-ui>=3.36.0 +gevent>=22.10.2 +gunicorn>=20.1.0 +git+https://github.com/benob/mosestokenizer.git@169bd3a504fe20a3e51b9a7af3f0ca359c2d36c9 +numpy==1.19.5 +regex==2021.8.28 +torch==1.9.0+cpu +transformers==4.10.0 diff --git a/requirements.txt b/requirements.txt index 3499eab..edaaa92 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,11 +2,10 @@ celery[redis,auth,msgpack]>=4.4.7 flask>=1.1.2 flask-cors>=3.0.10 flask-swagger-ui>=3.36.0 -gevent>=22.10.2 # NOCOMMIT -waitress>=2.1.2 # NOCOMMIT +gevent>=22.10.2 gunicorn>=20.1.0 -git+https://github.com/benob/mosestokenizer.git +git+https://github.com/benob/mosestokenizer.git@169bd3a504fe20a3e51b9a7af3f0ca359c2d36c9 numpy==1.19.5 regex==2021.8.28 -torch==1.9.0+cpu # NOCOMMIT : use cu111? +torch==1.9.0 transformers==4.10.0 From 8748f8bca701b7e43c09686971bc756e832e065e Mon Sep 17 00:00:00 2001 From: Jeronymous Date: Tue, 13 Jun 2023 16:00:20 +0200 Subject: [PATCH 05/11] Make punctuation recovery deterministic. And start to implement some thing for (insertion) disfluencies --- punctuation/recasepunc.py | 55 +++++++++++++++++++++++++++++++++++++-- 1 file changed, 53 insertions(+), 2 deletions(-) diff --git a/punctuation/recasepunc.py b/punctuation/recasepunc.py index 4865c05..c7c6120 100644 --- a/punctuation/recasepunc.py +++ b/punctuation/recasepunc.py @@ -59,6 +59,9 @@ def init_random(seed): # make sure everything is deterministic os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' torch.use_deterministic_algorithms(True) + set_seed(seed) + +def set_seed(seed): torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) random.seed(seed) @@ -138,14 +141,19 @@ def load_model(checkpoint_path="/usr/src/app/model-store/model", config=None): return config -def generate_predictions(config, line): +def generate_predictions(config, line, ignore_disfluencies=False): if isinstance(line, list): return [generate_predictions(config, l) for l in line] model = config.model + set_seed(config.seed) # also drop punctuation that we may generate line = ''.join([c for c in line if c not in mapped_punctuation]) + if ignore_disfluencies: + line = collapse_whitespace(line) + line = re.sub(r"(\w) *' *(\w)", r"\1'\2", line) # glue apostrophes to words + disfluencies, line = remove_simple_disfluences(line) output = '' if config.debug: print(line) @@ -202,6 +210,10 @@ def generate_predictions(config, line): output += '.' # Glue apostrophes back to words output = re.sub(r"(\w) *' *(\w)", r"\1'\2", output) + + if ignore_disfluencies: + output = collapse_whitespace(output) + output = reconstitute_text(output, disfluencies) return output mapped_punctuation = { @@ -227,7 +239,7 @@ def generate_predictions(config, line): '【': 'COMMA', '】': 'COMMA', '└': 'COMMA', - '└ ': 'COMMA', + #'└ ': 'COMMA', '_': 'O', '。': 'PERIOD', '、': 'COMMA', # enumeration comma @@ -251,6 +263,10 @@ def generate_predictions(config, line): '〕': 'COMMA', } +def collapse_whitespace(text): + return re.sub(r'\s+', ' ', text).strip() + + # modification of the wordpiece tokenizer to keep case information even if vocab is lower cased # forked from https://github.com/huggingface/transformers/blob/master/src/transformers/models/bert/tokenization_bert.py @@ -400,3 +416,38 @@ def init(config): print('WARNING: reverting to cpu as cuda is not available', file=sys.stderr) config.device = torch.device(config.device if torch.cuda.is_available() else 'cpu') +def remove_simple_disfluences(text, language=None): + if language is None: + # Get language from environment + language = os.environ.get("LANGUAGE","") + language = language.lower()[:2] + disfluencies = DISFLUENCIES.get(language, []) + all_hits = [] + for disfluency in disfluencies: + all_hits += re.finditer(r" *"+disfluency+r" *", text) + all_hits = sorted(all_hits, key=lambda x: x.start()) + to_be_inserted = [(hit.start(), hit.group()) for hit in all_hits] + new_text = text + for hit in all_hits[::-1]: + new_text = new_text[:hit.start()] + " " + new_text[hit.end():] + return to_be_inserted, new_text + +punctuation_regex = r"["+re.escape("".join(mapped_punctuation.keys()))+r"]" + +def reconstitute_text(text, to_be_inserted): + if len(to_be_inserted) == 0: + return text + pos_punc = [s.start() for s in re.finditer(punctuation_regex, text)] + for start, token in to_be_inserted: + start += len([p for p in pos_punc if p < start]) + text = text[:start] + token.rstrip(" ") + text[start:] + print(text) + return text + + +DISFLUENCIES = { + "fr": [ + "euh", + "heu", + ] +} \ No newline at end of file From 6ff64144cecb28c8b0102e2257e8859a4536c808 Mon Sep 17 00:00:00 2001 From: Jeronymous Date: Wed, 14 Jun 2023 09:15:49 +0200 Subject: [PATCH 06/11] note for later --- punctuation/recasepunc.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/punctuation/recasepunc.py b/punctuation/recasepunc.py index c7c6120..9d43f25 100644 --- a/punctuation/recasepunc.py +++ b/punctuation/recasepunc.py @@ -151,6 +151,7 @@ def generate_predictions(config, line, ignore_disfluencies=False): # also drop punctuation that we may generate line = ''.join([c for c in line if c not in mapped_punctuation]) if ignore_disfluencies: + # TODO: fix when there are several disfluencies in a row ("euh euh") line = collapse_whitespace(line) line = re.sub(r"(\w) *' *(\w)", r"\1'\2", line) # glue apostrophes to words disfluencies, line = remove_simple_disfluences(line) @@ -424,7 +425,7 @@ def remove_simple_disfluences(text, language=None): disfluencies = DISFLUENCIES.get(language, []) all_hits = [] for disfluency in disfluencies: - all_hits += re.finditer(r" *"+disfluency+r" *", text) + all_hits += re.finditer(r" *\b"+disfluency+r"\b *", text) all_hits = sorted(all_hits, key=lambda x: x.start()) to_be_inserted = [(hit.start(), hit.group()) for hit in all_hits] new_text = text From 1cf3fdf05c8a64bffc89009cd153f241b6b9a80b Mon Sep 17 00:00:00 2001 From: Jeronymous Date: Wed, 14 Jun 2023 09:17:52 +0200 Subject: [PATCH 07/11] Docker image build for recasepunc --- Jenkinsfile | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/Jenkinsfile b/Jenkinsfile index 5f9b7f0..12c53f4 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -47,5 +47,26 @@ pipeline { } } } + + stage('Docker build for recasepunc branch'){ + when{ + branch 'recasepunc' + } + steps { + echo 'Publishing recasepunc' + script { + image = docker.build(env.DOCKER_HUB_REPO, "-f Dockerfile.cpu .") + VERSION = sh( + returnStdout: true, + script: "awk -v RS='' '/#/ {print; exit}' RELEASE.md | head -1 | sed 's/#//' | sed 's/ //'" + ).trim() + docker.withRegistry('https://registry.hub.docker.com', env.DOCKER_HUB_CRED) { + image.push("${VERSION}") + image.push('recasepunc-latest') + } + } + } + } + }// end stages } From fe99c1e99402a23fa4d5261cacf2d120f33650f0 Mon Sep 17 00:00:00 2001 From: Jeronymous Date: Wed, 14 Jun 2023 10:08:31 +0200 Subject: [PATCH 08/11] Version upgrades should only occur for new releases on the master branch --- Jenkinsfile | 1 - 1 file changed, 1 deletion(-) diff --git a/Jenkinsfile b/Jenkinsfile index 12c53f4..4a2e51f 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -61,7 +61,6 @@ pipeline { script: "awk -v RS='' '/#/ {print; exit}' RELEASE.md | head -1 | sed 's/#//' | sed 's/ //'" ).trim() docker.withRegistry('https://registry.hub.docker.com', env.DOCKER_HUB_CRED) { - image.push("${VERSION}") image.push('recasepunc-latest') } } From af35c5d8b65b2391b7dd4981d18ffa84bc4a4d6c Mon Sep 17 00:00:00 2001 From: Jeronymous Date: Mon, 11 Mar 2024 13:33:16 +0100 Subject: [PATCH 09/11] add doc and image for GPU capabilities --- Dockerfile | 2 +- Jenkinsfile | 2 +- README.md | 9 +++++++-- RELEASE.md | 3 +++ requirements.txt | 2 +- 5 files changed, 13 insertions(+), 5 deletions(-) diff --git a/Dockerfile b/Dockerfile index e4de246..1cedf91 100644 --- a/Dockerfile +++ b/Dockerfile @@ -5,8 +5,8 @@ RUN apt-get update && \ apt-get install -y --no-install-recommends \ ca-certificates \ g++ \ - openjdk-11-jre-headless \ curl \ + libtinfo5 \ wget # Rust compiler for tokenizers diff --git a/Jenkinsfile b/Jenkinsfile index 4a2e51f..3a63ebd 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -55,7 +55,7 @@ pipeline { steps { echo 'Publishing recasepunc' script { - image = docker.build(env.DOCKER_HUB_REPO, "-f Dockerfile.cpu .") + image = docker.build(env.DOCKER_HUB_REPO, "-f Dockerfile .") VERSION = sh( returnStdout: true, script: "awk -v RS='' '/#/ {print; exit}' RELEASE.md | head -1 | sed 's/#//' | sed 's/ //'" diff --git a/README.md b/README.md index 0b0816d..9c45550 100644 --- a/README.md +++ b/README.md @@ -40,6 +40,9 @@ Some models trained on [Common Crawl](http://data.statmt.org/cc-100/) are availa ### Docker The punctuation service requires docker up and running. +For GPU capabilities, it is also needed to install +[nvidia-container-toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html). + ### (micro-service) Service broker The punctuation only entry point in job mode are tasks posted on a REDIS message broker using [Celery](https://github.com/celery/celery). @@ -63,7 +66,7 @@ docker pull registry.linto.ai/lintoai/linto-platform-punctuation:latest **2- Download the models** -Have the punctuation model ready at MODEL_PATH. +Have the punctuation model ready at ``. ### HTTP @@ -84,12 +87,14 @@ Fill the .env with your values. ```bash docker run --rm \ --v MODEL_PATH:/usr/src/app/model-store/model \ +-v :/usr/src/app/model-store/model \ -p HOST_SERVING_PORT:80 \ --env-file .env \ linto-platform-punctuation:latest ``` +Also add ```--gpus all``` as an option to enable GPU capabilities. + This will run a container providing an http API binded on the host HOST_SERVING_PORT port. diff --git a/RELEASE.md b/RELEASE.md index 6848347..ee8dfe6 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -1,3 +1,6 @@ +# 2.0.1 +- newer image for recasepunc, with GPU support + # 2.0.0 - Integration of recasepunc diff --git a/requirements.txt b/requirements.txt index edaaa92..9082873 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ celery[redis,auth,msgpack]>=4.4.7 flask>=1.1.2 flask-cors>=3.0.10 -flask-swagger-ui>=3.36.0 +flask-swagger-ui==3.36.0 gevent>=22.10.2 gunicorn>=20.1.0 git+https://github.com/benob/mosestokenizer.git@169bd3a504fe20a3e51b9a7af3f0ca359c2d36c9 From 24af147c06d343f44ec43f676e2a4fabc37dc96d Mon Sep 17 00:00:00 2001 From: Jeronymous Date: Mon, 23 Sep 2024 13:25:54 +0200 Subject: [PATCH 10/11] Remove version number that was never released --- RELEASE.md | 3 --- 1 file changed, 3 deletions(-) diff --git a/RELEASE.md b/RELEASE.md index ee8dfe6..e49b832 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -1,7 +1,4 @@ # 2.0.1 -- newer image for recasepunc, with GPU support - -# 2.0.0 - Integration of recasepunc # 1.1.0 From 4ea698167a855fa4c34d5f4387d0540b8361c50b Mon Sep 17 00:00:00 2001 From: Jeronymous Date: Mon, 23 Sep 2024 13:28:31 +0200 Subject: [PATCH 11/11] fix version number --- RELEASE.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/RELEASE.md b/RELEASE.md index e49b832..6848347 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -1,4 +1,4 @@ -# 2.0.1 +# 2.0.0 - Integration of recasepunc # 1.1.0