Coverage for model_plots/Kuramoto/time_and_loss.py: 30%
20 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-12-05 17:26 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-12-05 17:26 +0000
1import logging
3import xarray as xr
5from utopya.eval import PlotHelper, is_plot_func
7log = logging.getLogger(__name__)
10@is_plot_func(
11 use_dag=True,
12 use_helper=True,
13 required_dag_tags=("loss_data", "neural_time", "MCMC_time"),
14)
15def time_and_loss(
16 hlpr: PlotHelper,
17 *,
18 data: dict,
19 loss_color: str,
20 neural_color: str = None,
21 MCMC_color: str = None,
22):
23 """Plots a comparison of time and loss values onto two axes"""
24 loss_data: xr.Dataset = data["loss_data"]
25 neural_time: xr.Dataset = data["neural_time"]
26 MCMC_time: xr.Dataset = data["MCMC_time"]
27 ax1 = hlpr.ax
29 # Plot the loss data
30 ax2 = hlpr.ax.twinx()
31 loss_data["y"].plot(ax=ax2, color=loss_color)
32 ax2.fill_between(
33 loss_data.coords["N"],
34 (loss_data["y"] + loss_data["yerr"]),
35 (loss_data["y"] - loss_data["yerr"]),
36 alpha=0.2,
37 color=loss_color,
38 lw=0,
39 )
40 ax2.set_ylabel(r"avg. $L^1$ error after 10 epochs")
41 ax2.set_ylim([0.1, 0.5])
43 # Plot the compute times
44 hlpr.select_axis(ax=ax1)
45 neural_time["y"].plot(color=neural_color)
46 ax1.fill_between(
47 neural_time.coords["N"],
48 (neural_time["y"] + neural_time["yerr"]),
49 (neural_time["y"] - neural_time["yerr"]),
50 alpha=0.5,
51 color=neural_color,
52 lw=0,
53 )
55 MCMC_time["y"].plot(color=MCMC_color)
56 ax1.fill_between(
57 MCMC_time.coords["N"],
58 (MCMC_time["y"] + MCMC_time["yerr"]),
59 (MCMC_time["y"] - MCMC_time["yerr"]),
60 alpha=0.5,
61 color=MCMC_color,
62 lw=0,
63 )