Coverage for tests/core/test_vector.py: 100%

71 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-12-05 17:26 +0000

1import math 

2import sys 

3from os.path import dirname as up 

4 

5import numpy as np 

6import pytest 

7import torch 

8from dantro._import_tools import import_module_from_path 

9 

10sys.path.insert(0, up(up(up(__file__)))) 

11 

12vec = import_module_from_path(mod_path=up(up(up(__file__))), mod_str="include.vector") 

13Vector = vec.Vector 

14distance = vec.distance 

15 

16# ---------------------------------------------------------------------------------------------------------------------- 

17# ---------------------------------------------------------------------------------------------------------------------- 

18 

19vectors = [Vector(2, 3), Vector(-1, -4), Vector(0.4, -0.2), Vector(0.0, 0)] 

20non_zero_vectors = [Vector(2, 3), Vector(-1, -4), Vector(0.4, -0.2)] 

21zero_vector = Vector(0, 0) 

22unit_vector = Vector(1, 1) 

23 

24 

25def test_vector(): 

26 """Tests the Vector class""" 

27 

28 for v in vectors: 

29 assert v == v 

30 assert v != Vector(-1, -1) 

31 assert v + zero_vector == v 

32 

33 assert math.sqrt(v * v) == abs(v) 

34 assert v * zero_vector == 0 

35 

36 assert -(-v) == v 

37 assert (abs(-v)) == abs(v) 

38 

39 assert abs(unit_vector) == math.sqrt(2) 

40 assert abs(zero_vector) == 0 

41 

42 assert vectors[0] + vectors[1] == Vector(1, -1) 

43 assert vectors[0] - vectors[1] == Vector(3, 7) 

44 

45 q = Vector(0.0, 0.0) 

46 q.scalar_mul(2.3) 

47 assert q == zero_vector 

48 

49 q = Vector(1.0, 1.0) 

50 q.scalar_mul(-math.pi) 

51 assert q == Vector(-math.pi, -math.pi) 

52 

53 for v in non_zero_vectors: 

54 v.normalise() 

55 assert abs(v) == pytest.approx(1, 1e-8) 

56 

57 v.normalise(norm=2.5) 

58 assert abs(v) == pytest.approx(2.5, 1e-8) 

59 

60 space_list = [[-10, 10], [-10, 10]] 

61 space_np = np.array(space_list) 

62 space_torch = torch.from_numpy(space_np) 

63 space_slice = [[-20, -10], [-10, -10]] 

64 space_small = [[-20, -10], [-10, -9]] 

65 for v in vectors: 

66 assert v.within_space(space_list) 

67 assert v.within_space(space_np) 

68 assert v.within_space(space_torch) 

69 assert not v.within_space(space_slice) 

70 assert not v.within_space(space_small) 

71 

72 space = Vector(3, 5) 

73 assert zero_vector.within_space(space) 

74 

75 

76def test_distances(): 

77 """Tests the distance function""" 

78 

79 # Non-periodic case 

80 for v in vectors: 

81 assert distance(v, zero_vector) == abs(v) 

82 assert distance(zero_vector, v) == abs(v) 

83 d_tensor = distance(v, zero_vector, as_tensor=True) 

84 assert d_tensor == abs(v) 

85 assert torch.is_tensor(d_tensor) 

86 

87 # Periodic case 

88 large_space = [[-100, 100], [-100, 100]] 

89 for v in vectors: 

90 assert distance(v, zero_vector, periodic=True, space=large_space) == abs(v) 

91 assert distance(zero_vector, v, periodic=True, space=large_space) == abs(v) 

92 

93 # Boundary points all 0 distance from each other 

94 small_space = [[-2, 2], [-2, 2]] 

95 q, p, r, t = Vector(2, 2), Vector(-2, -2), Vector(-2, 2), Vector(2, -2) 

96 for v in [p, q, r, t]: 

97 for w in [p, q, r, t]: 

98 assert distance(v, w, periodic=True, space=small_space) == 0 

99 assert distance( 

100 v, zero_vector, periodic=True, space=small_space 

101 ) == distance(v, zero_vector) 

102 

103 # Test specific value on torus 

104 v, w = Vector(-1.5, -1.5), Vector(1.5, 1.5) 

105 assert distance(v, w, periodic=True, space=small_space) == math.sqrt(2)