'''
===========================
Simulation Output Utilities
===========================
'''
import os
from typing import Any, Dict, Optional
import random
import numpy as np
import matplotlib.pyplot as plt
from vivarium.core.emitter import path_timeseries_from_embedded_timeseries
from vivarium.library.dict_utils import get_value_from_path
[docs]def set_axes(
ax,
show_xaxis=False,
sci_notation=False,
y_offset=0.0):
'''Set up plot axes.
Args:
ax: The axes to set up.
show_xaxis: Whether to show the x axis.
sci_notation: Either ``False`` to not use scientific notation or
an integer :math:`x` such that scientific notation will be
used outside the range :math:`[10^{-x}, 10^x]`.
y_offset: Horizontal distance between axis offset text
(typically for scientific notation) and the y-axis.
'''
if sci_notation:
scilimits = 4
if isinstance(sci_notation, int):
scilimits = sci_notation
ax.ticklabel_format(
style='sci',
axis='y',
scilimits=(-scilimits, scilimits),
useOffset=True)
else:
ax.ticklabel_format(
style='plain',
axis='y')
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
ax.tick_params(right=False, top=False)
# move offset axis text (typically scientific notation)
t = ax.yaxis.get_offset_text()
t.set_x(-y_offset)
if not show_xaxis:
ax.spines['bottom'].set_visible(False)
ax.tick_params(bottom=False, labelbottom=False)
def _save_fig_to_dir(
fig,
filename,
out_dir='out/',
):
os.makedirs(out_dir, exist_ok=True)
fig_path = os.path.join(out_dir, filename)
print(f"Writing {fig_path}")
fig.savefig(fig_path, bbox_inches='tight')
[docs]def plot_simulation_output(
timeseries_raw,
settings: Optional[Dict[str, Any]] = None,
out_dir=None,
filename='simulation',
):
'''
Plot simulation output, with rows organized into separate columns.
Arguments::
timeseries (dict): This can be obtained from simulation output with convert_to_timeseries()
settings (dict): Accepts the following keys:
* **column_width** (:py:class:`int`): the width (inches) of
each column in the figure
* **max_rows** (:py:class:`int`): ports with more states
than this number of states get wrapped into a new column
* **remove_zeros** (:py:class:`bool`): if True, timeseries
with all zeros get removed
* **remove_flat** (:py:class:`bool`): if True, timeseries
with all the same value get removed
* **remove_first_timestep** (:py:class:`bool`): if True,
skips the first timestep
* **skip_ports** (:py:class:`list`): entire ports that won't
be plotted
* **show_state** (:py:class:`list`): with
``[('port_id', 'state_id')]`` for all states that will be
highlighted, even if they are otherwise to be removed
TODO: Obsolete?
'''
int_or_float = (int, np.int32, np.int64, float, np.float32, np.float64)
settings = settings or {}
plot_fontsize = 8
plt.rc('font', size=plot_fontsize)
plt.rc('axes', titlesize=plot_fontsize)
# get settings
column_width = settings.get('column_width', 3)
max_rows = settings.get('max_rows', 25)
remove_zeros = settings.get('remove_zeros', True)
remove_flat = settings.get('remove_flat', False)
skip_ports = settings.get('skip_ports', [])
remove_first_timestep = settings.get('remove_first_timestep', False)
# make a flat 'path' timeseries, with keys being path
top_level = list(timeseries_raw.keys())
timeseries = path_timeseries_from_embedded_timeseries(timeseries_raw)
time_vec = timeseries.pop('time')
if remove_first_timestep:
time_vec = time_vec[1:]
# remove select states from timeseries
removed_states = set()
for path, series in timeseries.items():
if path[0] in skip_ports:
removed_states.add(path)
elif remove_flat:
if series.count(series[0]) == len(series):
removed_states.add(path)
elif remove_zeros:
if all(v == 0 for v in series):
removed_states.add(path)
for path in removed_states:
del timeseries[path]
# get figure columns
# get length of each top-level port
port_lengths = {}
for path in timeseries.keys():
if path[0] in top_level:
if path[0] not in port_lengths:
port_lengths[path[0]] = 0
port_lengths[path[0]] += 1
n_data = [length for port, length in port_lengths.items() if length > 0]
columns = []
for n_states in n_data:
new_cols = n_states / max_rows
if new_cols > 1:
for col in range(int(new_cols)):
columns.append(max_rows)
mod_states = n_states % max_rows
if mod_states > 0:
columns.append(mod_states)
else:
columns.append(n_states)
# make figure and plot
n_cols = len(columns)
n_rows = max(columns)
fig = plt.figure(figsize=(n_cols * column_width, n_rows * column_width/3))
grid = plt.GridSpec(n_rows, n_cols)
row_idx = 0
col_idx = 0
for port in port_lengths.keys():
# get this port's timeseries
port_timeseries = {}
for path, ts in timeseries.items():
if path[0] is port:
next_path = path[1:]
if any(isinstance(item, tuple) for item in next_path):
next_path = tuple([
item[0] if isinstance(item, tuple) else item
for item in next_path])
port_timeseries[next_path] = ts
for state_id, series in sorted(port_timeseries.items()):
if remove_first_timestep:
series = series[1:]
# not enough data points -- this state likely did not exist throughout the entire simulation
if len(series) != len(time_vec):
continue
ax = fig.add_subplot(grid[row_idx, col_idx]) # grid is (row, column)
if not all(isinstance(state, int_or_float) for state in series):
# check if series is a list of ints or floats
ax.title.set_text(str(port) + ': ' + str(state_id) + ' (non numeric)')
else:
# plot line at zero if series crosses the zero line
if any(x == 0.0 for x in series) or (any(x < 0.0 for x in series) and any(x > 0.0 for x in series)):
zero_line = [0 for t in time_vec]
ax.plot(time_vec, zero_line, 'k--')
# plot the series
ax.plot(time_vec, series)
if isinstance(state_id, tuple):
# new line for each store
state_id = '\n'.join(state_id)
ax.title.set_text(str(port) + '\n' + str(state_id))
if row_idx == columns[col_idx]-1:
# if last row of column
set_axes(ax, True)
ax.set_xlabel('time (s)')
row_idx = 0
col_idx += 1
else:
set_axes(ax)
row_idx += 1
ax.set_xlim([time_vec[0], time_vec[-1]])
if out_dir:
plt.subplots_adjust(wspace=column_width/3, hspace=column_width/3)
_save_fig_to_dir(fig, filename, out_dir)
return fig
[docs]def get_variable_title(path):
'''Get figure title from a variable path.
Args:
path: Path to the variable.
Returns:
String representation of the variable suitable for a figure
title.
'''
var = path[-1]
separator = '>'
connect_path = separator.join(path[:-1])
if isinstance(var, tuple): # if units are included in variable
title = f'{connect_path}: {var[0]} ({var[1]})'
else:
title = f'{connect_path}: {var}'
return title
# simple plotting function
[docs]def plot_variables(
output,
variables,
column_width=8,
row_height=1.2,
row_padding=0.8,
linewidth=3.0,
sci_notation=False,
default_color='tab:blue',
out_dir=None,
filename='variables'
):
'''Create a simple figure with a timeseries for every variable.
Args:
output: Simulation output as a map from variable names or paths
to timeseries data. Should contain a ``time`` key whose
value is a list of time points.
variables: The variables to plot. May be a list of variable
names (if simulation output keys are just variable names) or
a dictionary with keys ``variable`` (for the variable path),
``color`` (for the color to use for the plot), and
``display`` (the variable name to display). If ``display``
is not provided, the result of calling
:py:func:`get_variable_title` on the variable path is used.
column_width: Figure width.
row_height: Height of each row. Each variable gets one row.
row_padding: Space between rows.
linewidth: Width of timeseries lines.
sci_notation: Either ``False`` for no scientific notation or an
integer :math:`x` such that scientific notation will be used
for values outside the range :math:`[10^{-x}, 10^x]`.
default_color: Default timeseries color.
out_dir: Output directory.
filename: Output filename.
Returns:
The figure.
'''
n_rows = len(variables)
fig = plt.figure(figsize=(column_width, n_rows * row_height))
grid = plt.GridSpec(n_rows, 1)
time_vec = output['time']
for row_idx, variable_definition in enumerate(variables):
if isinstance(variable_definition, dict):
path = variable_definition['variable']
var_color = variable_definition.get('color', default_color)
variable_title = variable_definition.get('display', get_variable_title(path))
else:
path = variable_definition
var_color = default_color
variable_title = get_variable_title(path)
# get the output timeseries
series = get_value_from_path(output, path)
# make a new subplot
ax = fig.add_subplot(grid[row_idx, 0])
ax.plot(time_vec, series, linewidth=linewidth, color=var_color)
ax.set_title(variable_title)
# x-axis only at bottom row
if row_idx == n_rows - 1:
set_axes(ax, show_xaxis=True, sci_notation=sci_notation)
ax.set_xlabel('time (s)')
ax.spines['bottom'].set_position(('axes', -0.2))
else:
set_axes(ax, sci_notation=sci_notation)
fig.subplots_adjust(hspace=row_padding)
if out_dir:
_save_fig_to_dir(fig, filename, out_dir)
return fig
if __name__ == '__main__':
total_time = 500
data = {
'x': [random.uniform(0, 10) for _ in range(total_time)],
'y': [np.exp(random.uniform(0, 10)) for _ in range(total_time)],
'time': [t for t in range(total_time)]
}
plot_variables(
output=data,
variables=['x', 'y'],
out_dir='out')