abstochkin.graphing

Graphing for AbStochKin simulations.

  1"""  Graphing for AbStochKin simulations. """
  2
  3#  Copyright (c) 2024-2025, Alex Plakantonakis.
  4#
  5#  This program is free software: you can redistribute it and/or modify
  6#  it under the terms of the GNU General Public License as published by
  7#  the Free Software Foundation, either version 3 of the License, or
  8#  (at your option) any later version.
  9#
 10#  This program is distributed in the hope that it will be useful,
 11#  but WITHOUT ANY WARRANTY; without even the implied warranty of
 12#  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 13#  GNU General Public License for more details.
 14#
 15#  You should have received a copy of the GNU General Public License
 16#  along with this program.  If not, see <http://www.gnu.org/licenses/>.
 17
 18from pathlib import Path
 19from typing import Literal, Self
 20
 21import numpy as np
 22import matplotlib.pyplot as plt
 23import matplotlib.ticker as ticker
 24import plotly.graph_objects as go
 25
 26
 27class Graph:
 28    """
 29    Graphing class for displaying the results of AbStochKin simulations.
 30
 31    Notes
 32    -----
 33    To successfully use the LaTeX engine for rendering text on Linux,
 34    run the following command in a terminal: `sudo apt install cm-super`.
 35    """
 36
 37    # First, set some global matplotlib settings
 38    plt.rcParams['figure.autolayout'] = True  # tight layout
 39    plt.rcParams["figure.facecolor"] = 'lightgray'
 40    plt.rcParams['text.usetex'] = True
 41    plt.rcParams['axes.titlesize'] = 9
 42    plt.rcParams['axes.labelsize'] = 7
 43    plt.rcParams['xtick.labelsize'] = 6
 44    plt.rcParams['ytick.labelsize'] = 6
 45    plt.rcParams["legend.fontsize"] = 7
 46    plt.rcParams["legend.framealpha"] = 0.65
 47
 48    def __init__(self,
 49                 /,
 50                 nrows=1,
 51                 ncols=1,
 52                 figsize=(5, 5),
 53                 dpi=300,
 54                 *,
 55                 backend: Literal['matplotlib', 'plotly'] = 'matplotlib',
 56                 **kwargs):
 57        self.backend = backend
 58
 59        if self.backend == 'matplotlib':
 60            self.fig, self.ax = plt.subplots(nrows=nrows, ncols=ncols,
 61                                             figsize=figsize, dpi=dpi,
 62                                             **kwargs)
 63        elif self.backend == 'plotly':
 64            self.fig = go.Figure()
 65            self.fig.update_layout(
 66                width=figsize[0] * dpi * 0.6,  # Convert figsize to pixels
 67                height=figsize[1] * dpi * 0.6,  # Convert figsize to pixels
 68            )
 69            self.fig2 = None  # optional second figure
 70        else:
 71            raise ValueError(f"Unknown backend: {self.backend}. "
 72                             f"Please choose from 'matplotlib' (default), 'plotly'.")
 73
 74    def setup_spines_ticks(self, ax_loc):
 75        """
 76        Set up the spines and ticks in a `matplotlib` graph.
 77        Make only the left and bottom spines/axes visible on the graph
 78        and place major ticks on them. Also set the minor ticks.
 79        """
 80        if self.backend == 'matplotlib':
 81            axs = self.ax if len(ax_loc) == 0 else self.ax[ax_loc]
 82            axs.spines[['left']].set_position('zero')
 83            axs.spines[['top', 'right']].set_visible(False)
 84            axs.xaxis.set_ticks_position('bottom')
 85            axs.yaxis.set_ticks_position('left')
 86            axs.xaxis.set_minor_locator(ticker.AutoMinorLocator())
 87            axs.yaxis.set_minor_locator(ticker.AutoMinorLocator())
 88
 89        else:  # self.backend == 'plotly'
 90
 91            self.fig.update_layout(
 92                xaxis=dict(zeroline=True, showline=True, linewidth=2,
 93                           zerolinecolor="black", showticklabels=True,
 94                           tickmode='auto', ticks='outside'),
 95                yaxis=dict(zeroline=True, showline=True, linewidth=2,
 96                           zerolinecolor="black", showticklabels=True,
 97                           tickmode='auto', ticks='outside')
 98            )
 99
100    def plot_ODEs(self,
101                  de_data,
102                  *,
103                  num_pts: int = 1000,
104                  species: list[str] | tuple[str] = (),
105                  show_plot: bool = True,
106                  ax_loc: tuple = ()
107                  ) -> Self:
108        """
109        Plot the deterministic trajectories of all species obtained
110        by obtaining the solution to a system of ODEs.
111
112        Parameters
113        ----------
114        de_data : DEcalcs object
115                 Data structure containing all the data related to
116                 solving the system of ODEs.
117
118        num_pts : int, default: 1000, optional
119                 Number of points used to calculate DE curves at.
120                 Used to approximate a smooth/continuous curve.
121
122        species : sequence of strings, default: (), optional
123                 An iterable sequence of strings specifying the species
124                 names to plot. If no species are specified (the default),
125                 then all species trajectories are plotted.
126
127        show_plot : bool, default: True, optional
128                 If True, show the plot.
129
130        ax_loc : tuple, optional
131                If the figure is made up of subplots, specify the location
132                of the axis to draw the data at.
133                Ex: for two subplots, the possible values of `ax_loc`
134                are (0, ) and (1, ). That's because the `self.ax` object is
135                a 1-D array. For figures with multiple rows and columns of
136                subplots, a 2-D tuple is needed.
137        """
138        species = list(de_data.odes.keys()) if len(species) == 0 else species
139        # t, y = ode_sol.t, ode_sol.y.T  # values at precomputed time pts
140        t = np.linspace(de_data.odes_sol.t[0], de_data.odes_sol.t[-1],
141                        num_pts)  # time points for obtaining...
142        y = de_data.odes_sol.sol(t).T  # an approximately continuous solution
143
144        self.setup_spines_ticks(ax_loc)
145
146        if self.backend == 'matplotlib':
147            axs = self.ax if len(ax_loc) == 0 else self.ax[ax_loc]
148            axs.set(xlim=(0, de_data.odes_sol.t[-1]))
149
150            for i, sp in enumerate(list(de_data.odes.keys())):
151                if sp in species:
152                    axs.plot(t, y[:, i], label=f"${sp}_{{DE}}$",
153                             linestyle='--', linewidth=0.75, alpha=0.75)
154
155            # axs.set(title="Deterministic trajectories")
156            axs.set(xlabel=f"$t$ ({de_data.time_unit})", ylabel="$N$")
157            axs.legend(loc='upper right')
158            self.fig.tight_layout()
159
160            if show_plot:
161                plt.show()
162
163        else:  # self.backend == 'plotly'
164
165            for i, sp in enumerate(list(de_data.odes.keys())):
166                if sp in species:
167                    self.fig.add_trace(
168                        go.Scatter(
169                            x=t.tolist(),
170                            y=y[:, i].tolist(),
171                            mode='lines',
172                            name=f"${sp}_{{DE}}$",
173                            line=dict(dash='dash', width=1)
174                        )
175                    )
176
177            self.fig.update_layout(
178                xaxis=dict(title=f"$t \\, ({de_data.time_unit})$",
179                           range=[-0.01 * de_data.odes_sol.t[-1], de_data.odes_sol.t[-1]]),
180                yaxis=dict(title="$N$"),
181            )
182
183            if show_plot:
184                self.fig.show()
185
186        return self
187
188    def plot_trajectories(self,
189                          time,
190                          data,
191                          *,
192                          species: list[str] | tuple[str] = (),
193                          show_plot: bool = True,
194                          ax_loc: tuple = ()
195                          ) -> Self:
196        """ Graph simulation time trajectories. """
197        self.setup_spines_ticks(ax_loc)
198        species = list(data.keys()) if len(species) == 0 else species
199
200        if self.backend == 'matplotlib':
201            axs = self.ax if len(ax_loc) == 0 else self.ax[ax_loc]
202            axs.set(xlim=(0, time[-1]))
203
204            for sp, sp_data in data.items():
205                if sp in species:
206                    trajs = sp_data['N'].T
207                    for traj in trajs:
208                        axs.plot(time, traj, linewidth=0.25)
209
210            axs.set(title="ABK trajectories")
211            axs.set(xlabel=f"$t$ (sec)", ylabel="$N$")
212            # axs.legend(loc='best')
213
214            if show_plot:
215                plt.show()
216
217        else:  # self.backend == 'plotly'
218
219            for sp, sp_data in data.items():
220                if sp in species:
221                    trajs = sp_data['N'].T
222                    for i, traj in enumerate(trajs):
223                        self.fig.add_trace(
224                            go.Scatter(
225                                x=time.tolist(),
226                                y=traj.tolist(),
227                                mode='lines',
228                                name=f"${sp} \\; Run \\, {i}$",
229                                line=dict(width=0.5)
230                            )
231                        )
232            self.fig.update_layout(
233                xaxis=dict(title=f"$t \\; (sec)$",
234                           range=[-0.01 * time[-1], time[-1]]),
235                yaxis=dict(title="$N$")
236            )
237
238            if show_plot:
239                self.fig.show()
240
241        return self
242
243    def plot_avg_std(self,
244                     time,
245                     data,
246                     *,
247                     species: list[str] | tuple[str] = (),
248                     show_plot: bool = True,
249                     ax_loc: tuple = ()
250                     ) -> Self:
251        """
252        Graph simulation average trajectories and
253        1-standard-deviation envelopes.
254        """
255        self.setup_spines_ticks(ax_loc)
256        species = list(data.keys()) if len(species) == 0 else species
257
258        if self.backend == 'matplotlib':
259            axs = self.ax if len(ax_loc) == 0 else self.ax[ax_loc]
260            axs.set(xlim=(0, time[-1]))
261
262            for sp, sp_data in data.items():
263                if sp in species:
264                    axs.plot(time, sp_data['N_avg'],
265                             linewidth=1.5, label=f"$<{sp}>$", alpha=0.5)
266                    axs.fill_between(time,
267                                     sp_data['N_avg'] - sp_data['N_std'],
268                                     sp_data['N_avg'] + sp_data['N_std'],
269                                     alpha=0.5, linewidth=0)
270
271            axs.set(xlabel="$t$ (sec)", ylabel="$N$")
272            axs.legend(loc='upper right')
273            self.fig.tight_layout()
274
275            if show_plot:
276                plt.show()
277
278        else:  # self.backend == 'plotly'
279
280            for sp, sp_data in data.items():
281                if sp in species:
282                    self.fig.add_trace(
283                        go.Scatter(
284                            x=time.tolist(),
285                            y=sp_data['N_avg'].tolist(),
286                            mode='lines',
287                            name=f"$<{sp}>$",
288                            line=dict(width=2)
289                        )
290                    )
291
292                    self.fig.add_trace(
293                        go.Scatter(
294                            x=time.tolist(),
295                            y=(sp_data['N_avg'] + sp_data['N_std']).tolist(),
296                            mode='lines',
297                            name=f"$<{sp}> + \\sigma$",
298                            line=dict(width=0),
299                            showlegend=False
300                        )
301                    )
302
303                    self.fig.add_trace(
304                        go.Scatter(
305                            x=time.tolist(),
306                            y=(sp_data['N_avg'] - sp_data['N_std']).tolist(),
307                            mode='lines',
308                            line=dict(width=0),
309                            name=f"$<{sp}> - \\sigma$",
310                            fill='tonexty',
311                            showlegend=False
312                        )
313                    )
314
315            self.fig.update_layout(
316                xaxis=dict(title=f"$t \\; (sec)$",
317                           range=[-0.01 * time[-1], time[-1]]),
318                yaxis=dict(title="$N$")
319            )
320
321            if show_plot:
322                self.fig.show()
323
324        return self
325
326    def plot_eta(self,
327                 time,
328                 data,
329                 *,
330                 species: list[str] | tuple[str] = (),
331                 show_plot: bool = True,
332                 ax_loc: tuple = ()
333                 ) -> Self:
334        """ Graph the coefficient of variation. """
335        self.setup_spines_ticks(ax_loc)
336        species = list(data.keys()) if len(species) == 0 else species
337
338        if self.backend == "matplotlib":
339            axs = self.ax if len(ax_loc) == 0 else self.ax[ax_loc]
340            axs.set(xlim=(0, time[-1]))
341
342            for sp, sp_data in data.items():
343                if sp in species:
344                    axs.plot(time, sp_data['eta'], linewidth=1.5, label=f"${sp}$")
345                    axs.plot(time, sp_data['eta_p'], linewidth=1, linestyle='--',
346                             label=f"${sp}_{{Poisson}}$", color=(0.5, 0.5, 0.5))
347
348            axs.set(title="Coefficient of Variation, $\\eta$")
349            axs.set(xlabel=f"$t$ (sec)", ylabel="$\\eta$")
350            axs.legend(loc='upper right')
351
352            if show_plot:
353                plt.show()
354
355        else:  # self.backend == 'plotly'
356
357            for sp, sp_data in data.items():
358                if sp in species:
359                    self.fig.add_trace(
360                        go.Scatter(
361                            x=time.tolist(),
362                            y=sp_data['eta'].tolist(),
363                            mode='lines',
364                            name=f"${sp}$",
365                            line=dict(width=2)
366                        ))
367
368                    self.fig.add_trace(
369                        go.Scatter(
370                            x=time.tolist(),
371                            y=sp_data['eta_p'].tolist(),
372                            mode='lines',
373                            name=f"${sp}_{{Poisson}}$",
374                            line=dict(width=2, dash="dash")
375                        ))
376
377            self.fig.update_layout(
378                title="Coefficient of Variation",
379                xaxis=dict(title=f"$t \\; (sec)$",
380                           range=[-0.01 * time[-1], time[-1]]),
381                yaxis=dict(title="$\\eta$")
382            )
383
384            if show_plot:
385                self.fig.show()
386
387        return self
388
389    def plot_het_metrics(self,
390                         time,
391                         proc_str: tuple[str, str],
392                         proc_data: dict,
393                         *,
394                         het_attr='k',
395                         show_plot: bool = True,
396                         ax_loc: tuple = ()
397                         ) -> Self:
398        """
399        Graph species- and process-specific metrics of population heterogeneity.
400        """
401        self.setup_spines_ticks(ax_loc)
402        title = f"${proc_str[0].split(';')[0].replace(' ,', chr(92) + 'hspace{10pt} ,').replace('->', chr(92) + 'rightarrow')}$"
403
404        if self.backend == 'matplotlib':
405            axs = self.ax if len(ax_loc) == 0 else self.ax[ax_loc]
406            axs.set(xlim=(0, time[-1]))
407            axs.set(xlim=(0, time[-1]),
408                    ylim=(0, 1.5 * np.max(proc_data[f"<{het_attr}_avg>"] + proc_data[
409                        f"<{het_attr}_std>"])))
410
411            axs.plot(time, proc_data[f'<{het_attr}_avg>'],
412                     linewidth=1.5, label=f"$<{het_attr}>$", alpha=0.5)
413            axs.fill_between(time,
414                             proc_data[f"<{het_attr}_avg>"] - proc_data[f"<{het_attr}_std>"],
415                             proc_data[f"<{het_attr}_avg>"] + proc_data[f"<{het_attr}_std>"],
416                             alpha=0.5, linewidth=0)
417            axs.tick_params(axis='y', labelcolor='blue')
418
419            axs.set(title=title + (f"$, {proc_str[1]}$" if proc_str[1] != "" else ""),
420                    xlabel=f"$t$ (sec)")
421
422            if het_attr == 'Km':
423                axs.set_ylabel("$K_m$", color='blue')
424            elif het_attr == 'K50':
425                axs.set_ylabel("$K_{50}$", color='blue')
426            else:
427                axs.set_ylabel(f"${het_attr}$", color='blue')
428
429            # Second y-axis
430            axs2 = axs.twinx()
431            axs2.spines[['bottom']].set_position('zero')  # x axis
432            axs2.spines[['right']].set_position(('axes', 1))  # y axis
433            axs2.spines[['top', 'left', 'bottom']].set_visible(False)
434            # axs2.spines['right'].set_color('red')
435            axs2.set(ylim=(0, 1))
436            axs2.tick_params(axis='y', labelcolor='red')
437            axs2.yaxis.set_ticks_position('right')
438            axs2.set_yticks([i for i in np.arange(0, 1.1, 0.1)])
439            axs2.yaxis.set_minor_locator(ticker.AutoMinorLocator())
440            axs2.grid(which='major', axis='y', color='r',
441                      linestyle='--', linewidth=0.25, alpha=0.25)
442
443            axs2.plot(time, proc_data['psi_avg'],
444                      linewidth=1.5, label='$<\\psi>$', color='red', alpha=0.5)
445            axs2.fill_between(time,
446                              proc_data['psi_avg'] - proc_data['psi_std'],
447                              proc_data['psi_avg'] + proc_data['psi_std'],
448                              color='red', alpha=0.5, linewidth=0)
449            axs2.set_ylabel("$\\psi$", color='red')
450
451            if show_plot:
452                plt.show()
453
454            return self
455
456        else:  # self.backend == 'plotly'
457
458            self.fig.add_trace(
459                go.Scatter(
460                    x=time.tolist(),
461                    y=proc_data[f'<{het_attr}_avg>'].tolist(),
462                    mode='lines',
463                    name=f"$<{het_attr}>$",
464                    line=dict(width=2, color='blue'),
465                )
466            )
467
468            self.fig.add_trace(
469                go.Scatter(
470                    x=time.tolist(),
471                    y=(proc_data[f"<{het_attr}_avg>"] + proc_data[f"<{het_attr}_std>"]).tolist(),
472                    mode='lines',
473                    name=f"$<{het_attr}> + \\sigma$",
474                    line=dict(width=0),
475                    showlegend=False
476                )
477            )
478
479            self.fig.add_trace(
480                go.Scatter(
481                    x=time.tolist(),
482                    y=(proc_data[f"<{het_attr}_avg>"] - proc_data[f"<{het_attr}_std>"]).tolist(),
483                    mode='lines',
484                    name=f"$<{het_attr}> - \\sigma$",
485                    line=dict(width=0),
486                    fill='tonexty',
487                    showlegend=False
488                )
489            )
490
491            self.fig.update_layout(
492                title=title + (f"$, {proc_str[1]}$" if proc_str[1] != "" else ""),
493                xaxis=dict(title=f"$t \\; (sec)$",
494                           range=[-0.01 * time[-1], time[-1]]),
495                yaxis=dict(title="$K_m$" if het_attr == "Km" else "$K_{50}$" if het_attr == "K50" else f"${het_attr}$",
496                           color="blue",
497                           range=[0, 1.5 * np.max(proc_data[f"<{het_attr}_avg>"] + proc_data[f"<{het_attr}_std>"])])
498            )
499
500            if show_plot:
501                self.fig.show()
502
503            # Plot psi on a separate figure
504            self.fig2 = go.Figure()
505            self.fig2.update_layout(
506                width=self.fig.layout.width,
507                height=self.fig.layout.height,
508                xaxis=dict(zeroline=True, showline=True, linewidth=2,
509                           zerolinecolor="black", showticklabels=True,
510                           tickmode='auto', ticks='outside'),
511                yaxis=dict(zeroline=True, showline=True, linewidth=2,
512                           zerolinecolor="black", showticklabels=True,
513                           tickmode='auto', ticks='outside')
514            )
515
516            self.fig2.add_trace(
517                go.Scatter(
518                    x=time.tolist(),
519                    y=proc_data['psi_avg'],
520                    mode='lines',
521                    name=f"$<\\psi>$",
522                    line=dict(width=2, color='red'),
523                )
524            )
525
526            self.fig2.add_trace(
527                go.Scatter(
528                    x=time.tolist(),
529                    y=proc_data['psi_avg'] + proc_data['psi_std'],
530                    mode='lines',
531                    name=f"$<\\psi> + \\sigma$",
532                    line=dict(width=0),
533                    showlegend=False
534                )
535            )
536
537            self.fig2.add_trace(
538                go.Scatter(
539                    x=time.tolist(),
540                    y=proc_data['psi_avg'] - proc_data['psi_std'],
541                    mode='lines',
542                    name=f"$<\\psi> + \\sigma$",
543                    line=dict(width=0),
544                    fill='tonexty',
545                    showlegend=False
546                )
547            )
548
549            self.fig2.update_layout(
550                title=title + (f"$, {proc_str[1]}$" if proc_str[1] != "" else ""),
551                xaxis=dict(title=f"$t \\; (sec)$",
552                           range=[-0.01 * time[-1], time[-1]]),
553                yaxis=dict(title="$\\psi$",
554                           color="red",
555                           range=[0, 1],
556                           tick0=0,
557                           dtick=0.1)
558            )
559
560            if show_plot:
561                self.fig2.show()
562
563            return self
564
565    def savefig(self, filename: str, image_format: str = 'svg', **kwargs):
566        """ Save the figure as a file. """
567        graph_path = Path('.') / 'plots_output'
568        graph_path.mkdir(exist_ok=True)
569        graph_path_svg = graph_path / f"{filename}.{image_format}"
570        if self.backend == 'matplotlib':
571            self.fig.savefig(graph_path_svg, format=image_format, **kwargs)
572            plt.close(self.fig)
573        else:  # self.backend == 'plotly'
574            self.fig.write_image(graph_path_svg, format=image_format, **kwargs)
class Graph:
 28class Graph:
 29    """
 30    Graphing class for displaying the results of AbStochKin simulations.
 31
 32    Notes
 33    -----
 34    To successfully use the LaTeX engine for rendering text on Linux,
 35    run the following command in a terminal: `sudo apt install cm-super`.
 36    """
 37
 38    # First, set some global matplotlib settings
 39    plt.rcParams['figure.autolayout'] = True  # tight layout
 40    plt.rcParams["figure.facecolor"] = 'lightgray'
 41    plt.rcParams['text.usetex'] = True
 42    plt.rcParams['axes.titlesize'] = 9
 43    plt.rcParams['axes.labelsize'] = 7
 44    plt.rcParams['xtick.labelsize'] = 6
 45    plt.rcParams['ytick.labelsize'] = 6
 46    plt.rcParams["legend.fontsize"] = 7
 47    plt.rcParams["legend.framealpha"] = 0.65
 48
 49    def __init__(self,
 50                 /,
 51                 nrows=1,
 52                 ncols=1,
 53                 figsize=(5, 5),
 54                 dpi=300,
 55                 *,
 56                 backend: Literal['matplotlib', 'plotly'] = 'matplotlib',
 57                 **kwargs):
 58        self.backend = backend
 59
 60        if self.backend == 'matplotlib':
 61            self.fig, self.ax = plt.subplots(nrows=nrows, ncols=ncols,
 62                                             figsize=figsize, dpi=dpi,
 63                                             **kwargs)
 64        elif self.backend == 'plotly':
 65            self.fig = go.Figure()
 66            self.fig.update_layout(
 67                width=figsize[0] * dpi * 0.6,  # Convert figsize to pixels
 68                height=figsize[1] * dpi * 0.6,  # Convert figsize to pixels
 69            )
 70            self.fig2 = None  # optional second figure
 71        else:
 72            raise ValueError(f"Unknown backend: {self.backend}. "
 73                             f"Please choose from 'matplotlib' (default), 'plotly'.")
 74
 75    def setup_spines_ticks(self, ax_loc):
 76        """
 77        Set up the spines and ticks in a `matplotlib` graph.
 78        Make only the left and bottom spines/axes visible on the graph
 79        and place major ticks on them. Also set the minor ticks.
 80        """
 81        if self.backend == 'matplotlib':
 82            axs = self.ax if len(ax_loc) == 0 else self.ax[ax_loc]
 83            axs.spines[['left']].set_position('zero')
 84            axs.spines[['top', 'right']].set_visible(False)
 85            axs.xaxis.set_ticks_position('bottom')
 86            axs.yaxis.set_ticks_position('left')
 87            axs.xaxis.set_minor_locator(ticker.AutoMinorLocator())
 88            axs.yaxis.set_minor_locator(ticker.AutoMinorLocator())
 89
 90        else:  # self.backend == 'plotly'
 91
 92            self.fig.update_layout(
 93                xaxis=dict(zeroline=True, showline=True, linewidth=2,
 94                           zerolinecolor="black", showticklabels=True,
 95                           tickmode='auto', ticks='outside'),
 96                yaxis=dict(zeroline=True, showline=True, linewidth=2,
 97                           zerolinecolor="black", showticklabels=True,
 98                           tickmode='auto', ticks='outside')
 99            )
100
101    def plot_ODEs(self,
102                  de_data,
103                  *,
104                  num_pts: int = 1000,
105                  species: list[str] | tuple[str] = (),
106                  show_plot: bool = True,
107                  ax_loc: tuple = ()
108                  ) -> Self:
109        """
110        Plot the deterministic trajectories of all species obtained
111        by obtaining the solution to a system of ODEs.
112
113        Parameters
114        ----------
115        de_data : DEcalcs object
116                 Data structure containing all the data related to
117                 solving the system of ODEs.
118
119        num_pts : int, default: 1000, optional
120                 Number of points used to calculate DE curves at.
121                 Used to approximate a smooth/continuous curve.
122
123        species : sequence of strings, default: (), optional
124                 An iterable sequence of strings specifying the species
125                 names to plot. If no species are specified (the default),
126                 then all species trajectories are plotted.
127
128        show_plot : bool, default: True, optional
129                 If True, show the plot.
130
131        ax_loc : tuple, optional
132                If the figure is made up of subplots, specify the location
133                of the axis to draw the data at.
134                Ex: for two subplots, the possible values of `ax_loc`
135                are (0, ) and (1, ). That's because the `self.ax` object is
136                a 1-D array. For figures with multiple rows and columns of
137                subplots, a 2-D tuple is needed.
138        """
139        species = list(de_data.odes.keys()) if len(species) == 0 else species
140        # t, y = ode_sol.t, ode_sol.y.T  # values at precomputed time pts
141        t = np.linspace(de_data.odes_sol.t[0], de_data.odes_sol.t[-1],
142                        num_pts)  # time points for obtaining...
143        y = de_data.odes_sol.sol(t).T  # an approximately continuous solution
144
145        self.setup_spines_ticks(ax_loc)
146
147        if self.backend == 'matplotlib':
148            axs = self.ax if len(ax_loc) == 0 else self.ax[ax_loc]
149            axs.set(xlim=(0, de_data.odes_sol.t[-1]))
150
151            for i, sp in enumerate(list(de_data.odes.keys())):
152                if sp in species:
153                    axs.plot(t, y[:, i], label=f"${sp}_{{DE}}$",
154                             linestyle='--', linewidth=0.75, alpha=0.75)
155
156            # axs.set(title="Deterministic trajectories")
157            axs.set(xlabel=f"$t$ ({de_data.time_unit})", ylabel="$N$")
158            axs.legend(loc='upper right')
159            self.fig.tight_layout()
160
161            if show_plot:
162                plt.show()
163
164        else:  # self.backend == 'plotly'
165
166            for i, sp in enumerate(list(de_data.odes.keys())):
167                if sp in species:
168                    self.fig.add_trace(
169                        go.Scatter(
170                            x=t.tolist(),
171                            y=y[:, i].tolist(),
172                            mode='lines',
173                            name=f"${sp}_{{DE}}$",
174                            line=dict(dash='dash', width=1)
175                        )
176                    )
177
178            self.fig.update_layout(
179                xaxis=dict(title=f"$t \\, ({de_data.time_unit})$",
180                           range=[-0.01 * de_data.odes_sol.t[-1], de_data.odes_sol.t[-1]]),
181                yaxis=dict(title="$N$"),
182            )
183
184            if show_plot:
185                self.fig.show()
186
187        return self
188
189    def plot_trajectories(self,
190                          time,
191                          data,
192                          *,
193                          species: list[str] | tuple[str] = (),
194                          show_plot: bool = True,
195                          ax_loc: tuple = ()
196                          ) -> Self:
197        """ Graph simulation time trajectories. """
198        self.setup_spines_ticks(ax_loc)
199        species = list(data.keys()) if len(species) == 0 else species
200
201        if self.backend == 'matplotlib':
202            axs = self.ax if len(ax_loc) == 0 else self.ax[ax_loc]
203            axs.set(xlim=(0, time[-1]))
204
205            for sp, sp_data in data.items():
206                if sp in species:
207                    trajs = sp_data['N'].T
208                    for traj in trajs:
209                        axs.plot(time, traj, linewidth=0.25)
210
211            axs.set(title="ABK trajectories")
212            axs.set(xlabel=f"$t$ (sec)", ylabel="$N$")
213            # axs.legend(loc='best')
214
215            if show_plot:
216                plt.show()
217
218        else:  # self.backend == 'plotly'
219
220            for sp, sp_data in data.items():
221                if sp in species:
222                    trajs = sp_data['N'].T
223                    for i, traj in enumerate(trajs):
224                        self.fig.add_trace(
225                            go.Scatter(
226                                x=time.tolist(),
227                                y=traj.tolist(),
228                                mode='lines',
229                                name=f"${sp} \\; Run \\, {i}$",
230                                line=dict(width=0.5)
231                            )
232                        )
233            self.fig.update_layout(
234                xaxis=dict(title=f"$t \\; (sec)$",
235                           range=[-0.01 * time[-1], time[-1]]),
236                yaxis=dict(title="$N$")
237            )
238
239            if show_plot:
240                self.fig.show()
241
242        return self
243
244    def plot_avg_std(self,
245                     time,
246                     data,
247                     *,
248                     species: list[str] | tuple[str] = (),
249                     show_plot: bool = True,
250                     ax_loc: tuple = ()
251                     ) -> Self:
252        """
253        Graph simulation average trajectories and
254        1-standard-deviation envelopes.
255        """
256        self.setup_spines_ticks(ax_loc)
257        species = list(data.keys()) if len(species) == 0 else species
258
259        if self.backend == 'matplotlib':
260            axs = self.ax if len(ax_loc) == 0 else self.ax[ax_loc]
261            axs.set(xlim=(0, time[-1]))
262
263            for sp, sp_data in data.items():
264                if sp in species:
265                    axs.plot(time, sp_data['N_avg'],
266                             linewidth=1.5, label=f"$<{sp}>$", alpha=0.5)
267                    axs.fill_between(time,
268                                     sp_data['N_avg'] - sp_data['N_std'],
269                                     sp_data['N_avg'] + sp_data['N_std'],
270                                     alpha=0.5, linewidth=0)
271
272            axs.set(xlabel="$t$ (sec)", ylabel="$N$")
273            axs.legend(loc='upper right')
274            self.fig.tight_layout()
275
276            if show_plot:
277                plt.show()
278
279        else:  # self.backend == 'plotly'
280
281            for sp, sp_data in data.items():
282                if sp in species:
283                    self.fig.add_trace(
284                        go.Scatter(
285                            x=time.tolist(),
286                            y=sp_data['N_avg'].tolist(),
287                            mode='lines',
288                            name=f"$<{sp}>$",
289                            line=dict(width=2)
290                        )
291                    )
292
293                    self.fig.add_trace(
294                        go.Scatter(
295                            x=time.tolist(),
296                            y=(sp_data['N_avg'] + sp_data['N_std']).tolist(),
297                            mode='lines',
298                            name=f"$<{sp}> + \\sigma$",
299                            line=dict(width=0),
300                            showlegend=False
301                        )
302                    )
303
304                    self.fig.add_trace(
305                        go.Scatter(
306                            x=time.tolist(),
307                            y=(sp_data['N_avg'] - sp_data['N_std']).tolist(),
308                            mode='lines',
309                            line=dict(width=0),
310                            name=f"$<{sp}> - \\sigma$",
311                            fill='tonexty',
312                            showlegend=False
313                        )
314                    )
315
316            self.fig.update_layout(
317                xaxis=dict(title=f"$t \\; (sec)$",
318                           range=[-0.01 * time[-1], time[-1]]),
319                yaxis=dict(title="$N$")
320            )
321
322            if show_plot:
323                self.fig.show()
324
325        return self
326
327    def plot_eta(self,
328                 time,
329                 data,
330                 *,
331                 species: list[str] | tuple[str] = (),
332                 show_plot: bool = True,
333                 ax_loc: tuple = ()
334                 ) -> Self:
335        """ Graph the coefficient of variation. """
336        self.setup_spines_ticks(ax_loc)
337        species = list(data.keys()) if len(species) == 0 else species
338
339        if self.backend == "matplotlib":
340            axs = self.ax if len(ax_loc) == 0 else self.ax[ax_loc]
341            axs.set(xlim=(0, time[-1]))
342
343            for sp, sp_data in data.items():
344                if sp in species:
345                    axs.plot(time, sp_data['eta'], linewidth=1.5, label=f"${sp}$")
346                    axs.plot(time, sp_data['eta_p'], linewidth=1, linestyle='--',
347                             label=f"${sp}_{{Poisson}}$", color=(0.5, 0.5, 0.5))
348
349            axs.set(title="Coefficient of Variation, $\\eta$")
350            axs.set(xlabel=f"$t$ (sec)", ylabel="$\\eta$")
351            axs.legend(loc='upper right')
352
353            if show_plot:
354                plt.show()
355
356        else:  # self.backend == 'plotly'
357
358            for sp, sp_data in data.items():
359                if sp in species:
360                    self.fig.add_trace(
361                        go.Scatter(
362                            x=time.tolist(),
363                            y=sp_data['eta'].tolist(),
364                            mode='lines',
365                            name=f"${sp}$",
366                            line=dict(width=2)
367                        ))
368
369                    self.fig.add_trace(
370                        go.Scatter(
371                            x=time.tolist(),
372                            y=sp_data['eta_p'].tolist(),
373                            mode='lines',
374                            name=f"${sp}_{{Poisson}}$",
375                            line=dict(width=2, dash="dash")
376                        ))
377
378            self.fig.update_layout(
379                title="Coefficient of Variation",
380                xaxis=dict(title=f"$t \\; (sec)$",
381                           range=[-0.01 * time[-1], time[-1]]),
382                yaxis=dict(title="$\\eta$")
383            )
384
385            if show_plot:
386                self.fig.show()
387
388        return self
389
390    def plot_het_metrics(self,
391                         time,
392                         proc_str: tuple[str, str],
393                         proc_data: dict,
394                         *,
395                         het_attr='k',
396                         show_plot: bool = True,
397                         ax_loc: tuple = ()
398                         ) -> Self:
399        """
400        Graph species- and process-specific metrics of population heterogeneity.
401        """
402        self.setup_spines_ticks(ax_loc)
403        title = f"${proc_str[0].split(';')[0].replace(' ,', chr(92) + 'hspace{10pt} ,').replace('->', chr(92) + 'rightarrow')}$"
404
405        if self.backend == 'matplotlib':
406            axs = self.ax if len(ax_loc) == 0 else self.ax[ax_loc]
407            axs.set(xlim=(0, time[-1]))
408            axs.set(xlim=(0, time[-1]),
409                    ylim=(0, 1.5 * np.max(proc_data[f"<{het_attr}_avg>"] + proc_data[
410                        f"<{het_attr}_std>"])))
411
412            axs.plot(time, proc_data[f'<{het_attr}_avg>'],
413                     linewidth=1.5, label=f"$<{het_attr}>$", alpha=0.5)
414            axs.fill_between(time,
415                             proc_data[f"<{het_attr}_avg>"] - proc_data[f"<{het_attr}_std>"],
416                             proc_data[f"<{het_attr}_avg>"] + proc_data[f"<{het_attr}_std>"],
417                             alpha=0.5, linewidth=0)
418            axs.tick_params(axis='y', labelcolor='blue')
419
420            axs.set(title=title + (f"$, {proc_str[1]}$" if proc_str[1] != "" else ""),
421                    xlabel=f"$t$ (sec)")
422
423            if het_attr == 'Km':
424                axs.set_ylabel("$K_m$", color='blue')
425            elif het_attr == 'K50':
426                axs.set_ylabel("$K_{50}$", color='blue')
427            else:
428                axs.set_ylabel(f"${het_attr}$", color='blue')
429
430            # Second y-axis
431            axs2 = axs.twinx()
432            axs2.spines[['bottom']].set_position('zero')  # x axis
433            axs2.spines[['right']].set_position(('axes', 1))  # y axis
434            axs2.spines[['top', 'left', 'bottom']].set_visible(False)
435            # axs2.spines['right'].set_color('red')
436            axs2.set(ylim=(0, 1))
437            axs2.tick_params(axis='y', labelcolor='red')
438            axs2.yaxis.set_ticks_position('right')
439            axs2.set_yticks([i for i in np.arange(0, 1.1, 0.1)])
440            axs2.yaxis.set_minor_locator(ticker.AutoMinorLocator())
441            axs2.grid(which='major', axis='y', color='r',
442                      linestyle='--', linewidth=0.25, alpha=0.25)
443
444            axs2.plot(time, proc_data['psi_avg'],
445                      linewidth=1.5, label='$<\\psi>$', color='red', alpha=0.5)
446            axs2.fill_between(time,
447                              proc_data['psi_avg'] - proc_data['psi_std'],
448                              proc_data['psi_avg'] + proc_data['psi_std'],
449                              color='red', alpha=0.5, linewidth=0)
450            axs2.set_ylabel("$\\psi$", color='red')
451
452            if show_plot:
453                plt.show()
454
455            return self
456
457        else:  # self.backend == 'plotly'
458
459            self.fig.add_trace(
460                go.Scatter(
461                    x=time.tolist(),
462                    y=proc_data[f'<{het_attr}_avg>'].tolist(),
463                    mode='lines',
464                    name=f"$<{het_attr}>$",
465                    line=dict(width=2, color='blue'),
466                )
467            )
468
469            self.fig.add_trace(
470                go.Scatter(
471                    x=time.tolist(),
472                    y=(proc_data[f"<{het_attr}_avg>"] + proc_data[f"<{het_attr}_std>"]).tolist(),
473                    mode='lines',
474                    name=f"$<{het_attr}> + \\sigma$",
475                    line=dict(width=0),
476                    showlegend=False
477                )
478            )
479
480            self.fig.add_trace(
481                go.Scatter(
482                    x=time.tolist(),
483                    y=(proc_data[f"<{het_attr}_avg>"] - proc_data[f"<{het_attr}_std>"]).tolist(),
484                    mode='lines',
485                    name=f"$<{het_attr}> - \\sigma$",
486                    line=dict(width=0),
487                    fill='tonexty',
488                    showlegend=False
489                )
490            )
491
492            self.fig.update_layout(
493                title=title + (f"$, {proc_str[1]}$" if proc_str[1] != "" else ""),
494                xaxis=dict(title=f"$t \\; (sec)$",
495                           range=[-0.01 * time[-1], time[-1]]),
496                yaxis=dict(title="$K_m$" if het_attr == "Km" else "$K_{50}$" if het_attr == "K50" else f"${het_attr}$",
497                           color="blue",
498                           range=[0, 1.5 * np.max(proc_data[f"<{het_attr}_avg>"] + proc_data[f"<{het_attr}_std>"])])
499            )
500
501            if show_plot:
502                self.fig.show()
503
504            # Plot psi on a separate figure
505            self.fig2 = go.Figure()
506            self.fig2.update_layout(
507                width=self.fig.layout.width,
508                height=self.fig.layout.height,
509                xaxis=dict(zeroline=True, showline=True, linewidth=2,
510                           zerolinecolor="black", showticklabels=True,
511                           tickmode='auto', ticks='outside'),
512                yaxis=dict(zeroline=True, showline=True, linewidth=2,
513                           zerolinecolor="black", showticklabels=True,
514                           tickmode='auto', ticks='outside')
515            )
516
517            self.fig2.add_trace(
518                go.Scatter(
519                    x=time.tolist(),
520                    y=proc_data['psi_avg'],
521                    mode='lines',
522                    name=f"$<\\psi>$",
523                    line=dict(width=2, color='red'),
524                )
525            )
526
527            self.fig2.add_trace(
528                go.Scatter(
529                    x=time.tolist(),
530                    y=proc_data['psi_avg'] + proc_data['psi_std'],
531                    mode='lines',
532                    name=f"$<\\psi> + \\sigma$",
533                    line=dict(width=0),
534                    showlegend=False
535                )
536            )
537
538            self.fig2.add_trace(
539                go.Scatter(
540                    x=time.tolist(),
541                    y=proc_data['psi_avg'] - proc_data['psi_std'],
542                    mode='lines',
543                    name=f"$<\\psi> + \\sigma$",
544                    line=dict(width=0),
545                    fill='tonexty',
546                    showlegend=False
547                )
548            )
549
550            self.fig2.update_layout(
551                title=title + (f"$, {proc_str[1]}$" if proc_str[1] != "" else ""),
552                xaxis=dict(title=f"$t \\; (sec)$",
553                           range=[-0.01 * time[-1], time[-1]]),
554                yaxis=dict(title="$\\psi$",
555                           color="red",
556                           range=[0, 1],
557                           tick0=0,
558                           dtick=0.1)
559            )
560
561            if show_plot:
562                self.fig2.show()
563
564            return self
565
566    def savefig(self, filename: str, image_format: str = 'svg', **kwargs):
567        """ Save the figure as a file. """
568        graph_path = Path('.') / 'plots_output'
569        graph_path.mkdir(exist_ok=True)
570        graph_path_svg = graph_path / f"{filename}.{image_format}"
571        if self.backend == 'matplotlib':
572            self.fig.savefig(graph_path_svg, format=image_format, **kwargs)
573            plt.close(self.fig)
574        else:  # self.backend == 'plotly'
575            self.fig.write_image(graph_path_svg, format=image_format, **kwargs)

Graphing class for displaying the results of AbStochKin simulations.

Notes

To successfully use the LaTeX engine for rendering text on Linux, run the following command in a terminal: sudo apt install cm-super.

Graph( nrows=1, ncols=1, figsize=(5, 5), dpi=300, *, backend: Literal['matplotlib', 'plotly'] = 'matplotlib', **kwargs)
49    def __init__(self,
50                 /,
51                 nrows=1,
52                 ncols=1,
53                 figsize=(5, 5),
54                 dpi=300,
55                 *,
56                 backend: Literal['matplotlib', 'plotly'] = 'matplotlib',
57                 **kwargs):
58        self.backend = backend
59
60        if self.backend == 'matplotlib':
61            self.fig, self.ax = plt.subplots(nrows=nrows, ncols=ncols,
62                                             figsize=figsize, dpi=dpi,
63                                             **kwargs)
64        elif self.backend == 'plotly':
65            self.fig = go.Figure()
66            self.fig.update_layout(
67                width=figsize[0] * dpi * 0.6,  # Convert figsize to pixels
68                height=figsize[1] * dpi * 0.6,  # Convert figsize to pixels
69            )
70            self.fig2 = None  # optional second figure
71        else:
72            raise ValueError(f"Unknown backend: {self.backend}. "
73                             f"Please choose from 'matplotlib' (default), 'plotly'.")
backend
def setup_spines_ticks(self, ax_loc):
75    def setup_spines_ticks(self, ax_loc):
76        """
77        Set up the spines and ticks in a `matplotlib` graph.
78        Make only the left and bottom spines/axes visible on the graph
79        and place major ticks on them. Also set the minor ticks.
80        """
81        if self.backend == 'matplotlib':
82            axs = self.ax if len(ax_loc) == 0 else self.ax[ax_loc]
83            axs.spines[['left']].set_position('zero')
84            axs.spines[['top', 'right']].set_visible(False)
85            axs.xaxis.set_ticks_position('bottom')
86            axs.yaxis.set_ticks_position('left')
87            axs.xaxis.set_minor_locator(ticker.AutoMinorLocator())
88            axs.yaxis.set_minor_locator(ticker.AutoMinorLocator())
89
90        else:  # self.backend == 'plotly'
91
92            self.fig.update_layout(
93                xaxis=dict(zeroline=True, showline=True, linewidth=2,
94                           zerolinecolor="black", showticklabels=True,
95                           tickmode='auto', ticks='outside'),
96                yaxis=dict(zeroline=True, showline=True, linewidth=2,
97                           zerolinecolor="black", showticklabels=True,
98                           tickmode='auto', ticks='outside')
99            )

Set up the spines and ticks in a matplotlib graph. Make only the left and bottom spines/axes visible on the graph and place major ticks on them. Also set the minor ticks.

def plot_ODEs( self, de_data, *, num_pts: int = 1000, species: list[str] | tuple[str] = (), show_plot: bool = True, ax_loc: tuple = ()) -> Self:
101    def plot_ODEs(self,
102                  de_data,
103                  *,
104                  num_pts: int = 1000,
105                  species: list[str] | tuple[str] = (),
106                  show_plot: bool = True,
107                  ax_loc: tuple = ()
108                  ) -> Self:
109        """
110        Plot the deterministic trajectories of all species obtained
111        by obtaining the solution to a system of ODEs.
112
113        Parameters
114        ----------
115        de_data : DEcalcs object
116                 Data structure containing all the data related to
117                 solving the system of ODEs.
118
119        num_pts : int, default: 1000, optional
120                 Number of points used to calculate DE curves at.
121                 Used to approximate a smooth/continuous curve.
122
123        species : sequence of strings, default: (), optional
124                 An iterable sequence of strings specifying the species
125                 names to plot. If no species are specified (the default),
126                 then all species trajectories are plotted.
127
128        show_plot : bool, default: True, optional
129                 If True, show the plot.
130
131        ax_loc : tuple, optional
132                If the figure is made up of subplots, specify the location
133                of the axis to draw the data at.
134                Ex: for two subplots, the possible values of `ax_loc`
135                are (0, ) and (1, ). That's because the `self.ax` object is
136                a 1-D array. For figures with multiple rows and columns of
137                subplots, a 2-D tuple is needed.
138        """
139        species = list(de_data.odes.keys()) if len(species) == 0 else species
140        # t, y = ode_sol.t, ode_sol.y.T  # values at precomputed time pts
141        t = np.linspace(de_data.odes_sol.t[0], de_data.odes_sol.t[-1],
142                        num_pts)  # time points for obtaining...
143        y = de_data.odes_sol.sol(t).T  # an approximately continuous solution
144
145        self.setup_spines_ticks(ax_loc)
146
147        if self.backend == 'matplotlib':
148            axs = self.ax if len(ax_loc) == 0 else self.ax[ax_loc]
149            axs.set(xlim=(0, de_data.odes_sol.t[-1]))
150
151            for i, sp in enumerate(list(de_data.odes.keys())):
152                if sp in species:
153                    axs.plot(t, y[:, i], label=f"${sp}_{{DE}}$",
154                             linestyle='--', linewidth=0.75, alpha=0.75)
155
156            # axs.set(title="Deterministic trajectories")
157            axs.set(xlabel=f"$t$ ({de_data.time_unit})", ylabel="$N$")
158            axs.legend(loc='upper right')
159            self.fig.tight_layout()
160
161            if show_plot:
162                plt.show()
163
164        else:  # self.backend == 'plotly'
165
166            for i, sp in enumerate(list(de_data.odes.keys())):
167                if sp in species:
168                    self.fig.add_trace(
169                        go.Scatter(
170                            x=t.tolist(),
171                            y=y[:, i].tolist(),
172                            mode='lines',
173                            name=f"${sp}_{{DE}}$",
174                            line=dict(dash='dash', width=1)
175                        )
176                    )
177
178            self.fig.update_layout(
179                xaxis=dict(title=f"$t \\, ({de_data.time_unit})$",
180                           range=[-0.01 * de_data.odes_sol.t[-1], de_data.odes_sol.t[-1]]),
181                yaxis=dict(title="$N$"),
182            )
183
184            if show_plot:
185                self.fig.show()
186
187        return self

Plot the deterministic trajectories of all species obtained by obtaining the solution to a system of ODEs.

Parameters
  • de_data (DEcalcs object): Data structure containing all the data related to solving the system of ODEs.
  • num_pts : int, default (1000, optional): Number of points used to calculate DE curves at. Used to approximate a smooth/continuous curve.
  • species : sequence of strings, default ((), optional): An iterable sequence of strings specifying the species names to plot. If no species are specified (the default), then all species trajectories are plotted.
  • show_plot : bool, default (True, optional): If True, show the plot.
  • ax_loc (tuple, optional): If the figure is made up of subplots, specify the location of the axis to draw the data at. Ex: for two subplots, the possible values of ax_loc are (0, ) and (1, ). That's because the self.ax object is a 1-D array. For figures with multiple rows and columns of subplots, a 2-D tuple is needed.
def plot_trajectories( self, time, data, *, species: list[str] | tuple[str] = (), show_plot: bool = True, ax_loc: tuple = ()) -> Self:
189    def plot_trajectories(self,
190                          time,
191                          data,
192                          *,
193                          species: list[str] | tuple[str] = (),
194                          show_plot: bool = True,
195                          ax_loc: tuple = ()
196                          ) -> Self:
197        """ Graph simulation time trajectories. """
198        self.setup_spines_ticks(ax_loc)
199        species = list(data.keys()) if len(species) == 0 else species
200
201        if self.backend == 'matplotlib':
202            axs = self.ax if len(ax_loc) == 0 else self.ax[ax_loc]
203            axs.set(xlim=(0, time[-1]))
204
205            for sp, sp_data in data.items():
206                if sp in species:
207                    trajs = sp_data['N'].T
208                    for traj in trajs:
209                        axs.plot(time, traj, linewidth=0.25)
210
211            axs.set(title="ABK trajectories")
212            axs.set(xlabel=f"$t$ (sec)", ylabel="$N$")
213            # axs.legend(loc='best')
214
215            if show_plot:
216                plt.show()
217
218        else:  # self.backend == 'plotly'
219
220            for sp, sp_data in data.items():
221                if sp in species:
222                    trajs = sp_data['N'].T
223                    for i, traj in enumerate(trajs):
224                        self.fig.add_trace(
225                            go.Scatter(
226                                x=time.tolist(),
227                                y=traj.tolist(),
228                                mode='lines',
229                                name=f"${sp} \\; Run \\, {i}$",
230                                line=dict(width=0.5)
231                            )
232                        )
233            self.fig.update_layout(
234                xaxis=dict(title=f"$t \\; (sec)$",
235                           range=[-0.01 * time[-1], time[-1]]),
236                yaxis=dict(title="$N$")
237            )
238
239            if show_plot:
240                self.fig.show()
241
242        return self

Graph simulation time trajectories.

def plot_avg_std( self, time, data, *, species: list[str] | tuple[str] = (), show_plot: bool = True, ax_loc: tuple = ()) -> Self:
244    def plot_avg_std(self,
245                     time,
246                     data,
247                     *,
248                     species: list[str] | tuple[str] = (),
249                     show_plot: bool = True,
250                     ax_loc: tuple = ()
251                     ) -> Self:
252        """
253        Graph simulation average trajectories and
254        1-standard-deviation envelopes.
255        """
256        self.setup_spines_ticks(ax_loc)
257        species = list(data.keys()) if len(species) == 0 else species
258
259        if self.backend == 'matplotlib':
260            axs = self.ax if len(ax_loc) == 0 else self.ax[ax_loc]
261            axs.set(xlim=(0, time[-1]))
262
263            for sp, sp_data in data.items():
264                if sp in species:
265                    axs.plot(time, sp_data['N_avg'],
266                             linewidth=1.5, label=f"$<{sp}>$", alpha=0.5)
267                    axs.fill_between(time,
268                                     sp_data['N_avg'] - sp_data['N_std'],
269                                     sp_data['N_avg'] + sp_data['N_std'],
270                                     alpha=0.5, linewidth=0)
271
272            axs.set(xlabel="$t$ (sec)", ylabel="$N$")
273            axs.legend(loc='upper right')
274            self.fig.tight_layout()
275
276            if show_plot:
277                plt.show()
278
279        else:  # self.backend == 'plotly'
280
281            for sp, sp_data in data.items():
282                if sp in species:
283                    self.fig.add_trace(
284                        go.Scatter(
285                            x=time.tolist(),
286                            y=sp_data['N_avg'].tolist(),
287                            mode='lines',
288                            name=f"$<{sp}>$",
289                            line=dict(width=2)
290                        )
291                    )
292
293                    self.fig.add_trace(
294                        go.Scatter(
295                            x=time.tolist(),
296                            y=(sp_data['N_avg'] + sp_data['N_std']).tolist(),
297                            mode='lines',
298                            name=f"$<{sp}> + \\sigma$",
299                            line=dict(width=0),
300                            showlegend=False
301                        )
302                    )
303
304                    self.fig.add_trace(
305                        go.Scatter(
306                            x=time.tolist(),
307                            y=(sp_data['N_avg'] - sp_data['N_std']).tolist(),
308                            mode='lines',
309                            line=dict(width=0),
310                            name=f"$<{sp}> - \\sigma$",
311                            fill='tonexty',
312                            showlegend=False
313                        )
314                    )
315
316            self.fig.update_layout(
317                xaxis=dict(title=f"$t \\; (sec)$",
318                           range=[-0.01 * time[-1], time[-1]]),
319                yaxis=dict(title="$N$")
320            )
321
322            if show_plot:
323                self.fig.show()
324
325        return self

Graph simulation average trajectories and 1-standard-deviation envelopes.

def plot_eta( self, time, data, *, species: list[str] | tuple[str] = (), show_plot: bool = True, ax_loc: tuple = ()) -> Self:
327    def plot_eta(self,
328                 time,
329                 data,
330                 *,
331                 species: list[str] | tuple[str] = (),
332                 show_plot: bool = True,
333                 ax_loc: tuple = ()
334                 ) -> Self:
335        """ Graph the coefficient of variation. """
336        self.setup_spines_ticks(ax_loc)
337        species = list(data.keys()) if len(species) == 0 else species
338
339        if self.backend == "matplotlib":
340            axs = self.ax if len(ax_loc) == 0 else self.ax[ax_loc]
341            axs.set(xlim=(0, time[-1]))
342
343            for sp, sp_data in data.items():
344                if sp in species:
345                    axs.plot(time, sp_data['eta'], linewidth=1.5, label=f"${sp}$")
346                    axs.plot(time, sp_data['eta_p'], linewidth=1, linestyle='--',
347                             label=f"${sp}_{{Poisson}}$", color=(0.5, 0.5, 0.5))
348
349            axs.set(title="Coefficient of Variation, $\\eta$")
350            axs.set(xlabel=f"$t$ (sec)", ylabel="$\\eta$")
351            axs.legend(loc='upper right')
352
353            if show_plot:
354                plt.show()
355
356        else:  # self.backend == 'plotly'
357
358            for sp, sp_data in data.items():
359                if sp in species:
360                    self.fig.add_trace(
361                        go.Scatter(
362                            x=time.tolist(),
363                            y=sp_data['eta'].tolist(),
364                            mode='lines',
365                            name=f"${sp}$",
366                            line=dict(width=2)
367                        ))
368
369                    self.fig.add_trace(
370                        go.Scatter(
371                            x=time.tolist(),
372                            y=sp_data['eta_p'].tolist(),
373                            mode='lines',
374                            name=f"${sp}_{{Poisson}}$",
375                            line=dict(width=2, dash="dash")
376                        ))
377
378            self.fig.update_layout(
379                title="Coefficient of Variation",
380                xaxis=dict(title=f"$t \\; (sec)$",
381                           range=[-0.01 * time[-1], time[-1]]),
382                yaxis=dict(title="$\\eta$")
383            )
384
385            if show_plot:
386                self.fig.show()
387
388        return self

Graph the coefficient of variation.

def plot_het_metrics( self, time, proc_str: tuple[str, str], proc_data: dict, *, het_attr='k', show_plot: bool = True, ax_loc: tuple = ()) -> Self:
390    def plot_het_metrics(self,
391                         time,
392                         proc_str: tuple[str, str],
393                         proc_data: dict,
394                         *,
395                         het_attr='k',
396                         show_plot: bool = True,
397                         ax_loc: tuple = ()
398                         ) -> Self:
399        """
400        Graph species- and process-specific metrics of population heterogeneity.
401        """
402        self.setup_spines_ticks(ax_loc)
403        title = f"${proc_str[0].split(';')[0].replace(' ,', chr(92) + 'hspace{10pt} ,').replace('->', chr(92) + 'rightarrow')}$"
404
405        if self.backend == 'matplotlib':
406            axs = self.ax if len(ax_loc) == 0 else self.ax[ax_loc]
407            axs.set(xlim=(0, time[-1]))
408            axs.set(xlim=(0, time[-1]),
409                    ylim=(0, 1.5 * np.max(proc_data[f"<{het_attr}_avg>"] + proc_data[
410                        f"<{het_attr}_std>"])))
411
412            axs.plot(time, proc_data[f'<{het_attr}_avg>'],
413                     linewidth=1.5, label=f"$<{het_attr}>$", alpha=0.5)
414            axs.fill_between(time,
415                             proc_data[f"<{het_attr}_avg>"] - proc_data[f"<{het_attr}_std>"],
416                             proc_data[f"<{het_attr}_avg>"] + proc_data[f"<{het_attr}_std>"],
417                             alpha=0.5, linewidth=0)
418            axs.tick_params(axis='y', labelcolor='blue')
419
420            axs.set(title=title + (f"$, {proc_str[1]}$" if proc_str[1] != "" else ""),
421                    xlabel=f"$t$ (sec)")
422
423            if het_attr == 'Km':
424                axs.set_ylabel("$K_m$", color='blue')
425            elif het_attr == 'K50':
426                axs.set_ylabel("$K_{50}$", color='blue')
427            else:
428                axs.set_ylabel(f"${het_attr}$", color='blue')
429
430            # Second y-axis
431            axs2 = axs.twinx()
432            axs2.spines[['bottom']].set_position('zero')  # x axis
433            axs2.spines[['right']].set_position(('axes', 1))  # y axis
434            axs2.spines[['top', 'left', 'bottom']].set_visible(False)
435            # axs2.spines['right'].set_color('red')
436            axs2.set(ylim=(0, 1))
437            axs2.tick_params(axis='y', labelcolor='red')
438            axs2.yaxis.set_ticks_position('right')
439            axs2.set_yticks([i for i in np.arange(0, 1.1, 0.1)])
440            axs2.yaxis.set_minor_locator(ticker.AutoMinorLocator())
441            axs2.grid(which='major', axis='y', color='r',
442                      linestyle='--', linewidth=0.25, alpha=0.25)
443
444            axs2.plot(time, proc_data['psi_avg'],
445                      linewidth=1.5, label='$<\\psi>$', color='red', alpha=0.5)
446            axs2.fill_between(time,
447                              proc_data['psi_avg'] - proc_data['psi_std'],
448                              proc_data['psi_avg'] + proc_data['psi_std'],
449                              color='red', alpha=0.5, linewidth=0)
450            axs2.set_ylabel("$\\psi$", color='red')
451
452            if show_plot:
453                plt.show()
454
455            return self
456
457        else:  # self.backend == 'plotly'
458
459            self.fig.add_trace(
460                go.Scatter(
461                    x=time.tolist(),
462                    y=proc_data[f'<{het_attr}_avg>'].tolist(),
463                    mode='lines',
464                    name=f"$<{het_attr}>$",
465                    line=dict(width=2, color='blue'),
466                )
467            )
468
469            self.fig.add_trace(
470                go.Scatter(
471                    x=time.tolist(),
472                    y=(proc_data[f"<{het_attr}_avg>"] + proc_data[f"<{het_attr}_std>"]).tolist(),
473                    mode='lines',
474                    name=f"$<{het_attr}> + \\sigma$",
475                    line=dict(width=0),
476                    showlegend=False
477                )
478            )
479
480            self.fig.add_trace(
481                go.Scatter(
482                    x=time.tolist(),
483                    y=(proc_data[f"<{het_attr}_avg>"] - proc_data[f"<{het_attr}_std>"]).tolist(),
484                    mode='lines',
485                    name=f"$<{het_attr}> - \\sigma$",
486                    line=dict(width=0),
487                    fill='tonexty',
488                    showlegend=False
489                )
490            )
491
492            self.fig.update_layout(
493                title=title + (f"$, {proc_str[1]}$" if proc_str[1] != "" else ""),
494                xaxis=dict(title=f"$t \\; (sec)$",
495                           range=[-0.01 * time[-1], time[-1]]),
496                yaxis=dict(title="$K_m$" if het_attr == "Km" else "$K_{50}$" if het_attr == "K50" else f"${het_attr}$",
497                           color="blue",
498                           range=[0, 1.5 * np.max(proc_data[f"<{het_attr}_avg>"] + proc_data[f"<{het_attr}_std>"])])
499            )
500
501            if show_plot:
502                self.fig.show()
503
504            # Plot psi on a separate figure
505            self.fig2 = go.Figure()
506            self.fig2.update_layout(
507                width=self.fig.layout.width,
508                height=self.fig.layout.height,
509                xaxis=dict(zeroline=True, showline=True, linewidth=2,
510                           zerolinecolor="black", showticklabels=True,
511                           tickmode='auto', ticks='outside'),
512                yaxis=dict(zeroline=True, showline=True, linewidth=2,
513                           zerolinecolor="black", showticklabels=True,
514                           tickmode='auto', ticks='outside')
515            )
516
517            self.fig2.add_trace(
518                go.Scatter(
519                    x=time.tolist(),
520                    y=proc_data['psi_avg'],
521                    mode='lines',
522                    name=f"$<\\psi>$",
523                    line=dict(width=2, color='red'),
524                )
525            )
526
527            self.fig2.add_trace(
528                go.Scatter(
529                    x=time.tolist(),
530                    y=proc_data['psi_avg'] + proc_data['psi_std'],
531                    mode='lines',
532                    name=f"$<\\psi> + \\sigma$",
533                    line=dict(width=0),
534                    showlegend=False
535                )
536            )
537
538            self.fig2.add_trace(
539                go.Scatter(
540                    x=time.tolist(),
541                    y=proc_data['psi_avg'] - proc_data['psi_std'],
542                    mode='lines',
543                    name=f"$<\\psi> + \\sigma$",
544                    line=dict(width=0),
545                    fill='tonexty',
546                    showlegend=False
547                )
548            )
549
550            self.fig2.update_layout(
551                title=title + (f"$, {proc_str[1]}$" if proc_str[1] != "" else ""),
552                xaxis=dict(title=f"$t \\; (sec)$",
553                           range=[-0.01 * time[-1], time[-1]]),
554                yaxis=dict(title="$\\psi$",
555                           color="red",
556                           range=[0, 1],
557                           tick0=0,
558                           dtick=0.1)
559            )
560
561            if show_plot:
562                self.fig2.show()
563
564            return self

Graph species- and process-specific metrics of population heterogeneity.

def savefig(self, filename: str, image_format: str = 'svg', **kwargs):
566    def savefig(self, filename: str, image_format: str = 'svg', **kwargs):
567        """ Save the figure as a file. """
568        graph_path = Path('.') / 'plots_output'
569        graph_path.mkdir(exist_ok=True)
570        graph_path_svg = graph_path / f"{filename}.{image_format}"
571        if self.backend == 'matplotlib':
572            self.fig.savefig(graph_path_svg, format=image_format, **kwargs)
573            plt.close(self.fig)
574        else:  # self.backend == 'plotly'
575            self.fig.write_image(graph_path_svg, format=image_format, **kwargs)

Save the figure as a file.