Coverage for model_plots/Kuramoto/time_and_loss.py: 30%

20 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-12-05 17:26 +0000

1import logging 

2 

3import xarray as xr 

4 

5from utopya.eval import PlotHelper, is_plot_func 

6 

7log = logging.getLogger(__name__) 

8 

9 

10@is_plot_func( 

11 use_dag=True, 

12 use_helper=True, 

13 required_dag_tags=("loss_data", "neural_time", "MCMC_time"), 

14) 

15def time_and_loss( 

16 hlpr: PlotHelper, 

17 *, 

18 data: dict, 

19 loss_color: str, 

20 neural_color: str = None, 

21 MCMC_color: str = None, 

22): 

23 """Plots a comparison of time and loss values onto two axes""" 

24 loss_data: xr.Dataset = data["loss_data"] 

25 neural_time: xr.Dataset = data["neural_time"] 

26 MCMC_time: xr.Dataset = data["MCMC_time"] 

27 ax1 = hlpr.ax 

28 

29 # Plot the loss data 

30 ax2 = hlpr.ax.twinx() 

31 loss_data["y"].plot(ax=ax2, color=loss_color) 

32 ax2.fill_between( 

33 loss_data.coords["N"], 

34 (loss_data["y"] + loss_data["yerr"]), 

35 (loss_data["y"] - loss_data["yerr"]), 

36 alpha=0.2, 

37 color=loss_color, 

38 lw=0, 

39 ) 

40 ax2.set_ylabel(r"avg. $L^1$ error after 10 epochs") 

41 ax2.set_ylim([0.1, 0.5]) 

42 

43 # Plot the compute times 

44 hlpr.select_axis(ax=ax1) 

45 neural_time["y"].plot(color=neural_color) 

46 ax1.fill_between( 

47 neural_time.coords["N"], 

48 (neural_time["y"] + neural_time["yerr"]), 

49 (neural_time["y"] - neural_time["yerr"]), 

50 alpha=0.5, 

51 color=neural_color, 

52 lw=0, 

53 ) 

54 

55 MCMC_time["y"].plot(color=MCMC_color) 

56 ax1.fill_between( 

57 MCMC_time.coords["N"], 

58 (MCMC_time["y"] + MCMC_time["yerr"]), 

59 (MCMC_time["y"] - MCMC_time["yerr"]), 

60 alpha=0.5, 

61 color=MCMC_color, 

62 lw=0, 

63 )