Skip to content

Commit

Permalink
Merge branch 'main' into test_new
Browse files Browse the repository at this point in the history
  • Loading branch information
jcharkow committed Sep 26, 2024
2 parents 07f3e43 + 0a12419 commit 15301b8
Show file tree
Hide file tree
Showing 8 changed files with 462 additions and 34 deletions.
8 changes: 4 additions & 4 deletions docs/gallery_scripts_template/plot_chromatogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@

# # Download test file

url = 'https://raw.githubusercontent.com/Roestlab/massdash/dev/test/test_data/featureMap/ionMobilityTestFeatureDf.tsv'
file_name = 'ionMobilityTestFeatureDf.tsv'
url = 'https://raw.githubusercontent.com/OpenMS/pyopenms_viz/main/test/test_data/ionMobilityTestChromatogramDf.tsv'
file_name = 'chromatogramDf.tsv'

# # Send a GET request to the URL
# # Send a GET request to the URL and handle potential errors
Expand All @@ -29,6 +29,6 @@
print(f"Error writing file: {e}")

# # Code to add annotation to ionMobilityTestFeatureDf data
df = pd.read_csv("./ionMobilityTestFeatureDf.tsv", sep="\t")
df.plot(kind="chromatogram", x="rt", y="int", by="Annotation")
df = pd.read_csv(file_name, sep="\t")
df.plot(kind="chromatogram", x="rt", y="int", by="Annotation", legend=dict(bbox_to_anchor=(1, 0.7)))

13 changes: 6 additions & 7 deletions docs/gallery_scripts_template/plot_mobilogram.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
PeakMap TEMPLATE
Mobilogram TEMPLATE
================
This example makes a simple plot
Expand All @@ -10,26 +10,25 @@

pd.options.plotting.backend = 'TEMPLATE'

# # Download test file
# Download test file

url = 'https://raw.githubusercontent.com/Roestlab/massdash/dev/test/test_data/featureMap/ionMobilityTestFeatureDf.tsv'
file_name = 'ionMobilityTestFeatureDf.tsv'

# # Send a GET request to the URL
# # Send a GET request to the URL and handle potential errors
# Send a GET request to the URL and handle potential errors
try:
response = requests.get(url)
response.raise_for_status() # Raises an HTTPError for bad responses

# # Save the content of the response to a file
# Save the content of the response to a file
with open(file_name, 'wb') as file:
file.write(response.content)
except requests.RequestException as e:
print(f"Error downloading file: {e}")
except IOError as e:
print(f"Error writing file: {e}")

# # Code to add annotation to ionMobilityTestFeatureDf data
# Code to add annotation to ionMobilityTestFeatureDf data
df = pd.read_csv("./ionMobilityTestFeatureDf.tsv", sep="\t")
df.plot(kind="mobilogram", x="im", y="int", by="Annotation")

df.plot(kind="mobilogram", x="im", y="int", by="Annotation", aggregate_duplicates=True, legend=dict(bbox_to_anchor=(1, 0.7)))
2 changes: 1 addition & 1 deletion docs/gallery_scripts_template/plot_peakmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,5 @@

# # Code to add annotation to ionMobilityTestFeatureDf data
df = pd.read_csv("./ionMobilityTestFeatureDf.tsv", sep="\t")
df.plot(kind="peakmap", x="rt", y="mz", z="int")
df.plot(kind="peakmap", x="rt", y="mz", z="int", aggregate_duplicates=True)

2 changes: 1 addition & 1 deletion docs/gallery_scripts_template/plot_peakmap_marginals.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,5 @@

# # Code to add annotation to ionMobilityTestFeatureDf data
df = pd.read_csv("./ionMobilityTestFeatureDf.tsv", sep="\t")
df.plot(kind="peakmap", x="rt", y="mz", z="int", add_marginals=True)
df.plot(kind="peakmap", x="rt", y="mz", z="int", add_marginals=True, aggregate_duplicates=True)

2 changes: 1 addition & 1 deletion docs/gallery_scripts_template/plot_spectrum_dia.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,6 @@
# # Code to add annotation to ionMobilityTestFeatureDf data
df = pd.read_csv("./ionMobilityTestFeatureDf.tsv", sep="\t")

df.plot(kind="spectrum", x="mz", y="int", custom_annotation='Annotation', annotate_mz=True, bin_method='none', annotate_top_n_peaks=5)
df.plot(kind="spectrum", x="mz", y="int", custom_annotation='Annotation', annotate_mz=True, bin_method='none', annotate_top_n_peaks=5, aggregate_duplicates=True)


329 changes: 329 additions & 0 deletions nbs/manuscript_spyogenes_subplots.ipynb

Large diffs are not rendered by default.

131 changes: 114 additions & 17 deletions pyopenms_viz/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
mz_tolerance_binning,
)
from .constants import IS_SPHINX_BUILD
import warnings


_common_kinds = ("line", "vline", "scatter")
Expand Down Expand Up @@ -140,6 +141,7 @@ def __init__(
line_width: float | None = None,
min_border: int | None = None,
show_plot: bool | None = None,
aggregate_duplicates: bool | None = None,
legend: LegendConfig | Dict | None = None,
feature_config: FeatureConfig | Dict | None = None,
_config: _BasePlotConfig | None = None,
Expand Down Expand Up @@ -178,6 +180,7 @@ def __init__(
self.line_width = line_width
self.min_border = min_border
self.show_plot = show_plot
self.aggregate_duplicates = aggregate_duplicates

self.legend = legend
self.feature_config = feature_config
Expand Down Expand Up @@ -218,6 +221,35 @@ def __init__(
self._load_extension()
self._create_figure()

def _check_and_aggregate_duplicates(self):
"""
Check if duplicate data is present and aggregate if specified.
Modifies self.data
"""

# get all columns except for intensity column (typically this is 'y' however is 'z' for peakmaps)
if self.kind in {"peakmap"}:
known_columns_without_int = [
col for col in self.known_columns if col != self.z
]
else:
known_columns_without_int = [
col for col in self.known_columns if col != self.y
]

if self.data[known_columns_without_int].duplicated().any():
if self.aggregate_duplicates:
self.data = (
self.data[self.known_columns]
.groupby(known_columns_without_int)
.sum()
.reset_index()
)
else:
warnings.warn(
"Duplicate data detected, data will not be aggregated which may lead to unexpected plots. To enable aggregation set `aggregate_duplicates=True`."
)

def _verify_column(self, colname: str | int, name: str) -> str:
"""fetch data from column name
Expand Down Expand Up @@ -266,6 +298,15 @@ def _kind(self) -> str:
"""
raise NotImplementedError

@property
def known_columns(self) -> List[str]:
"""
List of known columns in the data, if there are duplicates outside of these columns they will be grouped in aggregation if specified
"""
known_columns = [self.x, self.y]
known_columns.extend([self.by] if self.by is not None else [])
return known_columns

@property
def _interactive(self) -> bool:
"""
Expand Down Expand Up @@ -530,6 +571,10 @@ def __init__(
if relative_intensity:
self.data[y] = self.data[y] / self.data[y].max() * 100

self._check_and_aggregate_duplicates()
# sort data by x so in order
self.data.sort_values(by=x, inplace=True)

self.plot(self.data, self.x, self.y, **kwargs)

def plot(self, data, x, y, **kwargs):
Expand Down Expand Up @@ -599,6 +644,44 @@ class SpectrumPlot(BaseMSPlot, ABC):
def _kind(self):
return "spectrum"

@property
def known_columns(self) -> List[str]:
"""
List of known columns in the data, if there are duplicates outside of these columns they will be grouped in aggregation if specified
"""
known_columns = super().known_columns
known_columns.extend([self.peak_color] if self.peak_color is not None else [])
known_columns.extend(
[self.ion_annotation] if self.ion_annotation is not None else []
)
known_columns.extend(
[self.sequence_annotation] if self.sequence_annotation is not None else []
)
known_columns.extend(
[self.custom_annotation] if self.custom_annotation is not None else []
)
known_columns.extend(
[self.annotation_color] if self.annotation_color is not None else []
)
return known_columns

def _check_and_aggregate_duplicates(self):
super()._check_and_aggregate_duplicates()

if self.reference_spectrum is not None:
if self.reference_spectrum[self.known_columns].duplicated().any():
if self.aggregate_duplicates:
self.reference_spectrum = (
self.reference_spectrum[self.known_columns]
.groupby(self.known_columns)
.sum()
.reset_index()
)
else:
warnings.warn(
"Duplicate data detected in reference spectrum, data will not be aggregated which may lead to unexpected plots. To enable aggregation set `aggregate_duplicates=True`."
)

def __init__(
self,
data: DataFrame,
Expand Down Expand Up @@ -656,6 +739,8 @@ def __init__(
self.custom_annotation = custom_annotation
self.annotation_color = annotation_color

self._check_and_aggregate_duplicates()

self.plot(x, y, **kwargs)

def plot(self, x, y, **kwargs):
Expand Down Expand Up @@ -929,6 +1014,15 @@ class PeakMapPlot(BaseMSPlot, ABC):
def _kind(self):
return "peakmap"

@property
def known_columns(self) -> List[str]:
"""
List of known columns in the data, if there are duplicates outside of these columns they will be grouped in aggregation if specified
"""
known_columns = super().known_columns
known_columns.extend([self.z] if self.z is not None else [])
return known_columns

def __init__(
self,
data,
Expand All @@ -948,8 +1042,6 @@ def __init__(
fill_by_z: bool = True,
**kwargs,
) -> None:
# Copy data since it will be modified
data = data.copy()

# Set default config attributes if not passed as keyword arguments
kwargs["_config"] = _BasePlotConfig(kind=self._kind)
Expand All @@ -968,46 +1060,51 @@ def __init__(
else:
self.annotation_data = None

super().__init__(data, x, y, z=z, **kwargs)
self._check_and_aggregate_duplicates()

# Convert intensity values to relative intensity if required
relative_intensity = kwargs.pop("relative_intensity", False)
if relative_intensity:
data[z] = data[z] / max(data[z]) * 100
self.data[z] = self.data[z] / max(self.data[z]) * 100

# Bin peaks if required
if bin_peaks == True or (
data.shape[0] > num_x_bins * num_y_bins and bin_peaks == "auto"
self.data.shape[0] > num_x_bins * num_y_bins and bin_peaks == "auto"
):
data[x] = cut(data[x], bins=num_x_bins)
data[y] = cut(data[y], bins=num_y_bins)
self.data[x] = cut(self.data[x], bins=num_x_bins)
self.data[y] = cut(self.data[y], bins=num_y_bins)
by = kwargs.pop("by", None)
if by is not None:
# Group by x, y and by columns and calculate the sum intensity within each bin
data = (
data.groupby([x, y, by], observed=True)
self.data = (
self.data.groupby([x, y, by], observed=True)
.agg({z: aggregation_method})
.reset_index()
)
# Add by back to kwargs
kwargs["by"] = by
else:
# Group by x and y bins and calculate the sum intensity within each bin
data = (
data.groupby([x, y], observed=True)
self.data = (
self.data.groupby([x, y], observed=True)
.agg({z: aggregation_method})
.reset_index()
)
data[x] = data[x].apply(lambda interval: interval.mid).astype(float)
data[y] = data[y].apply(lambda interval: interval.mid).astype(float)
data = data.fillna(0)
self.data[x] = (
self.data[x].apply(lambda interval: interval.mid).astype(float)
)
self.data[y] = (
self.data[y].apply(lambda interval: interval.mid).astype(float)
)
self.data = self.data.fillna(0)

# Log intensity scale
if z_log_scale:
data[z] = log1p(data[z])
self.data[z] = log1p(self.data[z])

# Sort values by intensity in ascending order to plot highest intensity peaks last
data = data.sort_values(z)

super().__init__(data, x, y, z=z, **kwargs)
self.data.sort_values(z, inplace=True)

# If we do not want to fill/color based on z value, set to none prior to plotting
if not fill_by_z:
Expand Down
9 changes: 6 additions & 3 deletions pyopenms_viz/_matplotlib/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
from matplotlib.patches import Rectangle
from matplotlib.axes import Axes
from matplotlib.figure import Figure

from .._config import LegendConfig

Expand Down Expand Up @@ -222,9 +224,10 @@ def show_default(self):
"""
Show the plot.
"""
#### apply tight layout
fig = self.fig.get_figure()
fig.tight_layout()
if isinstance(self.fig, Axes):
self.fig.get_figure().tight_layout()
else:
self.superFig.tight_layout()
plt.show()


Expand Down

0 comments on commit 15301b8

Please sign in to comment.