Skip to content

Commit

Permalink
fixed the shap h5 save bug that that wrongly saved a shap_scores[0] e…
Browse files Browse the repository at this point in the history
…vent when it is an array instead of an array, thereby wrongly broadcasting the same value for all the rows in the h5
  • Loading branch information
viramalingam committed Jan 8, 2024
1 parent f89e380 commit 4982f15
Showing 1 changed file with 10 additions and 2 deletions.
12 changes: 10 additions & 2 deletions bpnet/cli/shap_scores.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,11 @@ def data_func(model_inputs):

# save the hyp shap scores, one hot sequences & chrom positions
# to a HDF5 file
save_scores(peaks_df, X, counts_shap_scores[0], output_fname)
if isinstance(counts_shap_scores, list):
save_scores(peaks_df, X, counts_shap_scores[0], output_fname)
else:
save_scores(peaks_df, X, counts_shap_scores, output_fname)


logging.info("Generating 'profile' shap scores")
profile_shap_scores = profile_model_profile_explainer.shap_values(
Expand All @@ -342,8 +346,12 @@ def data_func(model_inputs):

# save the profile hyp shap scores, one hot sequences & chrom
# positions to a HDF5 file
save_scores(peaks_df, X, profile_shap_scores[0], output_fname)

if isinstance(counts_shap_scores, list):
save_scores(peaks_df, X, profile_shap_scores[0], output_fname)
else:
save_scores(peaks_df, X, profile_shap_scores, output_fname)

# save the dataframe as a new .bed file
peaks_df.to_csv('{}/peaks_valid_scores.bed'.format(shap_dir),
sep='\t', header=False, index=False)
Expand Down

0 comments on commit 4982f15

Please sign in to comment.