Coverage for include / solvers.py: 28%
29 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 16:26 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 16:26 +0000
1import math
2import torch
3from functools import wraps
4from typing import Callable, Optional, Tuple, Dict
5from torchdiffeq import odeint, odeint_adjoint
7Tensor = torch.Tensor
9def build_time_grid(t: Tensor = None, t_span: tuple = None, dt: float = None, *,
10 device = None, dtype = None) -> Tensor:
12 # Build time grid if not supplied
13 if t is None:
14 if t_span is None or dt is None:
15 raise ValueError("Provide either `t` or (`t_span` and `dt`).")
16 t0, t1 = t_span
17 n_steps = int(math.ceil((t1 - t0) / dt))
18 return torch.arange(n_steps + 1, device=device, dtype=dtype) * dt + t0
19 else:
20 return torch.as_tensor(t, device=device, dtype=dtype)
22# ---------------------------------------------------------------------
23# Decorator factory over torchdiffeq
24# ---------------------------------------------------------------------
25def torchdiffeq_solver(method: str = "dopri5",
26 adjoint: bool = False,
27 rtol: float = 1e-6,
28 atol: float = 1e-9,
29 options: Optional[Dict] = None):
30 """
31 Turns an RHS f(t, y, *args, **kwargs) into a solver that calls torchdiffeq.
32 The decorated function signature is:
33 solve(*args, y0, t=None, t_span=None, dt=None, device=None, dtype=None, **kwargs)
34 You must provide either `t` OR (`t_span` and `dt`).
35 """
36 integrator = odeint_adjoint if adjoint else odeint
37 options = {} if options is None else options
39 def decorator(rhs: Callable):
40 @wraps(rhs)
41 def solve(*args,
42 y0,
43 t: Optional[Tensor] = None,
44 t_span: Optional[Tuple[float, float]] = None,
45 dt: Optional[float] = None,
46 device: Optional[torch.device] = None,
47 dtype: Optional[torch.dtype] = None,
48 **kwargs):
49 # Prepare initial condition
50 y0 = torch.as_tensor(y0, device=device, dtype=dtype)
51 device, dtype = y0.device, y0.dtype
53 # Build time grid
54 t = build_time_grid(t, t_span, dt, device=device, dtype=dtype)
56 # Closure that captures args/kwargs
57 def fun(ti, yi):
58 return rhs(ti, yi, *args, **kwargs)
60 y = integrator(fun, y0, t, rtol=rtol, atol=atol, method=method, options=options)
61 # y shape: (len(t), *y0.shape)
62 return t, y
63 return solve
64 return decorator