Coverage for models/HarrisWilson/DataGeneration.py: 17%
95 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-12-05 17:26 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-12-05 17:26 +0000
1import logging
2import sys
3from os.path import dirname as up
4from typing import Tuple
6import h5py as h5
7import numpy as np
8import pandas as pd
9import torch
10from dantro._import_tools import import_module_from_path
12sys.path.append(up(up(up(__file__))))
14base = import_module_from_path(mod_path=up(up(up(__file__))), mod_str="include")
16from .ABM import HarrisWilsonABM
18""" Load a dataset or generate synthetic data on which to train the neural net """
20log = logging.getLogger(__name__)
23def load_from_dir(dir) -> Tuple[torch.tensor, torch.tensor, torch.tensor]:
24 """Loads Harris-Wilson data from a directory.
26 :returns the origin sizes, network, and the time series
27 """
29 log.note(" Loading data ...")
31 # If data is to be loaded, check whether a single h5 file, a folder containing csv files, or
32 # a dictionary pointing to specific csv files has been passed.
33 if isinstance(dir, str):
34 # If data is in h5 format
35 if dir.lower().endswith(".h5"):
36 with h5.File(dir, "r") as f:
37 origins = np.array(f["HarrisWilson"]["origin_sizes"])
38 training_data = np.array(f["HarrisWilson"]["training_data"])
39 nw = np.array(f["network"]["_edge_weights"])
41 # If data is a folder, load csv files
42 else:
43 origins = pd.read_csv(
44 dir + "/origin_sizes.csv", header=0, index_col=0
45 ).to_numpy()
46 training_data = pd.read_csv(
47 dir + "/training_data.csv", header=0, index_col=0
48 ).to_numpy()
49 nw = pd.read_csv(dir + "/network.csv", header=0, index_col=0).to_numpy()
51 # If a dictionary is passed, load data from individual locations
52 elif isinstance(dir, dict):
53 origins = pd.read_csv(dir["origin_zones"], header=0, index_col=0).to_numpy()
54 training_data = pd.read_csv(
55 dir["destination_zones"], header=0, index_col=0
56 ).to_numpy()
57 nw = pd.read_csv(dir["network"], header=0, index_col=0).to_numpy()
59 origins = torch.from_numpy(origins).float()
60 training_data = torch.unsqueeze(torch.from_numpy(training_data).float(), -1)
61 nw = torch.reshape(
62 torch.from_numpy(nw).float(), (origins.shape[0], training_data.shape[1])
63 )
65 # Return the data as torch tensors
66 return origins, training_data, nw
69def generate_synthetic_data(
70 *, cfg, device: str
71) -> Tuple[torch.tensor, torch.tensor, torch.tensor]:
72 """Generates synthetic Harris-Wilson using a numerical solver.
74 :param cfg: the configuration file
75 :returns the origin sizes, network, and the time series
76 """
78 log.note(" Generating synthetic data ...")
80 # Get run configuration properties
81 data_cfg = cfg["synthetic_data"]
82 N_origin, N_destination = data_cfg["N_origin"], data_cfg["N_destination"]
83 num_steps = data_cfg["num_steps"]
85 # Generate the initial origin sizes
86 or_sizes = torch.abs(
87 base.random_tensor(
88 data_cfg.get("origin_sizes"), size=(N_origin, 1), device=device
89 )
90 )
92 # Generate the edge weights
93 network = torch.exp(
94 -1
95 * torch.abs(
96 base.random_tensor(
97 data_cfg.get("init_weights"),
98 size=(N_origin, N_destination),
99 device=device,
100 )
101 )
102 )
104 # Generate the initial destination zone sizes
105 init_dest_sizes = torch.abs(
106 base.random_tensor(
107 data_cfg.get("init_dest_sizes"), size=(N_destination, 1), device=device
108 )
109 )
111 # Extract the underlying parameters from the config
112 true_parameters = {
113 "alpha": data_cfg["alpha"],
114 "beta": data_cfg["beta"],
115 "kappa": data_cfg["kappa"],
116 "sigma": data_cfg["sigma"],
117 }
119 # Initialise the ABM
120 ABM = HarrisWilsonABM(
121 origin_sizes=or_sizes,
122 network=network,
123 true_parameters=true_parameters,
124 M=data_cfg["N_destination"],
125 epsilon=data_cfg["epsilon"],
126 dt=data_cfg["dt"],
127 device="cpu",
128 )
130 # Run the ABM for n iterations, generating the entire time series
131 dset_sizes_ts = ABM.run(
132 init_data=init_dest_sizes,
133 input_data=None,
134 n_iterations=num_steps,
135 generate_time_series=True,
136 )
138 # Return all three
139 return or_sizes, dset_sizes_ts, network
142def get_HW_data(cfg, h5file: h5.File, h5group: h5.Group, *, device: str):
143 """Gets the data for the Harris-Wilson model. If no path to a dataset is passed, synthetic data is generated using
144 the config settings
146 :param cfg: the data configuration
147 :param h5file: the h5 File to use. Needed to add a network group.
148 :param h5group: the h5 Group to write data to
149 :return: the origin zone sizes, the training data, and the network
150 """
152 data_dir = cfg.get("load_from_dir", {})
154 # Get the origin sizes, time series, and network data
155 or_sizes, dest_sizes, network = (
156 load_from_dir(data_dir)
157 if data_dir
158 else generate_synthetic_data(cfg=cfg, device=device)
159 )
161 N_origin, N_destination = or_sizes.shape[0], dest_sizes.shape[1]
163 # Only save individual time frames
164 synthetic_data = cfg.get("synthetic_data", {})
165 if synthetic_data:
166 write_start = synthetic_data.get("write_start", 0)
167 write_every = synthetic_data.get("write_every", 1)
168 time_series = dest_sizes[write_start::write_every]
169 else:
170 time_series = dest_sizes
172 # If time series has a single frame, double it to enable visualisation.
173 # This does not affect the training data
174 training_data_size = cfg.get("training_data_size", None)
175 training_data_size = (
176 time_series.shape[0] if training_data_size is None else training_data_size
177 )
179 if time_series.shape[0] == 1:
180 time_series = torch.concat((time_series, time_series), axis=0)
182 # Extract the training data from the time series data
183 training_data = dest_sizes[-training_data_size:]
185 # Set up dataset for complete synthetic time series
186 dset_time_series = h5group.create_dataset(
187 "time_series",
188 time_series.shape[:-1],
189 maxshape=time_series.shape[:-1],
190 chunks=True,
191 compression=3,
192 )
193 dset_time_series.attrs["dim_names"] = ["time", "zone_id"]
194 dset_time_series.attrs["coords_mode__time"] = "start_and_step"
195 dset_time_series.attrs["coords__time"] = [write_start, write_every]
196 dset_time_series.attrs["coords_mode__zone_id"] = "values"
197 dset_time_series.attrs["coords__zone_id"] = np.arange(
198 N_origin, N_origin + N_destination, 1
199 )
201 # Write the time series data
202 dset_time_series[:, :] = torch.flatten(time_series, start_dim=-2)
204 # Save the training time series
205 dset_training_data = h5group.create_dataset(
206 "training_data",
207 training_data.shape[:-1],
208 maxshape=training_data.shape[:-1],
209 chunks=True,
210 compression=3,
211 )
212 dset_training_data.attrs["dim_names"] = ["time", "zone_id"]
213 dset_training_data.attrs["coords_mode__time"] = "trivial"
214 dset_training_data.attrs["coords_mode__zone_id"] = "values"
215 dset_training_data.attrs["coords__zone_id"] = np.arange(
216 N_origin, N_origin + N_destination, 1
217 )
218 dset_training_data[:, :] = torch.flatten(training_data, start_dim=-2)
220 # Set up chunked dataset to store the state data in
221 # Origin zone sizes
222 dset_origin_sizes = h5group.create_dataset(
223 "origin_sizes",
224 or_sizes.shape,
225 maxshape=or_sizes.shape,
226 chunks=True,
227 compression=3,
228 )
229 dset_origin_sizes.attrs["dim_names"] = ["zone_id", "dim_name__0"]
230 dset_origin_sizes.attrs["coords_mode__zone_id"] = "trivial"
231 dset_origin_sizes[:] = or_sizes
233 # Create a network group
234 nw_group = h5file.create_group("network")
235 nw_group.attrs["content"] = "graph"
236 nw_group.attrs["is_directed"] = True
237 nw_group.attrs["allows_parallel"] = False
239 # Add vertices
240 vertices = nw_group.create_dataset(
241 "_vertices",
242 (N_origin + N_destination,),
243 maxshape=(N_origin + N_destination,),
244 chunks=True,
245 compression=3,
246 dtype=int,
247 )
248 vertices.attrs["dim_names"] = ["vertex_idx"]
249 vertices.attrs["coords_mode__vertex_idx"] = "trivial"
250 vertices[:] = np.arange(0, N_origin + N_destination, 1)
251 vertices.attrs["node_type"] = [0] * N_origin + [1] * N_destination
253 # Add edges. The network is a complete bipartite graph
254 edges = nw_group.create_dataset(
255 "_edges",
256 (N_origin * N_destination, 2),
257 maxshape=(N_origin * N_destination, 2),
258 chunks=True,
259 compression=3,
260 )
261 edges.attrs["dim_names"] = ["edge_idx", "vertex_idx"]
262 edges.attrs["coords_mode__edge_idx"] = "trivial"
263 edges.attrs["coords_mode__vertex_idx"] = "trivial"
264 edges[:,] = np.reshape(
265 [
266 [[i, j] for i in range(N_origin)]
267 for j in range(N_origin, N_origin + N_destination)
268 ],
269 (N_origin * N_destination, 2),
270 )
272 # Edge weights
273 edge_weights = nw_group.create_dataset(
274 "_edge_weights",
275 (N_origin * N_destination,),
276 maxshape=(N_origin * N_destination,),
277 chunks=True,
278 compression=3,
279 )
280 edge_weights.attrs["dim_names"] = ["edge_idx"]
281 edge_weights.attrs["coords_mode__edge_idx"] = "trivial"
282 edge_weights[:] = torch.reshape(network, (N_origin * N_destination,))
284 return or_sizes.to(device), training_data.to(device), network.to(device)