Skip to content

Commit

Permalink
implemented concurrent calls to node normalizer for performance, usin…
Browse files Browse the repository at this point in the history
…g requests sessions for retries and performance, cleaned up comments
  • Loading branch information
EvanDietzMorris committed Jul 23, 2024
1 parent 4bdeb7d commit 7731c4a
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 97 deletions.
16 changes: 3 additions & 13 deletions Common/kgx_file_normalizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,13 @@
from Common.biolink_utils import BiolinkInformationResources, INFORES_STATUS_INVALID, INFORES_STATUS_DEPRECATED
from Common.biolink_constants import SEQUENCE_VARIANT, PRIMARY_KNOWLEDGE_SOURCE, AGGREGATOR_KNOWLEDGE_SOURCES, \
PUBLICATIONS, OBJECT_ID, SUBJECT_ID, PREDICATE, SUBCLASS_OF
from Common.normalization import NormalizationScheme, NodeNormalizer, EdgeNormalizer, EdgeNormalizationResult
from Common.normalization import NormalizationScheme, NodeNormalizer, EdgeNormalizer, EdgeNormalizationResult, \
NormalizationFailedError
from Common.utils import LoggingUtil, chunk_iterator
from Common.kgx_file_writer import KGXFileWriter
from Common.merging import MemoryGraphMerger, DiskGraphMerger


class NormalizationBrokenError(Exception):
def __init__(self, error_message: str, actual_error: Exception=None):
self.error_message = error_message
self.actual_error = actual_error


class NormalizationFailedError(Exception):
def __init__(self, error_message: str, actual_error: Exception=None):
self.error_message = error_message
self.actual_error = actual_error


EDGE_PROPERTIES_THAT_SHOULD_BE_SETS = {AGGREGATOR_KNOWLEDGE_SOURCES, PUBLICATIONS}
NODE_NORMALIZATION_BATCH_SIZE = 1_000_000
EDGE_NORMALIZATION_BATCH_SIZE = 1_000_000
Expand Down Expand Up @@ -349,6 +338,7 @@ def normalize_edge_file(self):
# this could happen due to rare cases of normalization splits where one node normalizes to many
if edge_count > 1:
edge_splits += edge_count - 1

graph_merger.merge_edges(normalized_edges)
self.logger.info(f'Processed {number_of_source_edges} edges so far...')

Expand Down
15 changes: 2 additions & 13 deletions Common/load_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@

from Common.data_sources import SourceDataLoaderClassFactory, RESOURCE_HOGS, get_available_data_sources
from Common.utils import LoggingUtil, GetDataPullError
from Common.kgx_file_normalizer import KGXFileNormalizer, NormalizationBrokenError, NormalizationFailedError
from Common.normalization import NormalizationScheme, NodeNormalizer, EdgeNormalizer
from Common.kgx_file_normalizer import KGXFileNormalizer
from Common.normalization import NormalizationScheme, NodeNormalizer, EdgeNormalizer, NormalizationFailedError
from Common.metadata import SourceMetadata
from Common.loader_interface import SourceDataBrokenError, SourceDataFailedError
from Common.supplementation import SequenceVariantSupplementation, SupplementationFailedError
Expand Down Expand Up @@ -355,17 +355,6 @@ def normalize_source(self,
normalization_status=SourceMetadata.STABLE,
normalization_info=normalization_info)
return True
except NormalizationBrokenError as broken_error:
error_message = f"{source_id} NormalizationBrokenError: {broken_error.error_message}"
if broken_error.actual_error:
error_message += f" - {broken_error.actual_error}"
self.logger.error(error_message)
source_metadata.update_normalization_metadata(parsing_version,
composite_normalization_version,
normalization_status=SourceMetadata.BROKEN,
normalization_error=error_message,
normalization_time=current_time)
return False
except NormalizationFailedError as failed_error:
error_message = f"{source_id} NormalizationFailedError: {failed_error.error_message}"
if failed_error.actual_error:
Expand Down
134 changes: 63 additions & 71 deletions Common/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@
import requests
import time

from concurrent.futures import ThreadPoolExecutor
from requests.adapters import HTTPAdapter, Retry
from dataclasses import dataclass

from robokop_genetics.genetics_normalization import GeneticsNormalizer
from Common.biolink_constants import *
from Common.utils import LoggingUtil
Expand All @@ -16,7 +19,6 @@
# predicate to use when normalization fails
FALLBACK_EDGE_PREDICATE = 'biolink:related_to'


@dataclass
class NormalizationScheme:
node_normalization_version: str = 'latest'
Expand All @@ -42,6 +44,12 @@ def get_metadata_representation(self):
'strict': self.strict}


class NormalizationFailedError(Exception):
def __init__(self, error_message: str, actual_error: Exception = None):
self.error_message = error_message
self.actual_error = actual_error


class NodeNormalizer:
"""
Class that contains methods relating to node normalization of KGX data.
Expand Down Expand Up @@ -96,98 +104,74 @@ def __init__(self,
self.sequence_variant_normalizer = None
self.variant_node_types = None

def hit_node_norm_service(self, curies, retries=0):
resp: requests.models.Response = requests.post(f'{self.node_norm_endpoint}get_normalized_nodes',
json={'curies': curies,
'conflate': self.conflate_node_types,
'drug_chemical_conflate': self.conflate_node_types,
'description': True})
self.requests_session = self.get_normalization_requests_session()

def hit_node_norm_service(self, curies):
resp = self.requests_session.post(f'{self.node_norm_endpoint}get_normalized_nodes',
json={'curies': curies,
'conflate': self.conflate_node_types,
'drug_chemical_conflate': self.conflate_node_types,
'description': True})
if resp.status_code == 200:
# if successful return the json as an object
return resp.json()
else:
error_message = f'Node norm response code: {resp.status_code}'
if resp.status_code >= 500:
# if 5xx retry 3 times
retries += 1
if retries == 4:
error_message += ', retried 3 times, giving up..'
self.logger.error(error_message)
resp.raise_for_status()
else:
error_message += f', retrying.. (attempt {retries})'
time.sleep(retries * 3)
self.logger.error(error_message)
return self.hit_node_norm_service(curies, retries)
response_json = resp.json()
if response_json:
return response_json
else:
# we should never get a legitimate 4xx response from node norm,
# crash with an error for troubleshooting
if resp.status_code == 422:
error_message += f'(curies: {curies})'
self.logger.error(error_message)
resp.raise_for_status()
error_message = f"Node Normalization service {self.node_norm_endpoint} returned 200 " \
f"but with an empty result for (curies: {curies})"
raise NormalizationFailedError(error_message=error_message)
elif resp.status_code == 422:
# 422 unprocessable entity - we sent something bad to node norm, crash so we can diagnose
error_message = f'Node norm response code: {resp.status_code} (curies: {curies})'
self.logger.error(error_message)
resp.raise_for_status()

def normalize_node_data(self, node_list: list, block_size: int = 1000) -> list:
def normalize_node_data(self, node_list: list, batch_size: int = 1000) -> list:
"""
This method calls the NodeNormalization web service to get the normalized identifier and name of the node.
the data comes in as a node list.
This method calls the NodeNormalization web service and normalizes a list of nodes.
:param node_list: A list with items to normalize
:param block_size: the number of curies in the request
:param node_list: A list of nodes to normalize
:param batch_size: the number of curies to be sent to NodeNormalization at once
:return:
"""

self.logger.debug(f'Start of normalize_node_data. items: {len(node_list)}')

# init the cache - this accumulates all the results from the node norm service
cached_node_norms: dict = {}

# create a unique set of node ids
tmp_normalize: set = set([node['id'] for node in node_list])

# convert the set to a list so we can iterate through it
to_normalize: list = list(tmp_normalize)
# make a list of the node ids
to_normalize: list = [node['id'] for node in node_list]

# init the array index lower boundary
# use indexes and slice to grab batch_size sized chunks of ids from the list
start_index: int = 0

# get the last index of the list
last_index: int = len(to_normalize)

self.logger.debug(f'{last_index} unique nodes found in this group.')

# grab chunks of the data frame
chunks_of_ids = []
while True:
if start_index < last_index:
# define the end index of the slice
end_index: int = start_index + block_size
end_index: int = start_index + batch_size

# force the end index to be the last index to insure no overflow
# force the end index to be no greater than the last index to ensure no overflow
if end_index >= last_index:
end_index = last_index

self.logger.debug(f'Working block {start_index} to {end_index}.')

# collect a slice of records from the data frame
data_chunk: list = to_normalize[start_index: end_index]

# hit the node norm api
normalization_json = self.hit_node_norm_service(curies=data_chunk)
if normalization_json:
# merge the normalization results with what we have gotten so far
cached_node_norms.update(**normalization_json)
else:
# this shouldn't happen but if the API returns an empty dict instead of nulls,
# assume none of the curies normalize
empty_responses = {curie: None for curie in data_chunk}
cached_node_norms.update(empty_responses)
# collect a slice of block_size curies from the full list
chunks_of_ids.append(to_normalize[start_index: end_index])

# move on down the list
start_index += block_size
start_index += batch_size
else:
break

# hit the node norm api with the chunks of curies in parallel
# we could try to optimize the number of max_workers for ThreadPoolExecutor more specifically,
# by default python attempts to find a reasonable # based on os.cpu_count()
with ThreadPoolExecutor() as executor:
# casting to a list here
normalization_results = list(executor.map(self.hit_node_norm_service, chunks_of_ids))
for normalization_json in normalization_results:
# merge the normalization results into one dictionary
cached_node_norms.update(**normalization_json)

# reset the node index
node_idx = 0

Expand Down Expand Up @@ -367,6 +351,16 @@ def get_current_node_norm_version(self):
# this shouldn't happen, raise an exception
resp.raise_for_status()

@staticmethod
def get_normalization_requests_session():
s = requests.Session()
retries = Retry(total=3,
backoff_factor=.1,
status_forcelist=[502, 503, 504, 403, 429])
s.mount('https://', HTTPAdapter(max_retries=retries))
s.mount('http://', HTTPAdapter(max_retries=retries))
return s


class EdgeNormalizationResult:
def __init__(self,
Expand Down Expand Up @@ -424,10 +418,8 @@ def normalize_edge_data(self,
"""

# find the predicates that have not been normalized yet
predicates_to_normalize = set()
for edge in edge_list:
if edge[PREDICATE] not in self.edge_normalization_lookup:
predicates_to_normalize.add(edge[PREDICATE])
predicates_to_normalize = {edge[PREDICATE] for edge in edge_list
if edge[PREDICATE] not in self.edge_normalization_lookup}

# convert the set to a list so we can iterate through it
predicates_to_normalize_list = list(predicates_to_normalize)
Expand Down

0 comments on commit 7731c4a

Please sign in to comment.