Coverage for models/SIR/ABM.py: 99%
150 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-12-05 17:26 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-12-05 17:26 +0000
1import sys
2from enum import IntEnum
3from os.path import dirname as up
4from typing import Sequence, Union
6import numpy as np
7import torch
8from dantro._import_tools import import_module_from_path
10sys.path.append(up(up(up(__file__))))
11base = import_module_from_path(mod_path=up(up(up(__file__))), mod_str="include")
13Vector = base.Vector
14distance = base.distance
16""" The SIR agent-based model of infectious diseases """
19class kinds(IntEnum):
20 SUSCEPTIBLE = 0
21 INFECTED = 1
22 RECOVERED = 2
25# --- The agent class --------------------------------------------------------------------------------------------------
26class Agent:
27 def __init__(self, *, id: int, kind: kinds, position: Vector):
28 """
30 :param id: the agent id (fixed)
31 :param kind: the agent kind
32 :param pos: the agent position
33 """
34 self.id = id
35 self.kind = kind
36 self.position = position
37 self.init_position = position
38 self.init_kind = kind
40 # Move an agent along a direction vector
41 def move(self, direction: Vector):
42 self.position += direction
44 # Moves an agent within a square domain, repelling it from the boundary walls
45 def repel_from_wall(self, direction: Vector, space: Union[Vector, Sequence]):
46 if isinstance(space, Vector):
47 x_0, x_1 = 0, space.x
48 y_0, y_1 = 0, space.y
49 else:
50 x_0, x_1 = space[0]
51 y_0, y_1 = space[1]
53 if not (self.position + direction).within_space(space):
54 new_pos = self.position + direction
55 if new_pos.x < x_0:
56 direction.x = -(direction.x + 2 * (self.position.x - x_0))
57 elif new_pos.x > x_1:
58 direction.x = -(direction.x - 2 * (x_1 - self.position.x))
59 if new_pos.y < y_0:
60 direction.y = -(direction.y + 2 * (self.position.y - y_0))
61 elif new_pos.y > y_1:
62 direction.y = -(direction.y - 2 * (y_1 - self.position.y))
63 self.move(direction)
65 def move_in_periodic_space(self, direction: Vector, space: Union[Vector, Sequence]):
66 if isinstance(space, Vector):
67 x_0, x_1 = 0, space.x
68 y_0, y_1 = 0, space.y
69 else:
70 (
71 x_0,
72 x_1,
73 ) = space[0]
74 y_0, y_1 = space[1]
76 new_position = self.position + direction
78 in_space = new_position.within_space(space)
79 while not in_space:
80 if new_position.x < x_0:
81 new_position.x = x_1 - abs(x_0 - new_position.x)
82 elif new_position.x > x_1:
83 new_position.x = x_0 + abs(new_position.x - x_1)
84 if new_position.y < y_0:
85 new_position.y = y_1 - abs(y_0 - new_position.y)
86 elif new_position.y > y_1:
87 new_position.y = y_0 + abs(new_position.y - y_1)
88 in_space = new_position.within_space(space)
90 self.position = new_position
92 def move_randomly_in_space(
93 self,
94 *,
95 space: Union[Vector, Sequence],
96 diffusion_radius: float,
97 periodic: bool = False,
98 ):
99 """Move an agent randomly within a space with a given diffusivity. If the boundaries are periodic,
100 the agent moves through the boundaries
102 :param space: the space within which to move
103 :param diffusion: the diffusivity
104 :param periodic: whether the boundaries are periodic
105 """
107 # Get a random direction in the sphere with radius diffusion_radius
108 direction = Vector((2 * np.random.rand() - 1), (2 * np.random.rand() - 1))
109 direction.normalise(norm=diffusion_radius)
111 # Non-periodic case: move within space, repelling from walls
112 if not periodic:
113 self.repel_from_wall(direction, space)
115 # Periodic case: move through boundaries
116 else:
117 self.move_in_periodic_space(direction, space)
119 def reset(self):
120 self.position = self.init_position
121 self.kind = self.init_kind
123 def __repr__(self):
124 return f"Agent {self.id}; " f"kind: {self.kind}; " f"position: {self.position}"
127# --- The SIR ABM ------------------------------------------------------------------------------------------------------
128class SIR_ABM:
129 def __init__(
130 self,
131 *,
132 N: int,
133 space: tuple,
134 sigma_s: float,
135 sigma_i: float,
136 sigma_r: float,
137 r_infectious: float,
138 p_infect: float,
139 t_infectious: float,
140 is_periodic: bool,
141 **__,
142 ):
143 """
145 :param r_infectious: the radius of contact within which infection occurs
146 :param p_infect: the probability of infecting an agent within the infection radius
147 :param t_infectious: the time for which an agent is infectious
148 """
150 # Parameters for the dynamics
151 self.space = Vector(space[0], space[1])
152 self.is_periodic = is_periodic
153 self.sigma_s = sigma_s
154 self.sigma_i = sigma_i
155 self.sigma_r = sigma_r
156 self.r_infectious = torch.tensor(r_infectious)
157 self.p_infect = torch.tensor(p_infect)
158 self.t_infectious = torch.tensor(t_infectious)
160 # Set up the cells and initialise their location at a random position in space.
161 # All cells are initialised as susceptible
162 self.N = N
163 self.init_kinds = [kinds.INFECTED] + [kinds.SUSCEPTIBLE] * (self.N - 1)
165 # Initialise the agent positions and kinds
166 self.agents = {
167 i: Agent(
168 id=i,
169 kind=self.init_kinds[i],
170 position=Vector(
171 np.random.rand() * self.space.x, np.random.rand() * self.space.y
172 ),
173 )
174 for i in range(self.N)
175 }
177 # Track the ids of the susceptible, infected, and recovered cells
178 self.kinds = None
180 # Track the current kinds, positions, and total kind counts of all the agents
181 self.current_kinds = None
182 self.current_positions = None
183 self.current_counts = None
185 # Count the number of susceptible, infected, and recovered agents.
186 self.susceptible = None
187 self.infected = None
188 self.recovered = None
190 # Track the times since infection occurred for each agent. Index = time since infection
191 self.times_since_infection = None
193 # Initialise all the datasets
194 self.initialise()
196 # Initialises the ABM data containers
197 def initialise(self):
198 # Initialise the ABM with one infected agent.
199 self.kinds = {
200 kinds.SUSCEPTIBLE: {i: None for i in range(1, self.N)},
201 kinds.INFECTED: {0: None},
202 kinds.RECOVERED: {},
203 }
204 self.current_kinds = [int(self.agents[i].kind) for i in range(self.N)]
205 self.current_counts = torch.tensor(
206 [[self.N - 1], [1.0], [0.0]], dtype=torch.float
207 )
208 self.susceptible = torch.tensor(self.N - 1, dtype=torch.float)
209 self.infected = torch.tensor(1, dtype=torch.float)
210 self.recovered = torch.tensor(0, dtype=torch.float)
212 self.current_positions = [
213 (self.agents[i].position.x, self.agents[i].position.y)
214 for i in range(self.N)
215 ]
217 self.times_since_infection = [[0]]
219 # --- Update functions ---------------------------------------------------------------------------------------------
221 # Updates the agent kinds
222 def update_kinds(self, id=None, kind=None):
223 if id is None:
224 self.current_kinds = [int(self.agents[i].kind) for i in range(self.N)]
225 else:
226 self.current_kinds[id] = int(kind)
228 # Updates the kind counts
229 def update_counts(self):
230 self.current_counts = torch.tensor(
231 [
232 [len(self.kinds[kinds.SUSCEPTIBLE])],
233 [len(self.kinds[kinds.INFECTED])],
234 [len(self.kinds[kinds.RECOVERED])],
235 ]
236 ).float()
238 # Moves the agents randomly in space
239 def move_agents_randomly(self):
240 for agent_id in self.kinds[kinds.SUSCEPTIBLE].keys():
241 self.agents[agent_id].move_randomly_in_space(
242 space=self.space,
243 diffusion_radius=self.sigma_s,
244 periodic=self.is_periodic,
245 )
247 for agent_id in self.kinds[kinds.INFECTED].keys():
248 self.agents[agent_id].move_randomly_in_space(
249 space=self.space,
250 diffusion_radius=self.sigma_i,
251 periodic=self.is_periodic,
252 )
254 for agent_id in self.kinds[kinds.RECOVERED].keys():
255 self.agents[agent_id].move_randomly_in_space(
256 space=self.space,
257 diffusion_radius=self.sigma_r,
258 periodic=self.is_periodic,
259 )
261 # Updates the agent positions
262 def update_positions(self):
263 self.current_positions = [
264 (self.agents[i].position.x, self.agents[i].position.y)
265 for i in range(self.N)
266 ]
268 # Resets the ABM to the initial state
269 def reset(self):
270 for i in range(self.N):
271 self.agents[i].reset()
272 self.initialise()
274 # --- Run function -------------------------------------------------------------------------------------------------
276 # Runs the ABM for a single iteration
277 def run_single(self, *, parameters: torch.tensor = None):
278 p_infect = self.p_infect if parameters is None else parameters[0]
279 t_infectious = self.t_infectious if parameters is None else parameters[1]
281 # Collect the ids of the infected agents
282 infected_agent_ids = []
284 if self.kinds[kinds.SUSCEPTIBLE] and self.kinds[kinds.INFECTED]:
285 # For each susceptible agent, calculate the number of contacts to an infected agent.
286 # A contact occurs when the susceptible agent is within the infection radius of an infected agent.
287 num_contacts = torch.sum(
288 torch.vstack(
289 [
290 torch.hstack(
291 [
292 torch.ceil(
293 torch.relu(
294 1
295 - distance(
296 self.agents[s].position,
297 self.agents[i].position,
298 space=self.space,
299 periodic=self.is_periodic,
300 )
301 / self.r_infectious
302 )
303 )
304 for i in self.kinds[kinds.INFECTED].keys()
305 ]
306 )
307 for s in self.kinds[kinds.SUSCEPTIBLE].keys()
308 ]
309 ),
310 dim=1,
311 ).long()
313 # Get the ids of susceptible agents that had a non-zero number of contacts with infected agents
314 risk_contacts = torch.nonzero(num_contacts).long()
316 if len(risk_contacts) != 0:
317 # Infect all susceptible agents that were in contact with an infected agent with probability
318 # 1 - (1- p_infect)^n, where n is the number of contacts.
319 infections = torch.flatten(
320 torch.ceil(
321 torch.relu(
322 (1 - torch.pow((1 - p_infect), num_contacts[risk_contacts]))
323 - torch.rand((len(risk_contacts), 1))
324 )
325 )
326 )
328 # Get the ids of the newly infected agents
329 infected_agent_ids = [
330 list(self.kinds[kinds.SUSCEPTIBLE].keys())[_]
331 for _ in torch.flatten(
332 risk_contacts[torch.nonzero(infections != 0.0, as_tuple=True)]
333 )
334 ]
336 if infected_agent_ids:
337 # Update the counts of susceptible and infected agents accordingly
338 self.current_counts[0] -= len(infected_agent_ids)
339 self.current_counts[1] += len(infected_agent_ids)
341 # Update the agent kind to 'infected'
342 for agent_id in infected_agent_ids:
343 self.agents[agent_id].kind = kinds.INFECTED
344 self.kinds[kinds.SUSCEPTIBLE].pop(agent_id)
345 self.kinds[kinds.INFECTED].update({agent_id: None})
346 self.update_kinds(agent_id, kinds.INFECTED)
348 # Track the time since infection of the newly infected agents
349 self.times_since_infection.insert(0, infected_agent_ids)
351 # Change any 'infected' agents that have surpassed the maximum infected time to 'recovered'.
352 if len(self.times_since_infection) > t_infectious:
353 # The agents that have been infectious for the maximum amount of time have recovered
354 recovered_agents = self.times_since_infection.pop()
356 # Update the counts accordingly
357 self.current_counts[1] -= len(recovered_agents)
358 self.current_counts[2] += len(recovered_agents)
360 # Update the agent kinds
361 for agent_id in recovered_agents:
362 self.kinds[kinds.INFECTED].pop(agent_id)
363 self.kinds[kinds.RECOVERED].update({agent_id: None})
364 self.update_kinds(agent_id, kinds.RECOVERED)
366 # Move the susceptible, infected, and recovered agents with their respective diffusivities
367 self.move_agents_randomly()
368 self.update_positions()