Coverage for include/vector.py: 98%
43 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-12-05 17:26 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-12-05 17:26 +0000
1import math
2from typing import Sequence, Union
4import numpy as np
5import torch
8# --- The vector class -----------------------------------------------------------------------------------------------
9class Vector:
10 def __init__(self, x: float, y: float):
11 """
12 :param x: the x coordinate
13 :param y: the y coordinate
14 """
16 self.x = x
17 self.y = y
19 # Magic methods ....................................................................................................
20 def __abs__(self):
21 return math.sqrt(pow(self.x, 2) + pow(self.y, 2))
23 def __add__(self, other):
24 return Vector(self.x + other.x, self.y + other.y)
26 def __eq__(self, other):
27 return self.x == other.x and self.y == other.y
29 def __mul__(self, other):
30 return self.x * other.x + self.y * other.y
32 def __mod__(self, other):
33 return Vector(self.x % other.x, self.y % other.y)
35 def __neg__(self):
36 return Vector(-self.x, -self.y)
38 def __repr__(self) -> str:
39 return f"({self.x}, {self.y})"
41 def __sub__(self, other):
42 return Vector(self.x - other.x, self.y - other.y)
44 def scalar_mul(self, l: float):
45 self.x *= l
46 self.y *= l
48 def normalise(self, *, norm: float = 1):
49 f = norm / self.__abs__()
50 self.scalar_mul(f)
52 def within_space(self, space: Union[Sequence, "Vector"]) -> bool:
53 # Checks whether the vector lies within a square domain
54 if isinstance(space, Vector):
55 return (0 <= self.x <= space.x) and (0 <= self.y <= space.y)
56 else:
57 return (space[0][0] <= self.x <= space[0][1]) and (
58 space[1][0] <= self.y <= space[1][1]
59 )
62def distance(
63 v: Vector,
64 w: Vector,
65 *,
66 periodic: bool = False,
67 space: Union[Vector, Sequence] = None,
68 as_tensor: bool = True,
69):
70 """Returns the distance between two vectors v and w. If the space is periodic, the distance is
71 calculated accordingly."""
73 if not periodic:
74 return (
75 abs(v - w) if not as_tensor else torch.tensor(abs(v - w), dtype=torch.float)
76 )
77 else:
78 d = v - w
80 if isinstance(space, Vector):
81 L_x, L_y = abs(space.x), abs(space.y)
82 else:
83 L_x, L_y = abs(np.diff(space[0])), abs(np.diff(space[1]))
85 dist = math.sqrt(
86 pow(min(abs(d.x), L_x - abs(d.x)), 2)
87 + pow(min(abs(d.y), L_y - abs(d.y)), 2)
88 )
90 return dist if not as_tensor else torch.tensor(dist, dtype=torch.float)