From d73da5fd39d82022c6a098625a316b93764a9f5a Mon Sep 17 00:00:00 2001 From: YevhenKost Date: Fri, 15 Mar 2024 13:56:58 +0000 Subject: [PATCH 1/3] added batching; fixed circular references; added model loading; added new endpoint --- Dockerfile | 9 ------ config/config.json | 2 +- main.py | 37 +++++++++++++++++---- requirements.txt | 13 ++++++-- src/bert_te.py | 73 +++++++++++++++++++++++++++-------------- src/data.py | 40 +++++++++++++++-------- src/loading_utils.py | 15 +++++++++ src/models.py | 77 +++++++++++++++++++++++++++++++++++++++++--- src/templates.py | 5 +-- 9 files changed, 207 insertions(+), 64 deletions(-) create mode 100644 src/loading_utils.py diff --git a/Dockerfile b/Dockerfile index e551054..b8a2720 100644 --- a/Dockerfile +++ b/Dockerfile @@ -4,15 +4,6 @@ FROM python:3.8.2 RUN pip install --upgrade pip -RUN pip3 install tqdm - - -RUN pip3 install torch -RUN pip3 install numpy -RUN pip3 install transformers -RUN pip3 install Cython - - COPY . /app WORKDIR /app RUN pip install -r requirements.txt diff --git a/config/config.json b/config/config.json index 16d0b35..893ea96 100644 --- a/config/config.json +++ b/config/config.json @@ -1,4 +1,4 @@ { - "model_path": "facebook/bart-large-mnli" + "model_path": "valhalla/distilbart-mnli-12-1" } \ No newline at end of file diff --git a/main.py b/main.py index 9e5bb87..20834d7 100644 --- a/main.py +++ b/main.py @@ -1,8 +1,9 @@ - -from src.bert_te import BertArgumentStructure +from src.bert_te import BertArgumentStructure from src.data import Data from src.utility import handle_errors +from src.loading_utils import load_model + from flask import Flask, request from prometheus_flask_exporter import PrometheusMetrics import logging @@ -13,6 +14,7 @@ app = Flask(__name__) metrics = PrometheusMetrics(app) +model = load_model(config_file_path="config/config.json") @metrics.summary('requests_by_status', 'Request latencies by status', labels={'status': lambda r: r.status_code}) @@ -23,8 +25,10 @@ def bertte(): if request.method == 'POST': file_obj = request.files['file'] - data = Data(file_obj) - result = BertArgumentStructure(file_obj).get_argument_structure() + # data = Data(file_obj) + result = BertArgumentStructure( + file_obj=file_obj, model=model + ).get_argument_structure() return result @@ -34,7 +38,28 @@ def bertte(): The model is fine-tuned to recognize inferences, conflicts, and non-relations. It accepts xIAF as input and returns xIAF as output. This component can be integrated into the argument mining pipeline alongside a segmenter.""" - return info - + return info + + +@handle_errors +@app.route('/bert-te-from-json-to-json', methods=['GET', 'POST']) +def bertte_from_json(): + if request.method == 'POST': + xaif_dict = request.json + + result = BertArgumentStructure( + file_obj=None, model=model + ).get_argument_structure_from_json(xaif_dict=xaif_dict) + + return result + + if request.method == 'GET': + info = """The Inference Identifier is a component of AMF that detects argument relations between propositions. + This implementation utilises the Hugging Face implementation of BERT for textual entailment. + The model is fine-tuned to recognize inferences, conflicts, and non-relations. + It accepts xIAF as input and returns xIAF as output. + This component can be integrated into the argument mining pipeline alongside a segmenter.""" + return info + if __name__ == "__main__": app.run(host="0.0.0.0", port=int("5002"), debug=False) diff --git a/requirements.txt b/requirements.txt index 3b9a991..eec4ca9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,11 @@ -flask -flask_uploads -prometheus_flask_exporter +flask==3.0.2 +flask_uploads==0.2.1 +prometheus_flask_exporter==0.23.0 + +torch==2.2.1 +numpy==1.26.4 +transformers==4.38.2 +Cython==3.0.9 + +tqdm==4.66.2 diff --git a/src/bert_te.py b/src/bert_te.py index 2133053..0a48be4 100644 --- a/src/bert_te.py +++ b/src/bert_te.py @@ -1,20 +1,28 @@ -import json -from src.models import Model from src.data import Data, AIF -from src.templates import BertTEOutput + +from itertools import combinations + + +def divide_chunks(l, n): + # looping till length l + for i in range(0, len(l), n): + yield l[i:i + n] class BertArgumentStructure: - def __init__(self,file_obj): + def __init__(self,file_obj, model, batch_size=8): self.file_obj = file_obj - self.config_file_path = "'config/config.json'" - self.model_path = self.load_config(self.config_file_path) - self.model = Model(self.model_path) - def load_config(self, file_path): - """Load the contents of the config.json file to get the model files.""" - with open(file_path, 'r') as config_file: - config_data = json.load(config_file) - return config_data.get('model_path') + self.model = model + self.batch_size = batch_size + + def get_argument_structure_from_json(self, xaif_dict): + aif = xaif_dict.get('AIF', {}) + + propositions_id_pairs = self.get_propositions_id_pairs(aif) + self.update_node_edge_with_relations(propositions_id_pairs, aif) + + return self.format_output(xaif_dict, aif) + def get_argument_structure(self): """Retrieve the argument structure from the input data.""" @@ -56,18 +64,35 @@ def update_node_edge_with_relations(self, propositions_id_pairs, aif): """ Update the nodes and edges in the AIF structure to reflect the new relations between propositions. """ - checked_pairs = set() - for prop1_node_id, prop1 in propositions_id_pairs.items(): - for prop2_node_id, prop2 in propositions_id_pairs.items(): - if prop1_node_id != prop2_node_id: - pair1 = (prop1_node_id, prop2_node_id) - pair2 = (prop2_node_id, prop1_node_id) - if pair1 not in checked_pairs and pair2 not in checked_pairs: - checked_pairs.add(pair1) - checked_pairs.add(pair2) - prediction = self.model.predict((prop1, prop2)) - AIF.create_entry(aif['nodes'], aif['edges'], prediction, prop1_node_id, prop2_node_id) + + node_ids_combs = list(combinations( + list(propositions_id_pairs.keys()), 2 + )) + + + for batch_node_pairs in divide_chunks(node_ids_combs, self.batch_size): + + batch_proposition_pairs = [ + [propositions_id_pairs[node_id_1], propositions_id_pairs[node_id_2]] + for node_id_1, node_id_2 in batch_node_pairs + ] + batch_preds = self.model.predict_pairs_batch( + proposition_pairs=batch_proposition_pairs + ) + + for node_ids_pair, prediction in zip(batch_node_pairs, batch_preds): + AIF.create_entry( + aif['nodes'], aif['edges'], + prediction, node_ids_pair[0], node_ids_pair[1] + ) def format_output(self, x_aif, aif): """Format the output data.""" - return BertTEOutput.format_output(x_aif['AIF']['nodes'], x_aif['AIF']['edges'], x_aif, aif) + + xaif_output = {} + xaif_output["nodes"] = x_aif['AIF']['nodes'].copy() + xaif_output["edges"] = x_aif['AIF']['edges'].copy() + xaif_output["AIF"] = aif.copy() + return xaif_output + + # return BertTEOutput.format_output(x_aif['AIF']['nodes'], x_aif['AIF']['edges'], x_aif, aif) diff --git a/src/data.py b/src/data.py index 41bf45c..88d3a63 100644 --- a/src/data.py +++ b/src/data.py @@ -40,13 +40,15 @@ def get_file_path(self,): return self.f_name class AIF: - def __init__(self, ): - pass - def is_valid_json_aif(sel,aif_nodes): + + @classmethod + def is_valid_json_aif(cls,aif_nodes): if 'nodes' in aif_nodes and 'locutions' in aif_nodes and 'edges' in aif_nodes: return True return False - def is_json_aif_dialog(self, aif_nodes: list) -> bool: + + @classmethod + def is_json_aif_dialog(cls, aif_nodes: list) -> bool: ''' check if json_aif is dialog ''' @@ -57,7 +59,8 @@ def is_json_aif_dialog(self, aif_nodes: list) -> bool: - def get_next_max_id(self, nodes, n_type): + @classmethod + def get_next_max_id(cls, nodes, n_type): """ Takes a list of nodes (edges) and returns the maximum node/edge ID. Arguments: @@ -93,7 +96,8 @@ def get_next_max_id(self, nodes, n_type): - def get_speaker(self, node_id: int, locutions: List[Dict[str, int]], participants: List[Dict[str, str]]) -> str: + @classmethod + def get_speaker(cls, node_id: int, locutions: List[Dict[str, int]], participants: List[Dict[str, str]]) -> str: """ Takes a node ID, a list of locutions, and a list of participants, and returns the name of the participant who spoke the locution with the given node ID, or "None" if the node ID is not found. @@ -127,7 +131,12 @@ def get_speaker(self, node_id: int, locutions: List[Dict[str, int]], participant else: return ("None None","None") - def create_entry(self,nodes, edges, prediction, index1, index2): + @classmethod + def create_entry(cls, nodes, edges, prediction, index1, index2): + if prediction not in [ + "RA", "CA", "MA" + ]: + return if prediction == "RA": AR_text = "Default Inference" @@ -138,16 +147,17 @@ def create_entry(self,nodes, edges, prediction, index1, index2): elif prediction == "MA": AR_text = "Default Rephrase" AR_type = "MA" - node_id = AIF.get_next_max_id(nodes, 'nodeID') - edge_id = AIF.get_next_max_id(edges, 'edgeID') + node_id = cls.get_next_max_id(nodes, 'nodeID') + edge_id = cls.get_next_max_id(edges, 'edgeID') nodes.append({'text': AR_text, 'type':AR_type,'nodeID': node_id}) edges.append({'fromID': index1, 'toID': node_id,'edgeID':edge_id}) - edge_id = AIF.get_next_max_id(edges, 'edgeID') + edge_id = cls.get_next_max_id(edges, 'edgeID') edges.append({'fromID': node_id, 'toID': index2,'edgeID':edge_id}) - def get_i_node_ya_nodes_for_l_node(self, edges, n_id): + @classmethod + def get_i_node_ya_nodes_for_l_node(cls, edges, n_id): """traverse through edges and returns YA node_ID and I node_ID, given L node_ID""" for entry in edges: if n_id == entry['fromID']: @@ -159,7 +169,8 @@ def get_i_node_ya_nodes_for_l_node(self, edges, n_id): return None, None - def remove_entries(self, l_node_id, nodes, edges, locutions): + @classmethod + def remove_entries(cls, l_node_id, nodes, edges, locutions): """ Removes entries associated with a specific node ID from a JSON dictionary. @@ -171,7 +182,7 @@ def remove_entries(self, l_node_id, nodes, edges, locutions): - (Dict): the edited JSON dictionary with entries associated with the specified node ID removed """ # Remove nodes with the specified node ID - in_id, yn_id = self.get_i_node_ya_nodes_for_l_node(edges, l_node_id) + in_id, yn_id = cls.get_i_node_ya_nodes_for_l_node(edges, l_node_id) edited_nodes = [node for node in nodes if node.get('nodeID') != l_node_id] edited_nodes = [node for node in edited_nodes if node.get('nodeID') != in_id] @@ -186,7 +197,8 @@ def remove_entries(self, l_node_id, nodes, edges, locutions): return edited_nodes, edited_edges, edited_locutions - def get_xAIF_arrays(self, aif_section: dict, xaif_elements: List) -> tuple: + @classmethod + def get_xAIF_arrays(cls, aif_section: dict, xaif_elements: List) -> tuple: """ Extracts values associated with specified keys from the given AIF section dictionary. diff --git a/src/loading_utils.py b/src/loading_utils.py new file mode 100644 index 0000000..16ab6bc --- /dev/null +++ b/src/loading_utils.py @@ -0,0 +1,15 @@ +from src.models import Model +import json + +def load_config(file_path): + """Load the contents of the config.json file to get the model files.""" + with open(file_path, 'r') as config_file: + config_data = json.load(config_file) + return config_data.get('model_path') + +def load_model(config_file_path = "config/config.json"): + + model_path = load_config(config_file_path) + model = Model(model_path) + + return model \ No newline at end of file diff --git a/src/models.py b/src/models.py index 05f2d64..a0abaca 100644 --- a/src/models.py +++ b/src/models.py @@ -1,18 +1,85 @@ - from transformers import BartForSequenceClassification, BartTokenizer +from typing import List, Tuple +from transformers import BatchEncoding class Model: def __init__(self, model_path): self.model_path = model_path self.tokenizer = BartTokenizer.from_pretrained(model_path) self.model = BartForSequenceClassification.from_pretrained(model_path) - self.RA_TRESHOLD = 80 - self.CA_TRESHOLD = 10 + self.RA_THRESHOLD = 80 + self.CA_THRESHOLD = 10 - def predict(self, proposition_pair): + def predict_pair(self, proposition_pair): proposition1, proposition2 = proposition_pair return self._post_process(proposition1, proposition2) + + def _tokenize_pairs_batch(self, proposition_pairs: List[List[str]]) -> Tuple[BatchEncoding]: + + # tokenize in provided order: [[t1, t2], [t3, t4]] + encoded_input_original_pair_order = self.tokenizer.batch_encode_plus( + proposition_pairs, + return_tensors="pt", + padding=True, + truncation=True + ) + + # tokenize in reversed order of text pairs: [[t2, t1], [t4, t3]] + reversed_order_proposition_pairs = [ + text_pair[::-1] for text_pair in proposition_pairs + ] + encoded_input_reversed_pair_order = self.tokenizer.batch_encode_plus( + reversed_order_proposition_pairs, + return_tensors="pt", + padding=True, + truncation=True + ) + + return encoded_input_reversed_pair_order, encoded_input_original_pair_order + + def _make_decision(self, prob: float) -> str: + + if prob > self.RA_THRESHOLD: + arg_rel = "RA" + elif prob < self.CA_THRESHOLD: + arg_rel = "CA" + else: + arg_rel = "None" + + return arg_rel + + def predict_pairs_batch(self, proposition_pairs: List[List[str]]) -> List[str]: + + encoded_input_reversed_pair_order, encoded_input_original_pair_order = self._tokenize_pairs_batch( + proposition_pairs=proposition_pairs + ) + + preds_reversed_pair_order = self.model(**encoded_input_reversed_pair_order) + preds_original_pair_order = self.model(**encoded_input_original_pair_order) + + entail_contradiction_logits_reversed_pair_order = preds_reversed_pair_order.logits[:, [0, 2]].softmax(dim=1) + entail_contradiction_logits_original_pair_order = preds_original_pair_order.logits[:, [0, 2]].softmax(dim=1) + + true_probs_reversed_pair_order = entail_contradiction_logits_reversed_pair_order[:, 1] * 100 + true_probs_original_pair_order = entail_contradiction_logits_original_pair_order[:, 1] * 100 + + max_true_probs = [ + max([prob_rev, prob_orig]) for prob_rev, prob_orig in zip( + true_probs_reversed_pair_order, true_probs_original_pair_order + ) + ] + + preds_decisions = [ + self._make_decision(prob) for prob in max_true_probs + ] + + return preds_decisions + + + + + def _get_prob(self, text1, text2): input_ids = self.tokenizer.encode(text1, text2, return_tensors='pt') logits = self.model(input_ids)[0] @@ -22,7 +89,7 @@ def _get_prob(self, text1, text2): return true_prob def _post_process(self, text1, text2): - true_prob, arg_rel = 0.0, "None" + true_prob1 = self._get_prob(text1, text2) true_prob2 = self._get_prob(text2, text1) true_prob = max(true_prob1, true_prob2) diff --git a/src/templates.py b/src/templates.py index e0ea986..524c524 100644 --- a/src/templates.py +++ b/src/templates.py @@ -3,8 +3,9 @@ class BertTEOutput: @staticmethod def format_output(nodes, edges, aif={}, x_aif={}): + aif['nodes'] = nodes - aif['edges'] = edges - x_aif['AIF'] = aif + aif['edges'] = edges + x_aif['AIF'] = aif.copy() return json.dumps(x_aif) From 04a7e16d012a68a3e67be35199dbe322db458fae Mon Sep 17 00:00:00 2001 From: YevhenKost Date: Mon, 18 Mar 2024 11:18:25 +0000 Subject: [PATCH 2/3] removed numpy from reqs --- requirements.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index eec4ca9..f5ee94a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,7 +3,6 @@ flask_uploads==0.2.1 prometheus_flask_exporter==0.23.0 torch==2.2.1 -numpy==1.26.4 transformers==4.38.2 Cython==3.0.9 From 470a712a4857a5e985c1e1d1ba8c85e19f251afa Mon Sep 17 00:00:00 2001 From: YevhenKost Date: Thu, 28 Mar 2024 12:08:54 +0000 Subject: [PATCH 3/3] fix: empty edge list error --- src/data.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/data.py b/src/data.py index 88d3a63..fced145 100644 --- a/src/data.py +++ b/src/data.py @@ -69,6 +69,9 @@ def get_next_max_id(cls, nodes, n_type): - (int): the maximum node/edge ID in the list of nodes """ + if not len(nodes): + return 0 + max_id, lef_n_id, right_n_id = 0, 0, "" if isinstance(nodes[0][n_type],str): # check if the node id is a text or integer if "_" in nodes[0][n_type]: