Coverage for include / solvers.py: 28%

29 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 16:26 +0000

1import math 

2import torch 

3from functools import wraps 

4from typing import Callable, Optional, Tuple, Dict 

5from torchdiffeq import odeint, odeint_adjoint 

6 

7Tensor = torch.Tensor 

8 

9def build_time_grid(t: Tensor = None, t_span: tuple = None, dt: float = None, *, 

10 device = None, dtype = None) -> Tensor: 

11 

12 # Build time grid if not supplied 

13 if t is None: 

14 if t_span is None or dt is None: 

15 raise ValueError("Provide either `t` or (`t_span` and `dt`).") 

16 t0, t1 = t_span 

17 n_steps = int(math.ceil((t1 - t0) / dt)) 

18 return torch.arange(n_steps + 1, device=device, dtype=dtype) * dt + t0 

19 else: 

20 return torch.as_tensor(t, device=device, dtype=dtype) 

21 

22# --------------------------------------------------------------------- 

23# Decorator factory over torchdiffeq 

24# --------------------------------------------------------------------- 

25def torchdiffeq_solver(method: str = "dopri5", 

26 adjoint: bool = False, 

27 rtol: float = 1e-6, 

28 atol: float = 1e-9, 

29 options: Optional[Dict] = None): 

30 """ 

31 Turns an RHS f(t, y, *args, **kwargs) into a solver that calls torchdiffeq. 

32 The decorated function signature is: 

33 solve(*args, y0, t=None, t_span=None, dt=None, device=None, dtype=None, **kwargs) 

34 You must provide either `t` OR (`t_span` and `dt`). 

35 """ 

36 integrator = odeint_adjoint if adjoint else odeint 

37 options = {} if options is None else options 

38 

39 def decorator(rhs: Callable): 

40 @wraps(rhs) 

41 def solve(*args, 

42 y0, 

43 t: Optional[Tensor] = None, 

44 t_span: Optional[Tuple[float, float]] = None, 

45 dt: Optional[float] = None, 

46 device: Optional[torch.device] = None, 

47 dtype: Optional[torch.dtype] = None, 

48 **kwargs): 

49 # Prepare initial condition 

50 y0 = torch.as_tensor(y0, device=device, dtype=dtype) 

51 device, dtype = y0.device, y0.dtype 

52 

53 # Build time grid 

54 t = build_time_grid(t, t_span, dt, device=device, dtype=dtype) 

55 

56 # Closure that captures args/kwargs 

57 def fun(ti, yi): 

58 return rhs(ti, yi, *args, **kwargs) 

59 

60 y = integrator(fun, y0, t, rtol=rtol, atol=atol, method=method, options=options) 

61 # y shape: (len(t), *y0.shape) 

62 return t, y 

63 return solve 

64 return decorator