Coverage for models/SIR/ABM.py: 99%

150 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-12-05 17:26 +0000

1import sys 

2from enum import IntEnum 

3from os.path import dirname as up 

4from typing import Sequence, Union 

5 

6import numpy as np 

7import torch 

8from dantro._import_tools import import_module_from_path 

9 

10sys.path.append(up(up(up(__file__)))) 

11base = import_module_from_path(mod_path=up(up(up(__file__))), mod_str="include") 

12 

13Vector = base.Vector 

14distance = base.distance 

15 

16""" The SIR agent-based model of infectious diseases """ 

17 

18 

19class kinds(IntEnum): 

20 SUSCEPTIBLE = 0 

21 INFECTED = 1 

22 RECOVERED = 2 

23 

24 

25# --- The agent class -------------------------------------------------------------------------------------------------- 

26class Agent: 

27 def __init__(self, *, id: int, kind: kinds, position: Vector): 

28 """ 

29 

30 :param id: the agent id (fixed) 

31 :param kind: the agent kind 

32 :param pos: the agent position 

33 """ 

34 self.id = id 

35 self.kind = kind 

36 self.position = position 

37 self.init_position = position 

38 self.init_kind = kind 

39 

40 # Move an agent along a direction vector 

41 def move(self, direction: Vector): 

42 self.position += direction 

43 

44 # Moves an agent within a square domain, repelling it from the boundary walls 

45 def repel_from_wall(self, direction: Vector, space: Union[Vector, Sequence]): 

46 if isinstance(space, Vector): 

47 x_0, x_1 = 0, space.x 

48 y_0, y_1 = 0, space.y 

49 else: 

50 x_0, x_1 = space[0] 

51 y_0, y_1 = space[1] 

52 

53 if not (self.position + direction).within_space(space): 

54 new_pos = self.position + direction 

55 if new_pos.x < x_0: 

56 direction.x = -(direction.x + 2 * (self.position.x - x_0)) 

57 elif new_pos.x > x_1: 

58 direction.x = -(direction.x - 2 * (x_1 - self.position.x)) 

59 if new_pos.y < y_0: 

60 direction.y = -(direction.y + 2 * (self.position.y - y_0)) 

61 elif new_pos.y > y_1: 

62 direction.y = -(direction.y - 2 * (y_1 - self.position.y)) 

63 self.move(direction) 

64 

65 def move_in_periodic_space(self, direction: Vector, space: Union[Vector, Sequence]): 

66 if isinstance(space, Vector): 

67 x_0, x_1 = 0, space.x 

68 y_0, y_1 = 0, space.y 

69 else: 

70 ( 

71 x_0, 

72 x_1, 

73 ) = space[0] 

74 y_0, y_1 = space[1] 

75 

76 new_position = self.position + direction 

77 

78 in_space = new_position.within_space(space) 

79 while not in_space: 

80 if new_position.x < x_0: 

81 new_position.x = x_1 - abs(x_0 - new_position.x) 

82 elif new_position.x > x_1: 

83 new_position.x = x_0 + abs(new_position.x - x_1) 

84 if new_position.y < y_0: 

85 new_position.y = y_1 - abs(y_0 - new_position.y) 

86 elif new_position.y > y_1: 

87 new_position.y = y_0 + abs(new_position.y - y_1) 

88 in_space = new_position.within_space(space) 

89 

90 self.position = new_position 

91 

92 def move_randomly_in_space( 

93 self, 

94 *, 

95 space: Union[Vector, Sequence], 

96 diffusion_radius: float, 

97 periodic: bool = False, 

98 ): 

99 """Move an agent randomly within a space with a given diffusivity. If the boundaries are periodic, 

100 the agent moves through the boundaries 

101 

102 :param space: the space within which to move 

103 :param diffusion: the diffusivity 

104 :param periodic: whether the boundaries are periodic 

105 """ 

106 

107 # Get a random direction in the sphere with radius diffusion_radius 

108 direction = Vector((2 * np.random.rand() - 1), (2 * np.random.rand() - 1)) 

109 direction.normalise(norm=diffusion_radius) 

110 

111 # Non-periodic case: move within space, repelling from walls 

112 if not periodic: 

113 self.repel_from_wall(direction, space) 

114 

115 # Periodic case: move through boundaries 

116 else: 

117 self.move_in_periodic_space(direction, space) 

118 

119 def reset(self): 

120 self.position = self.init_position 

121 self.kind = self.init_kind 

122 

123 def __repr__(self): 

124 return f"Agent {self.id}; " f"kind: {self.kind}; " f"position: {self.position}" 

125 

126 

127# --- The SIR ABM ------------------------------------------------------------------------------------------------------ 

128class SIR_ABM: 

129 def __init__( 

130 self, 

131 *, 

132 N: int, 

133 space: tuple, 

134 sigma_s: float, 

135 sigma_i: float, 

136 sigma_r: float, 

137 r_infectious: float, 

138 p_infect: float, 

139 t_infectious: float, 

140 is_periodic: bool, 

141 **__, 

142 ): 

143 """ 

144 

145 :param r_infectious: the radius of contact within which infection occurs 

146 :param p_infect: the probability of infecting an agent within the infection radius 

147 :param t_infectious: the time for which an agent is infectious 

148 """ 

149 

150 # Parameters for the dynamics 

151 self.space = Vector(space[0], space[1]) 

152 self.is_periodic = is_periodic 

153 self.sigma_s = sigma_s 

154 self.sigma_i = sigma_i 

155 self.sigma_r = sigma_r 

156 self.r_infectious = torch.tensor(r_infectious) 

157 self.p_infect = torch.tensor(p_infect) 

158 self.t_infectious = torch.tensor(t_infectious) 

159 

160 # Set up the cells and initialise their location at a random position in space. 

161 # All cells are initialised as susceptible 

162 self.N = N 

163 self.init_kinds = [kinds.INFECTED] + [kinds.SUSCEPTIBLE] * (self.N - 1) 

164 

165 # Initialise the agent positions and kinds 

166 self.agents = { 

167 i: Agent( 

168 id=i, 

169 kind=self.init_kinds[i], 

170 position=Vector( 

171 np.random.rand() * self.space.x, np.random.rand() * self.space.y 

172 ), 

173 ) 

174 for i in range(self.N) 

175 } 

176 

177 # Track the ids of the susceptible, infected, and recovered cells 

178 self.kinds = None 

179 

180 # Track the current kinds, positions, and total kind counts of all the agents 

181 self.current_kinds = None 

182 self.current_positions = None 

183 self.current_counts = None 

184 

185 # Count the number of susceptible, infected, and recovered agents. 

186 self.susceptible = None 

187 self.infected = None 

188 self.recovered = None 

189 

190 # Track the times since infection occurred for each agent. Index = time since infection 

191 self.times_since_infection = None 

192 

193 # Initialise all the datasets 

194 self.initialise() 

195 

196 # Initialises the ABM data containers 

197 def initialise(self): 

198 # Initialise the ABM with one infected agent. 

199 self.kinds = { 

200 kinds.SUSCEPTIBLE: {i: None for i in range(1, self.N)}, 

201 kinds.INFECTED: {0: None}, 

202 kinds.RECOVERED: {}, 

203 } 

204 self.current_kinds = [int(self.agents[i].kind) for i in range(self.N)] 

205 self.current_counts = torch.tensor( 

206 [[self.N - 1], [1.0], [0.0]], dtype=torch.float 

207 ) 

208 self.susceptible = torch.tensor(self.N - 1, dtype=torch.float) 

209 self.infected = torch.tensor(1, dtype=torch.float) 

210 self.recovered = torch.tensor(0, dtype=torch.float) 

211 

212 self.current_positions = [ 

213 (self.agents[i].position.x, self.agents[i].position.y) 

214 for i in range(self.N) 

215 ] 

216 

217 self.times_since_infection = [[0]] 

218 

219 # --- Update functions --------------------------------------------------------------------------------------------- 

220 

221 # Updates the agent kinds 

222 def update_kinds(self, id=None, kind=None): 

223 if id is None: 

224 self.current_kinds = [int(self.agents[i].kind) for i in range(self.N)] 

225 else: 

226 self.current_kinds[id] = int(kind) 

227 

228 # Updates the kind counts 

229 def update_counts(self): 

230 self.current_counts = torch.tensor( 

231 [ 

232 [len(self.kinds[kinds.SUSCEPTIBLE])], 

233 [len(self.kinds[kinds.INFECTED])], 

234 [len(self.kinds[kinds.RECOVERED])], 

235 ] 

236 ).float() 

237 

238 # Moves the agents randomly in space 

239 def move_agents_randomly(self): 

240 for agent_id in self.kinds[kinds.SUSCEPTIBLE].keys(): 

241 self.agents[agent_id].move_randomly_in_space( 

242 space=self.space, 

243 diffusion_radius=self.sigma_s, 

244 periodic=self.is_periodic, 

245 ) 

246 

247 for agent_id in self.kinds[kinds.INFECTED].keys(): 

248 self.agents[agent_id].move_randomly_in_space( 

249 space=self.space, 

250 diffusion_radius=self.sigma_i, 

251 periodic=self.is_periodic, 

252 ) 

253 

254 for agent_id in self.kinds[kinds.RECOVERED].keys(): 

255 self.agents[agent_id].move_randomly_in_space( 

256 space=self.space, 

257 diffusion_radius=self.sigma_r, 

258 periodic=self.is_periodic, 

259 ) 

260 

261 # Updates the agent positions 

262 def update_positions(self): 

263 self.current_positions = [ 

264 (self.agents[i].position.x, self.agents[i].position.y) 

265 for i in range(self.N) 

266 ] 

267 

268 # Resets the ABM to the initial state 

269 def reset(self): 

270 for i in range(self.N): 

271 self.agents[i].reset() 

272 self.initialise() 

273 

274 # --- Run function ------------------------------------------------------------------------------------------------- 

275 

276 # Runs the ABM for a single iteration 

277 def run_single(self, *, parameters: torch.tensor = None): 

278 p_infect = self.p_infect if parameters is None else parameters[0] 

279 t_infectious = self.t_infectious if parameters is None else parameters[1] 

280 

281 # Collect the ids of the infected agents 

282 infected_agent_ids = [] 

283 

284 if self.kinds[kinds.SUSCEPTIBLE] and self.kinds[kinds.INFECTED]: 

285 # For each susceptible agent, calculate the number of contacts to an infected agent. 

286 # A contact occurs when the susceptible agent is within the infection radius of an infected agent. 

287 num_contacts = torch.sum( 

288 torch.vstack( 

289 [ 

290 torch.hstack( 

291 [ 

292 torch.ceil( 

293 torch.relu( 

294 1 

295 - distance( 

296 self.agents[s].position, 

297 self.agents[i].position, 

298 space=self.space, 

299 periodic=self.is_periodic, 

300 ) 

301 / self.r_infectious 

302 ) 

303 ) 

304 for i in self.kinds[kinds.INFECTED].keys() 

305 ] 

306 ) 

307 for s in self.kinds[kinds.SUSCEPTIBLE].keys() 

308 ] 

309 ), 

310 dim=1, 

311 ).long() 

312 

313 # Get the ids of susceptible agents that had a non-zero number of contacts with infected agents 

314 risk_contacts = torch.nonzero(num_contacts).long() 

315 

316 if len(risk_contacts) != 0: 

317 # Infect all susceptible agents that were in contact with an infected agent with probability 

318 # 1 - (1- p_infect)^n, where n is the number of contacts. 

319 infections = torch.flatten( 

320 torch.ceil( 

321 torch.relu( 

322 (1 - torch.pow((1 - p_infect), num_contacts[risk_contacts])) 

323 - torch.rand((len(risk_contacts), 1)) 

324 ) 

325 ) 

326 ) 

327 

328 # Get the ids of the newly infected agents 

329 infected_agent_ids = [ 

330 list(self.kinds[kinds.SUSCEPTIBLE].keys())[_] 

331 for _ in torch.flatten( 

332 risk_contacts[torch.nonzero(infections != 0.0, as_tuple=True)] 

333 ) 

334 ] 

335 

336 if infected_agent_ids: 

337 # Update the counts of susceptible and infected agents accordingly 

338 self.current_counts[0] -= len(infected_agent_ids) 

339 self.current_counts[1] += len(infected_agent_ids) 

340 

341 # Update the agent kind to 'infected' 

342 for agent_id in infected_agent_ids: 

343 self.agents[agent_id].kind = kinds.INFECTED 

344 self.kinds[kinds.SUSCEPTIBLE].pop(agent_id) 

345 self.kinds[kinds.INFECTED].update({agent_id: None}) 

346 self.update_kinds(agent_id, kinds.INFECTED) 

347 

348 # Track the time since infection of the newly infected agents 

349 self.times_since_infection.insert(0, infected_agent_ids) 

350 

351 # Change any 'infected' agents that have surpassed the maximum infected time to 'recovered'. 

352 if len(self.times_since_infection) > t_infectious: 

353 # The agents that have been infectious for the maximum amount of time have recovered 

354 recovered_agents = self.times_since_infection.pop() 

355 

356 # Update the counts accordingly 

357 self.current_counts[1] -= len(recovered_agents) 

358 self.current_counts[2] += len(recovered_agents) 

359 

360 # Update the agent kinds 

361 for agent_id in recovered_agents: 

362 self.kinds[kinds.INFECTED].pop(agent_id) 

363 self.kinds[kinds.RECOVERED].update({agent_id: None}) 

364 self.update_kinds(agent_id, kinds.RECOVERED) 

365 

366 # Move the susceptible, infected, and recovered agents with their respective diffusivities 

367 self.move_agents_randomly() 

368 self.update_positions()