Skip to content

Commit

Permalink
fix issues
Browse files Browse the repository at this point in the history
  • Loading branch information
FlorianLeRoyKili committed Jun 15, 2022
1 parent 1163c0f commit a09b340
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 19 deletions.
12 changes: 6 additions & 6 deletions kiliautoml/utils/mapper/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def __init__(
job: JobT,
job_name: str,
assets_repository, # check type
predictions: List[Any],
predictions: Any,
focus_class: Optional[List[str]],
):
self.job = job
Expand Down Expand Up @@ -226,7 +226,7 @@ def _get_assignments_and_lens_with_labels(self):

label_id_array = [self.cat2id[label] for label in self.labels]
prediction_true_class = [
self.predictions[enum][item] for enum, item in enumerate(label_id_array)
self.predictions[enum, item] for enum, item in enumerate(label_id_array)
]
predicted_order = np.argsort(self.predictions, axis=1)
predicted_class = predicted_order[:, -1]
Expand Down Expand Up @@ -274,14 +274,14 @@ def _get_custom_tooltip(self):
# with labels available
if len(self.lens_names) == 5:
return custom_tooltip_picture(
np.column_stack((self.lens[:, 1], self.lens[:, 3])),
np.column_stack((self.lens[:, 0], self.lens[:, 2])),
pict_data_type="img_list",
image_list=self.data,
)
# without labels available
else:
return custom_tooltip_picture(
self.lens[:, 1],
self.lens[:, 0],
pict_data_type="img_list",
image_list=self.data,
)
Expand All @@ -290,13 +290,13 @@ def _get_custom_tooltip(self):
# with labels available
if len(self.lens_names) == 5:
return custom_tooltip_text(
np.column_stack((self.lens[:, 1], self.lens[:, 3])),
np.column_stack((self.lens[:, 0], self.lens[:, 2])),
data=self.data,
)
# without labels available
else:
return custom_tooltip_text(
self.lens[:, 1],
self.lens[:, 0],
data=self.data,
)
else:
Expand Down
29 changes: 16 additions & 13 deletions mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import List, Optional

import click
import numpy as np
import pandas as pd
from kili.client import Kili
from tabulate import tabulate
Expand Down Expand Up @@ -162,34 +163,36 @@ def main(
ncol = first_line.count(",") + 1
nrows = len(next_lines) + 1

if ncol == len(job["content"]["categories"]):
index_col = None
if ncol == len(job["content"]["categories"]) + 1:
index_col = 0
header = None
elif ncol == len(job["content"]["categories"]):
index_col = None
if nrows == len(assets):
header = None
elif ncol == len(assets) + 1:
header = 0
else:
raise ValueError(
"When there is no index column in csv file with predictions"
"the number of row has to be equal to the number of assets"
"or the number of assets + 1 if there is a header"
)
else:
raise ValueError(
"Number of column in predictions should be either "
"the number of category of the number of category + 1 for the external id"
)

if nrows == len(assets):
header = None
elif ncol == len(assets) + 1:
header = 0
else:
raise ValueError(
"Number of rows in predictions should be either "
"the number of assets of the number of assets + 1 for the header"
)

predictions_df = pd.read_csv(predictions_path, index_col=index_col, header=header)

if index_col is None:
predictions = list(predictions_df.to_numpy())
predictions = predictions_df.to_numpy()
else:
predictions = []
for asset in assets:
predictions.append(predictions_df.loc[asset["externalId"]].to_numpy())
predictions = np.array(predictions)

mapper_image_classification = MapperClassification(
api_key=api_key,
Expand Down

0 comments on commit a09b340

Please sign in to comment.