-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy path01_preprocess_data.py
158 lines (140 loc) · 4.94 KB
/
01_preprocess_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
"""Load and prepare toxcast data provided by molecule net.""" # pylint: disable=invalid-name
from pathlib import Path
import pandas as pd
from loguru import logger
from molpipeline.any2mol import SmilesToMol
from molpipeline.error_handling import ErrorFilter, FilterReinserter
from molpipeline.mol2any import MolToSmiles
from molpipeline.mol2mol import (
EmptyMoleculeFilter,
FragmentDeduplicator,
MetalDisconnector,
MixtureFilter,
SaltRemover,
StereoRemover,
TautomerCanonicalizer,
Uncharger,
)
from molpipeline.mol2mol.filter import ElementFilter
from molpipeline.pipeline import Pipeline
def get_standardization_pipeline(n_jobs: int = -1) -> Pipeline:
"""Get the standardization pipeline.
Parameters
----------
n_jobs: int, optional (default=-1)
The number of jobs to use for standardization.
In case of -1, all available CPUs are used.
Returns
-------
Pipeline
The standardization pipeline.
"""
element_filter = ElementFilter(
allowed_element_numbers=[
1, # H
3, # Li
5, # B
6, # C
7, # N
8, # O
9, # F
11, # Na
12, # Mg
14, # Si
15, # P
16, # S
17, # Cl
19, # K
20, # Ca
34, # Se
35, # Br
53, # I
],
)
error_filter = ErrorFilter(filter_everything=True)
# Set up pipeline
standardization_pipeline = Pipeline(
[
("smi2mol", SmilesToMol()),
("element_filter", element_filter),
("metal_disconnector", MetalDisconnector()),
("salt_remover", SaltRemover()),
("uncharge1", Uncharger()),
("canonical_tautomer", TautomerCanonicalizer()),
("uncharge2", Uncharger()),
("stereo_remover", StereoRemover()),
("fragment_deduplicator", FragmentDeduplicator()),
("mixture_remover", MixtureFilter()),
("empty_molecule_remover", EmptyMoleculeFilter()),
("mol2smi", MolToSmiles()),
("error_filter", error_filter),
("error_replacer", FilterReinserter.from_error_filter(error_filter, None)),
],
n_jobs=n_jobs,
)
return standardization_pipeline
def main() -> None:
"""Run the preprocessing procedure."""
# Set up path variables
base_path = Path(__file__).parents[1]
data_path = base_path / "data"
Path(base_path / "logs").mkdir(exist_ok=True)
logger.add(base_path / "logs" / "01_preprocess_data.log")
# Set up pipeline
standardization_pipeline = get_standardization_pipeline(n_jobs=-1)
# Load data
data_df = pd.read_csv(data_path / "imported_data" / "toxcast_data.csv.gz")
logger.info(
f"Loaded data with {data_df.shape[0]} rows and {data_df.shape[1]} columns."
)
data_df["standardized_smiles"] = standardization_pipeline.fit_transform(
data_df["smiles"].tolist()
)
logger.warning(
f"Standardization failed for {data_df.standardized_smiles.isna().sum()} SMILeS."
)
logger.warning("Dropping rows with failed standardization.")
data_df = data_df.dropna(subset=["standardized_smiles"])
# Define endpoint list
endpoint_list = data_df.columns.tolist()
endpoint_list.remove("smiles")
endpoint_list.remove("standardized_smiles")
logger.info(f"Found {len(endpoint_list)} endpoints.")
# Create sparse dataframe, each row is a (smiles, endpoint) pair with a binary label
sparse_df = data_df.melt(
id_vars="standardized_smiles",
value_vars=endpoint_list,
value_name="label",
var_name="endpoint",
)
# Remove rows with missing labels
sparse_df = sparse_df.query("label.notna()").copy()
# Cast labels from float to int
sparse_df["label"] = sparse_df["label"].astype(int)
# Check each combination of endpoint and SMILES for conflicting labels
prefinal_list = []
for (endpoint, smiles), grp_df in sparse_df.groupby(
["endpoint", "standardized_smiles"]
):
unique_label = grp_df.label.unique()
# Only keep SMILES with a single label for a given endpoint
if len(unique_label) == 1:
prefinal_list.append(
{"endpoint": endpoint, "smiles": smiles, "label": unique_label[0]}
)
final_df = pd.DataFrame(prefinal_list)
output_path = data_path / "intermediate_data"
output_path.mkdir(parents=True, exist_ok=True)
final_df.to_csv(output_path / "ml_ready_data.tsv", sep="\t", index=False)
summary_df = final_df.pivot_table(
index="endpoint",
columns="label",
values="smiles",
aggfunc="nunique",
fill_value=0,
)
summary_df["total"] = summary_df.sum(axis=1)
summary_df.sort_values("total", ascending=False, inplace=True)
logger.info(f"\n{summary_df}")
if __name__ == "__main__":
main()