Skip to content

Commit

Permalink
Use articles listed in the go_term_publication_map table
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosribas committed Dec 16, 2024
1 parent 912fb2c commit c99b4f9
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 20 deletions.
47 changes: 38 additions & 9 deletions training/export_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

EUROPE_PMC = "https://www.ebi.ac.uk/europepmc/webservices/rest"
RATE_LIMIT = 8
NON_RNA_ARTICLE_LIMIT = 2300
NON_RNA_ARTICLE_LIMIT = 3500


async def clean_text(text: str) -> str:
Expand Down Expand Up @@ -102,6 +102,34 @@ async def rfam_articles() -> Set[str]:
return pubmed_ids


async def go_term_publication() -> Set[str]:
"""
Retrieves PubMed IDs (PMIDs) from Go terms.
:return: a set of PMIDs found in the go_term_publication_map table
"""
user = os.getenv("POSTGRES_USER")
password = os.getenv("POSTGRES_PASSWORD")
database = os.getenv("POSTGRES_DATABASE")
host = os.getenv("POSTGRES_HOST")
port = os.getenv("POSTGRES_PORT")

async with create_engine(user=user, database=database, host=host, password=password, port=port) as engine:
async with engine.acquire() as connection:
query = sa.text('''
SELECT DISTINCT refs.pmid
FROM go_term_publication_map map
JOIN rnc_references refs
ON refs.id=map.reference_id
''')

get_data = await connection.execute(query)
rows = await get_data.fetchall()
pubmed_ids = {row["pmid"] for row in rows}

return pubmed_ids


async def manually_annotated_articles(pmids: Set[str]) -> List[Dict[str, int]]:
"""
Retrieves manually annotated articles from the database, excluding those with PubMed IDs
Expand Down Expand Up @@ -157,7 +185,7 @@ async def manually_annotated_articles(pmids: Set[str]) -> List[Dict[str, int]]:
async def non_rna_articles(page):
pubmed_ids = set()
query = f'/search?query=(IN_EPMC:Y AND OPEN_ACCESS:Y AND NOT SRC:PPR AND NOT "rna" ' \
f'AND NOT "mrna" AND NOT "ncrna" AND NOT "lncrna" AND NOT "rrna" AND NOT "sncrna") ' \
f'AND NOT "mrna" AND NOT "ncrna" AND NOT "lncrna" AND NOT "rrna" AND NOT "sncrna" AND NOT "mirna") ' \
f'&sort_cited:y&pageSize=500&cursorMark={page}&format=json'

try:
Expand All @@ -184,15 +212,16 @@ async def main():
semaphore = asyncio.Semaphore(RATE_LIMIT)

# get abstracts from TarBase and Rfam
tarbase, rfam = await asyncio.gather(
tarbase, rfam, go_term = await asyncio.gather(
tarbase_articles(),
rfam_articles()
rfam_articles(),
go_term_publication()
)
tarbase_rfam_pmids = tarbase | rfam
tarbase_rfam_task = [fetch_abstract(pmid, semaphore) for pmid in tarbase_rfam_pmids]
tarbase_rfam_abstracts = await asyncio.gather(*tarbase_rfam_task)
rna_pmids = tarbase | rfam | go_term
rna_task = [fetch_abstract(pmid, semaphore) for pmid in rna_pmids]
rna_abstracts = await asyncio.gather(*rna_task)

for abstract in filter(None, tarbase_rfam_abstracts):
for abstract in filter(None, rna_abstracts):
cleaned_abstract = await clean_text(abstract)
list_of_abstracts.append({"abstract": cleaned_abstract, "rna_related": 1})

Expand All @@ -213,7 +242,7 @@ async def main():
list_of_abstracts.append({"abstract": cleaned_abstract, "rna_related": 0})

# get abstracts of manually annotated articles (extracted from the RNAcentral database)
manually_annotated = await manually_annotated_articles(tarbase_rfam_pmids)
manually_annotated = await manually_annotated_articles(rna_pmids)

# save to CSV
df = pd.DataFrame(manually_annotated + list_of_abstracts)
Expand Down
Binary file modified training/svc_pipeline.pkl
Binary file not shown.
22 changes: 11 additions & 11 deletions training/text_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ def main():
df = pd.read_csv("data.csv")
print(df["rna_related"].value_counts())
# rna_related
# 1 2159
# 0 2159
# 1 3347
# 0 3347

X = df["abstract"]
y = df["rna_related"]
Expand All @@ -44,20 +44,20 @@ def main():
print(f"CNB: {accuracy_score(y_test, predictCNB):.2f}")
print(f"SVC: {accuracy_score(y_test, predictSVC):.2f}")
print(f"RF: {accuracy_score(y_test, predictRF):.2f}")
# MNB: 0.96
# CNB: 0.96
# MNB: 0.94
# CNB: 0.94
# SVC: 0.99
# RF: 0.98
# RF: 0.96

print(classification_report(y_test, predictSVC))
# precision recall f1-score support
# precision recall f1-score support
#
# 0 0.99 1.00 0.99 417
# 1 1.00 0.99 0.99 447
# 0 0.99 0.98 0.98 669
# 1 0.98 0.99 0.99 670
#
# accuracy 0.99 864
# macro avg 0.99 0.99 0.99 864
# weighted avg 0.99 0.99 0.99 864
# accuracy 0.99 1339
# macro avg 0.99 0.99 0.99 1339
# weighted avg 0.99 0.99 0.99 1339

joblib.dump(pipeSVC, "svc_pipeline.pkl")

Expand Down

0 comments on commit c99b4f9

Please sign in to comment.