Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dev/batching #2

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 0 additions & 9 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,6 @@ FROM python:3.8.2

RUN pip install --upgrade pip

RUN pip3 install tqdm


RUN pip3 install torch
RUN pip3 install numpy
RUN pip3 install transformers
RUN pip3 install Cython


COPY . /app
WORKDIR /app
RUN pip install -r requirements.txt
Expand Down
2 changes: 1 addition & 1 deletion config/config.json
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
{
"model_path": "facebook/bart-large-mnli"
"model_path": "valhalla/distilbart-mnli-12-1"
}

37 changes: 31 additions & 6 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@

from src.bert_te import BertArgumentStructure
from src.bert_te import BertArgumentStructure
from src.data import Data
from src.utility import handle_errors

from src.loading_utils import load_model

from flask import Flask, request
from prometheus_flask_exporter import PrometheusMetrics
import logging
Expand All @@ -13,6 +14,7 @@

app = Flask(__name__)
metrics = PrometheusMetrics(app)
model = load_model(config_file_path="config/config.json")

@metrics.summary('requests_by_status', 'Request latencies by status',
labels={'status': lambda r: r.status_code})
Expand All @@ -23,8 +25,10 @@
def bertte():
if request.method == 'POST':
file_obj = request.files['file']
data = Data(file_obj)
result = BertArgumentStructure(file_obj).get_argument_structure()
# data = Data(file_obj)
result = BertArgumentStructure(
file_obj=file_obj, model=model
).get_argument_structure()

return result

Expand All @@ -34,7 +38,28 @@ def bertte():
The model is fine-tuned to recognize inferences, conflicts, and non-relations.
It accepts xIAF as input and returns xIAF as output.
This component can be integrated into the argument mining pipeline alongside a segmenter."""
return info

return info


@handle_errors
@app.route('/bert-te-from-json-to-json', methods=['GET', 'POST'])
def bertte_from_json():
if request.method == 'POST':
xaif_dict = request.json

result = BertArgumentStructure(
file_obj=None, model=model
).get_argument_structure_from_json(xaif_dict=xaif_dict)

return result

if request.method == 'GET':
info = """The Inference Identifier is a component of AMF that detects argument relations between propositions.
This implementation utilises the Hugging Face implementation of BERT for textual entailment.
The model is fine-tuned to recognize inferences, conflicts, and non-relations.
It accepts xIAF as input and returns xIAF as output.
This component can be integrated into the argument mining pipeline alongside a segmenter."""
return info

if __name__ == "__main__":
app.run(host="0.0.0.0", port=int("5002"), debug=False)
12 changes: 9 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
flask
flask_uploads
prometheus_flask_exporter
flask==3.0.2
flask_uploads==0.2.1
prometheus_flask_exporter==0.23.0

torch==2.2.1
transformers==4.38.2
Cython==3.0.9

tqdm==4.66.2

73 changes: 49 additions & 24 deletions src/bert_te.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,28 @@
import json
from src.models import Model
from src.data import Data, AIF
from src.templates import BertTEOutput

from itertools import combinations


def divide_chunks(l, n):
# looping till length l
for i in range(0, len(l), n):
yield l[i:i + n]

class BertArgumentStructure:
def __init__(self,file_obj):
def __init__(self,file_obj, model, batch_size=8):
self.file_obj = file_obj
self.config_file_path = "'config/config.json'"
self.model_path = self.load_config(self.config_file_path)
self.model = Model(self.model_path)

def load_config(self, file_path):
"""Load the contents of the config.json file to get the model files."""
with open(file_path, 'r') as config_file:
config_data = json.load(config_file)
return config_data.get('model_path')
self.model = model
self.batch_size = batch_size

def get_argument_structure_from_json(self, xaif_dict):
aif = xaif_dict.get('AIF', {})

propositions_id_pairs = self.get_propositions_id_pairs(aif)
self.update_node_edge_with_relations(propositions_id_pairs, aif)

return self.format_output(xaif_dict, aif)


def get_argument_structure(self):
"""Retrieve the argument structure from the input data."""
Expand Down Expand Up @@ -56,18 +64,35 @@ def update_node_edge_with_relations(self, propositions_id_pairs, aif):
"""
Update the nodes and edges in the AIF structure to reflect the new relations between propositions.
"""
checked_pairs = set()
for prop1_node_id, prop1 in propositions_id_pairs.items():
for prop2_node_id, prop2 in propositions_id_pairs.items():
if prop1_node_id != prop2_node_id:
pair1 = (prop1_node_id, prop2_node_id)
pair2 = (prop2_node_id, prop1_node_id)
if pair1 not in checked_pairs and pair2 not in checked_pairs:
checked_pairs.add(pair1)
checked_pairs.add(pair2)
prediction = self.model.predict((prop1, prop2))
AIF.create_entry(aif['nodes'], aif['edges'], prediction, prop1_node_id, prop2_node_id)

node_ids_combs = list(combinations(
list(propositions_id_pairs.keys()), 2
))


for batch_node_pairs in divide_chunks(node_ids_combs, self.batch_size):

batch_proposition_pairs = [
[propositions_id_pairs[node_id_1], propositions_id_pairs[node_id_2]]
for node_id_1, node_id_2 in batch_node_pairs
]
batch_preds = self.model.predict_pairs_batch(
proposition_pairs=batch_proposition_pairs
)

for node_ids_pair, prediction in zip(batch_node_pairs, batch_preds):
AIF.create_entry(
aif['nodes'], aif['edges'],
prediction, node_ids_pair[0], node_ids_pair[1]
)

def format_output(self, x_aif, aif):
"""Format the output data."""
return BertTEOutput.format_output(x_aif['AIF']['nodes'], x_aif['AIF']['edges'], x_aif, aif)

xaif_output = {}
xaif_output["nodes"] = x_aif['AIF']['nodes'].copy()
xaif_output["edges"] = x_aif['AIF']['edges'].copy()
xaif_output["AIF"] = aif.copy()
return xaif_output

# return BertTEOutput.format_output(x_aif['AIF']['nodes'], x_aif['AIF']['edges'], x_aif, aif)
43 changes: 29 additions & 14 deletions src/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,15 @@ def get_file_path(self,):
return self.f_name

class AIF:
def __init__(self, ):
pass
def is_valid_json_aif(sel,aif_nodes):

@classmethod
def is_valid_json_aif(cls,aif_nodes):
if 'nodes' in aif_nodes and 'locutions' in aif_nodes and 'edges' in aif_nodes:
return True
return False
def is_json_aif_dialog(self, aif_nodes: list) -> bool:

@classmethod
def is_json_aif_dialog(cls, aif_nodes: list) -> bool:
''' check if json_aif is dialog
'''

Expand All @@ -57,7 +59,8 @@ def is_json_aif_dialog(self, aif_nodes: list) -> bool:



def get_next_max_id(self, nodes, n_type):
@classmethod
def get_next_max_id(cls, nodes, n_type):
"""
Takes a list of nodes (edges) and returns the maximum node/edge ID.
Arguments:
Expand All @@ -66,6 +69,9 @@ def get_next_max_id(self, nodes, n_type):
- (int): the maximum node/edge ID in the list of nodes
"""

if not len(nodes):
return 0

max_id, lef_n_id, right_n_id = 0, 0, ""
if isinstance(nodes[0][n_type],str): # check if the node id is a text or integer
if "_" in nodes[0][n_type]:
Expand Down Expand Up @@ -93,7 +99,8 @@ def get_next_max_id(self, nodes, n_type):



def get_speaker(self, node_id: int, locutions: List[Dict[str, int]], participants: List[Dict[str, str]]) -> str:
@classmethod
def get_speaker(cls, node_id: int, locutions: List[Dict[str, int]], participants: List[Dict[str, str]]) -> str:
"""
Takes a node ID, a list of locutions, and a list of participants, and returns the name of the participant who spoke the locution with the given node ID, or "None"
if the node ID is not found.
Expand Down Expand Up @@ -127,7 +134,12 @@ def get_speaker(self, node_id: int, locutions: List[Dict[str, int]], participant
else:
return ("None None","None")

def create_entry(self,nodes, edges, prediction, index1, index2):
@classmethod
def create_entry(cls, nodes, edges, prediction, index1, index2):
if prediction not in [
"RA", "CA", "MA"
]:
return

if prediction == "RA":
AR_text = "Default Inference"
Expand All @@ -138,16 +150,17 @@ def create_entry(self,nodes, edges, prediction, index1, index2):
elif prediction == "MA":
AR_text = "Default Rephrase"
AR_type = "MA"
node_id = AIF.get_next_max_id(nodes, 'nodeID')
edge_id = AIF.get_next_max_id(edges, 'edgeID')
node_id = cls.get_next_max_id(nodes, 'nodeID')
edge_id = cls.get_next_max_id(edges, 'edgeID')
nodes.append({'text': AR_text, 'type':AR_type,'nodeID': node_id})
edges.append({'fromID': index1, 'toID': node_id,'edgeID':edge_id})
edge_id = AIF.get_next_max_id(edges, 'edgeID')
edge_id = cls.get_next_max_id(edges, 'edgeID')
edges.append({'fromID': node_id, 'toID': index2,'edgeID':edge_id})



def get_i_node_ya_nodes_for_l_node(self, edges, n_id):
@classmethod
def get_i_node_ya_nodes_for_l_node(cls, edges, n_id):
"""traverse through edges and returns YA node_ID and I node_ID, given L node_ID"""
for entry in edges:
if n_id == entry['fromID']:
Expand All @@ -159,7 +172,8 @@ def get_i_node_ya_nodes_for_l_node(self, edges, n_id):
return None, None


def remove_entries(self, l_node_id, nodes, edges, locutions):
@classmethod
def remove_entries(cls, l_node_id, nodes, edges, locutions):
"""
Removes entries associated with a specific node ID from a JSON dictionary.

Expand All @@ -171,7 +185,7 @@ def remove_entries(self, l_node_id, nodes, edges, locutions):
- (Dict): the edited JSON dictionary with entries associated with the specified node ID removed
"""
# Remove nodes with the specified node ID
in_id, yn_id = self.get_i_node_ya_nodes_for_l_node(edges, l_node_id)
in_id, yn_id = cls.get_i_node_ya_nodes_for_l_node(edges, l_node_id)
edited_nodes = [node for node in nodes if node.get('nodeID') != l_node_id]
edited_nodes = [node for node in edited_nodes if node.get('nodeID') != in_id]

Expand All @@ -186,7 +200,8 @@ def remove_entries(self, l_node_id, nodes, edges, locutions):
return edited_nodes, edited_edges, edited_locutions


def get_xAIF_arrays(self, aif_section: dict, xaif_elements: List) -> tuple:
@classmethod
def get_xAIF_arrays(cls, aif_section: dict, xaif_elements: List) -> tuple:
"""
Extracts values associated with specified keys from the given AIF section dictionary.

Expand Down
15 changes: 15 additions & 0 deletions src/loading_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from src.models import Model
import json

def load_config(file_path):
"""Load the contents of the config.json file to get the model files."""
with open(file_path, 'r') as config_file:
config_data = json.load(config_file)
return config_data.get('model_path')

def load_model(config_file_path = "config/config.json"):

model_path = load_config(config_file_path)
model = Model(model_path)

return model
Loading