Skip to content

Commit

Permalink
updates to ats_xdmf and plot_column_data to plot lines
Browse files Browse the repository at this point in the history
  • Loading branch information
ecoon committed Mar 4, 2025
1 parent 74b152a commit 72c7b7e
Show file tree
Hide file tree
Showing 2 changed files with 157 additions and 73 deletions.
115 changes: 97 additions & 18 deletions tools/utils/ats_xdmf.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,12 @@
import sys,os
import numpy as np
import h5py
import math
import matplotlib.collections
from matplotlib import pyplot as plt
import plot_lines

from numpy import s_ as s

def valid_data_filename(domain, format=None):
"""The filename for an HDF5 data filename formatter"""
Expand Down Expand Up @@ -198,11 +203,11 @@ def variable(self, vname):
if self.domain and '-' not in vname:
vname = self.domain + '-' + vname
return vname

def _get(self, vname, cycle):
"""Private get: assumes vname is fully resolved, and does not deal with maps."""
return self.d[vname][cycle][:,0]

def get(self, vname, cycle):
"""Access a data member.
Expand All @@ -225,7 +230,7 @@ def get(self, vname, cycle):
else:
return reorder(val, self.map)
return

def getArray(self, vname):
"""Access an array of all cycle values.
Expand All @@ -246,7 +251,6 @@ def getArray(self, vname):
else:
return reorder(val, self.map)


def loadMesh(self, cycle=None, order=None, shape=None, columnar=False, round=5):
"""Load and reorder centroids and volumes of mesh.
Expand All @@ -270,12 +274,13 @@ def loadMesh(self, cycle=None, order=None, shape=None, columnar=False, round=5):
if order is None and shape is None and not columnar:
self.map = None
self.centroids = centroids
self.ordering = None

else:
self.centroids, self.map = structuredOrdering(centroids, order, shape, columnar)
self.ordering, self.centroids, self.map = structuredOrdering(centroids, order, shape, columnar)

self.volume = self.get('cell_volume', cycle)

def loadMeshPolygons(self, cycle=None):
"""Load a mesh into 2D polygons."""
if cycle is None:
Expand All @@ -290,6 +295,67 @@ def getMeshPolygons(self, edgecolor='k', cmap='jet', linewidth=1):
polygons = matplotlib.collections.PolyCollection(self.polygon_coordinates, edgecolor=edgecolor, cmap=cmap, linewidths=linewidth)
return polygons

def plotLinesInTime(self, varname, spatial_slice=None, coordinate=None, time_slice=None, transpose=None, ax=None, colorbar_label=None, **kwargs):
"""Plot multiple lines, one for each slice in time, as a function of coordinate.
Parameters
----------
varname : str
The variable to plot
"""
# make sure time_slice is a slice
if time_slice is None:
time_slice = s[:]
elif isinstance(time_slice, int):
time_slice = s[::time_slice]
else:
time_slice = s[time_slice]

# slice centroids to get coordinate
if spatial_slice is None:
spatial_slice = [s[:],]

if coordinate is None:
coordinate = next(self.ordering[i] for i in range(len(spatial_slice)) if spatial_slice[i] == s[:])
if isinstance(coordinate, str):
if coordinate == 'x': coordinate = 0
elif coordinate == 'y': coordinate = 1
elif coordinate == 'z': coordinate = 2
elif coordinate == 'xy':
raise ValuerError("Cannot infer coordinate 'xy' -- likely this dataset was loaded with inconsistent ordering or you provided an invalid coordinate.")
coordinate_slice = spatial_slice + [s[coordinate],]
coords = self.centroids[*coordinate_slice]

# default transpose is True for z, False for others
if transpose is None:
if coordinate == 2: transpose = True
else: transpose = False

# slice data to get values
vals = self.getArray(varname)
vals_slicer = [time_slice,] + spatial_slice
vals = vals[*vals_slicer]

X = np.tile(coords, (vals.shape[0], 1))
Y = vals

if transpose:
X,Y = Y,X

if colorbar_label is None:
colorbar_label = f'{varname} in time [{self.time_unit}]'
ax, axcb = plot_lines.plotLines(X, Y, self.times[time_slice], ax=ax,
t_min=self.times[0], t_max=self.times[-1],
colorbar_label=colorbar_label, **kwargs)

# label x and y axes
xy_labels = (varname, ['x','y','z'][coordinate]+' [m]') if transpose else (['x','y','z'][coordinate]+' [m]', varname)
ax.set_xlabel(xy_labels[0])
ax.set_ylabel(xy_labels[1])

return ax, axcb


elem_type = {5:'QUAD',
8:'PRISM',
Expand Down Expand Up @@ -427,6 +493,8 @@ def structuredOrdering(coordinates, order=None, shape=None, columnar=False):
Returns
-------
ordering : List[str]
Order used to sort, e.g. ['x', 'y']
ordered_coordinates : np.ndarray
The re-ordered coordinates, shape (n_coordinates, dimension).
map : np.ndarray(int)
Expand All @@ -443,33 +511,34 @@ def structuredOrdering(coordinates, order=None, shape=None, columnar=False):
Sort a column of 100 unordered cells into a 1D sorted array. The
input and output are both of shape (100,3).
> ordered_centroids, map = structuredOrdering(centroids, list())
> order, ordered_centroids, map = structuredOrdering(centroids, list())
Sort a logically structured transect of size NX x NY x NZ =
(100,1,20), where x is structured and z may vary as a function of
x. Both input and output are of shape (2000, 3), but the output
is sorted with each column appearing sequentially and the
z-dimension fastest-varying. map is of shape (2000,).
z-dimension fastest-varying. map is of shape (2000,). The
returned order is ['z',].
> ordered_centroids, map = structuredOrdering(centroids, ['z',])
> order, ordered_centroids, map = structuredOrdering(centroids, ['z',])
Do the same, but this time reshape into a 2D array. Now the
ordered_centroids are of shape (100, 20, 3), and the map is of
shape (100, 20).
shape (100, 20). The returned order is ['z', 'xy'].
> ordered_centroids, map = structuredOrdering(centroids, ['z',], [20,])
> order, ordered_centroids, map = structuredOrdering(centroids, ['z',], [20,])
Do the same as above, but detect the shape. This works only
because the mesh is columnar.
because the mesh is columnar. The returned order is ['z', 'xy'].
> ordered_centroids, map = structuredOrdering(centroids, columnar=True)
> order, ordered_centroids, map = structuredOrdering(centroids, columnar=True)
Sort a 3D map-view "structured-in-z" mesh into arbitrarily-ordered
x and y columns. Assume there are 1000 map-view triangles, each
extruded 20 cells deep. The input is is of shape (20000, 3) and
the output is of shape (1000, 20, 3).
the output is of shape (1000, 20, 3). The returned order is ['z', 'xy'].
> ordered_centroids, map = structuredOrdering(centroids, columnar=True)
> order, ordered_centroids, map = structuredOrdering(centroids, columnar=True)
Note that map can be used with the reorder() function to place
data in this ordering.
Expand All @@ -478,7 +547,6 @@ def structuredOrdering(coordinates, order=None, shape=None, columnar=False):
if columnar:
order = ['x', 'y', 'z',]


# Surely there is a cleaner way to do this in numpy?
# The current approach packs, sorts, and unpacks.
if (coordinates.shape[1] == 3):
Expand All @@ -496,6 +564,8 @@ def structuredOrdering(coordinates, order=None, shape=None, columnar=False):
else:
ordered_coordinates = np.array([coords_a['x'], coords_a['y']]).transpose()

out_order = order

if columnar:
# try to guess the shape based on new-found contiguity
n_cells_in_column = 0
Expand All @@ -504,19 +574,28 @@ def structuredOrdering(coordinates, order=None, shape=None, columnar=False):
np.allclose(xy, ordered_coordinates[n_cells_in_column,0:2], 0., 1.e-5):
n_cells_in_column += 1
shape = [n_cells_in_column,]

out_order = ['xy', 'z']

if shape is not None:
new_shape = (-1,) + tuple(shape)
coord_shape = new_shape+(3,)
ordered_coordinates = np.reshape(ordered_coordinates, coord_shape)
map = np.reshape(map, new_shape)

if len(new_shape) == 3:
out_order = ['x', 'y', 'z']
elif len(new_shape) == 2:
if coordinates.shape[1] == 3:
out_order = ['xy', 'z']
else:
out_order = ['x', 'y']

if map.shape[0] == 1:
map = map[0]
ordered_coordinates = ordered_coordinates[0]
out_order = out_order[1:]

return ordered_coordinates, map
return out_order, ordered_coordinates, map


def reorder(data, map):
Expand Down
115 changes: 60 additions & 55 deletions tools/utils/plot_column_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from matplotlib import pyplot as plt
import matplotlib.cm
import colors
import plot_lines

import ats_xdmf, ats_units

Expand Down Expand Up @@ -150,27 +151,10 @@ def annotate(layout, axs, time_unit):
def plot_subsurface(vis, col, ax, label, color=None, cmap=None):
if cmap is None:
cmap = colors.alpha_cmap(color)

z = vis.centroids[:,2]

if len(vis.times) == 1:
cm = colors.cm_mapper(vis.times[0]-1, vis.times[0], cmap)
else:
cm = colors.cm_mapper(vis.times[0], vis.times[-1], cmap)

formats = ['-', '--', '-.']

for varname, form in zip(col, formats):
data = vis.getArray(varname)
assert(len(data.shape) == 2)
assert(data.shape[1] == len(vis.centroids))

for i,time in enumerate(vis.times):
mylabel = None
if i == len(vis.times)-1:
mylabel = label

ax.plot(data[i,:], z, form, color=cm(time), label=mylabel)
for i, (varname, form) in enumerate(zip(col, formats)):
vis.plotLinesInTime(varname, ax=ax, colorbar_ticks=(i == 0), cmap=cmap, linestyle=form, colorbar_label=f'{varname} of {label} in time')

def animate_subsurface(vis, varnames, ax, label, colors=None):
if type(colors) is str:
Expand Down Expand Up @@ -255,7 +239,60 @@ def s_to_i(s):
else:
return int(sl)

def plotColumnData(directories,
layout,
color_mode='runs',
color_sample='enumerated',
color_map='jet',
subsurface_time_slice=None,
surface_time_slice=None,
data_filename_format='ats_vis_{}_data.h5',
mesh_filename_format='ats_vis_{}_mesh.h5',
time_unit='d',
figsize=[5,3]):
"""Plots the column data given directories and a layout."""

if isinstance(layout, str):
layout = valid_layout(layout)

if color_mode == 'runs':
if color_sample == 'enumerated':
color_list = colors.enumerated_colors(len(directories))
else:
color_list = colors.sampled_colors(len(directories),
getattr(matplotlib.cm, color_map))
elif color_mode == 'time':
color_list = [color_map,]*len(directories)


fig = plt.figure(figsize=figsize)
axs = fig.subplots(len(layout), len(layout[0]), squeeze=False)

domains = set([domain_var(v)[0] for v in layout_flattener(layout)])
for dirname, color in zip(directories, color_list):
vis_objs = dict()
for domain in domains:
vis = ats_xdmf.VisFile(dirname, domain,
ats_xdmf.valid_data_filename(domain, data_filename_format),
ats_xdmf.valid_mesh_filename(domain, mesh_filename_format),
time_unit=time_unit)
if domain == '':
vis.loadMesh(columnar=True)
if subsurface_time_slice is not None:
vis.filterIndices(subsurface_time_slice)
else:
vis.loadMesh()
if surface_time_slice is not None:
vis.filterIndices(surface_time_slice)

vis_objs[domain] = vis

plot(vis_objs, layout, axs, dirname, color, color_mode)

annotate(layout, axs, time_unit)
return axs


if __name__ == '__main__':
import argparse
import colors
Expand Down Expand Up @@ -293,42 +330,10 @@ def s_to_i(s):


args = parser.parse_args()
if args.color_mode == 'runs':
if args.color_sample == 'enumerated':
color_list = colors.enumerated_colors(len(args.directories))
else:
color_list = colors.sampled_colors(len(args.directories),
getattr(matplotlib.cm, args.color_map))
elif args.color_mode == 'time':
color_list = [args.color_map,]*len(args.directories)



fig = plt.figure(figsize=args.figsize)
axs = fig.subplots(len(args.layout), len(args.layout[0]), squeeze=False)

domains = set([domain_var(v)[0] for v in layout_flattener(args.layout)])
for dirname, color in zip(args.directories, color_list):
vis_objs = dict()
for domain in domains:
vis = ats_xdmf.VisFile(dirname, domain,
ats_xdmf.valid_data_filename(domain, args.data_filename_format),
ats_xdmf.valid_mesh_filename(domain, args.mesh_filename_format),
time_unit=args.time_unit)
if domain == '':
vis.loadMesh(columnar=True)
if args.subsurface_time_slice is not None:
vis.filterIndices(args.subsurface_time_slice)
else:
vis.loadMesh()
if args.surface_time_slice is not None:
vis.filterIndices(args.surface_time_slice)

vis_objs[domain] = vis

plot(vis_objs, args.layout, axs, dirname, color, args.color_mode)

annotate(args.layout, axs, args.time_unit)
plotColumnData(args.directories, args.layout, args.color_mode, args.color_sample, args.color_map,
args.subsurface_time_slice, args.surface_time_slice,
args.data_filename_format, args.mesh_filename_format,
args.time_unit, args.figsize)
plt.show()
sys.exit(0)

Expand Down

0 comments on commit 72c7b7e

Please sign in to comment.