Skip to content

Commit

Permalink
Hot fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
agdiaz committed Jan 3, 2025
1 parent c07bf89 commit eb175fc
Show file tree
Hide file tree
Showing 3 changed files with 156 additions and 100 deletions.
15 changes: 10 additions & 5 deletions parrot/brnn_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@

import matplotlib as mpl

mpl.use('Agg')

mpl.use("Agg")
import mpl.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sn
Expand All @@ -40,8 +40,6 @@
mpl.rcParams["font.size"] = 12
mpl.rcParams["lines.linewidth"] = 2

import matplotlib.pyplot as plt


def training_loss(train_loss, val_loss, output_file_prefix=""):
"""Plot training and validation loss per epoch
Expand Down Expand Up @@ -112,6 +110,7 @@ def training_loss(train_loss, val_loss, output_file_prefix=""):
pad_inches=0.1,
)
plt.clf()
plt.close()


def sequence_regression_scatterplot(true, predicted, output_file_prefix=""):
Expand Down Expand Up @@ -178,6 +177,7 @@ def sequence_regression_scatterplot(true, predicted, output_file_prefix=""):
pad_inches=0.1,
)
plt.clf()
plt.close()


def residue_regression_scatterplot(true, predicted, output_file_prefix=""):
Expand Down Expand Up @@ -255,6 +255,7 @@ def residue_regression_scatterplot(true, predicted, output_file_prefix=""):
pad_inches=0.1,
)
plt.clf()
plt.close()


def plot_roc_curve(
Expand Down Expand Up @@ -366,6 +367,7 @@ def plot_roc_curve(
pad_inches=0.1,
)
plt.clf()
plt.close()


def plot_precision_recall_curve(
Expand Down Expand Up @@ -412,7 +414,7 @@ def plot_precision_recall_curve(
)

# Plot
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(8, 8))
_fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(8, 8))
ax.plot(
recall["micro"],
precision["micro"],
Expand Down Expand Up @@ -468,6 +470,7 @@ def plot_precision_recall_curve(
pad_inches=0.1,
)
plt.clf()
plt.close()


def confusion_matrix(
Expand Down Expand Up @@ -525,6 +528,7 @@ class label for a particular sequence
pad_inches=0.1,
)
plt.clf()
plt.close()


def res_confusion_matrix(
Expand Down Expand Up @@ -595,6 +599,7 @@ def res_confusion_matrix(
pad_inches=0.1,
)
plt.clf()
plt.close()


def write_performance_metrics(
Expand Down
Loading

0 comments on commit eb175fc

Please sign in to comment.