Coverage for include/vector.py: 98%

43 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-12-05 17:26 +0000

1import math 

2from typing import Sequence, Union 

3 

4import numpy as np 

5import torch 

6 

7 

8# --- The vector class ----------------------------------------------------------------------------------------------- 

9class Vector: 

10 def __init__(self, x: float, y: float): 

11 """ 

12 :param x: the x coordinate 

13 :param y: the y coordinate 

14 """ 

15 

16 self.x = x 

17 self.y = y 

18 

19 # Magic methods .................................................................................................... 

20 def __abs__(self): 

21 return math.sqrt(pow(self.x, 2) + pow(self.y, 2)) 

22 

23 def __add__(self, other): 

24 return Vector(self.x + other.x, self.y + other.y) 

25 

26 def __eq__(self, other): 

27 return self.x == other.x and self.y == other.y 

28 

29 def __mul__(self, other): 

30 return self.x * other.x + self.y * other.y 

31 

32 def __mod__(self, other): 

33 return Vector(self.x % other.x, self.y % other.y) 

34 

35 def __neg__(self): 

36 return Vector(-self.x, -self.y) 

37 

38 def __repr__(self) -> str: 

39 return f"({self.x}, {self.y})" 

40 

41 def __sub__(self, other): 

42 return Vector(self.x - other.x, self.y - other.y) 

43 

44 def scalar_mul(self, l: float): 

45 self.x *= l 

46 self.y *= l 

47 

48 def normalise(self, *, norm: float = 1): 

49 f = norm / self.__abs__() 

50 self.scalar_mul(f) 

51 

52 def within_space(self, space: Union[Sequence, "Vector"]) -> bool: 

53 # Checks whether the vector lies within a square domain 

54 if isinstance(space, Vector): 

55 return (0 <= self.x <= space.x) and (0 <= self.y <= space.y) 

56 else: 

57 return (space[0][0] <= self.x <= space[0][1]) and ( 

58 space[1][0] <= self.y <= space[1][1] 

59 ) 

60 

61 

62def distance( 

63 v: Vector, 

64 w: Vector, 

65 *, 

66 periodic: bool = False, 

67 space: Union[Vector, Sequence] = None, 

68 as_tensor: bool = True, 

69): 

70 """Returns the distance between two vectors v and w. If the space is periodic, the distance is 

71 calculated accordingly.""" 

72 

73 if not periodic: 

74 return ( 

75 abs(v - w) if not as_tensor else torch.tensor(abs(v - w), dtype=torch.float) 

76 ) 

77 else: 

78 d = v - w 

79 

80 if isinstance(space, Vector): 

81 L_x, L_y = abs(space.x), abs(space.y) 

82 else: 

83 L_x, L_y = abs(np.diff(space[0])), abs(np.diff(space[1])) 

84 

85 dist = math.sqrt( 

86 pow(min(abs(d.x), L_x - abs(d.x)), 2) 

87 + pow(min(abs(d.y), L_y - abs(d.y)), 2) 

88 ) 

89 

90 return dist if not as_tensor else torch.tensor(dist, dtype=torch.float)