Coverage for tests/core/test_vector.py: 100%
71 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-12-05 17:26 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-12-05 17:26 +0000
1import math
2import sys
3from os.path import dirname as up
5import numpy as np
6import pytest
7import torch
8from dantro._import_tools import import_module_from_path
10sys.path.insert(0, up(up(up(__file__))))
12vec = import_module_from_path(mod_path=up(up(up(__file__))), mod_str="include.vector")
13Vector = vec.Vector
14distance = vec.distance
16# ----------------------------------------------------------------------------------------------------------------------
17# ----------------------------------------------------------------------------------------------------------------------
19vectors = [Vector(2, 3), Vector(-1, -4), Vector(0.4, -0.2), Vector(0.0, 0)]
20non_zero_vectors = [Vector(2, 3), Vector(-1, -4), Vector(0.4, -0.2)]
21zero_vector = Vector(0, 0)
22unit_vector = Vector(1, 1)
25def test_vector():
26 """Tests the Vector class"""
28 for v in vectors:
29 assert v == v
30 assert v != Vector(-1, -1)
31 assert v + zero_vector == v
33 assert math.sqrt(v * v) == abs(v)
34 assert v * zero_vector == 0
36 assert -(-v) == v
37 assert (abs(-v)) == abs(v)
39 assert abs(unit_vector) == math.sqrt(2)
40 assert abs(zero_vector) == 0
42 assert vectors[0] + vectors[1] == Vector(1, -1)
43 assert vectors[0] - vectors[1] == Vector(3, 7)
45 q = Vector(0.0, 0.0)
46 q.scalar_mul(2.3)
47 assert q == zero_vector
49 q = Vector(1.0, 1.0)
50 q.scalar_mul(-math.pi)
51 assert q == Vector(-math.pi, -math.pi)
53 for v in non_zero_vectors:
54 v.normalise()
55 assert abs(v) == pytest.approx(1, 1e-8)
57 v.normalise(norm=2.5)
58 assert abs(v) == pytest.approx(2.5, 1e-8)
60 space_list = [[-10, 10], [-10, 10]]
61 space_np = np.array(space_list)
62 space_torch = torch.from_numpy(space_np)
63 space_slice = [[-20, -10], [-10, -10]]
64 space_small = [[-20, -10], [-10, -9]]
65 for v in vectors:
66 assert v.within_space(space_list)
67 assert v.within_space(space_np)
68 assert v.within_space(space_torch)
69 assert not v.within_space(space_slice)
70 assert not v.within_space(space_small)
72 space = Vector(3, 5)
73 assert zero_vector.within_space(space)
76def test_distances():
77 """Tests the distance function"""
79 # Non-periodic case
80 for v in vectors:
81 assert distance(v, zero_vector) == abs(v)
82 assert distance(zero_vector, v) == abs(v)
83 d_tensor = distance(v, zero_vector, as_tensor=True)
84 assert d_tensor == abs(v)
85 assert torch.is_tensor(d_tensor)
87 # Periodic case
88 large_space = [[-100, 100], [-100, 100]]
89 for v in vectors:
90 assert distance(v, zero_vector, periodic=True, space=large_space) == abs(v)
91 assert distance(zero_vector, v, periodic=True, space=large_space) == abs(v)
93 # Boundary points all 0 distance from each other
94 small_space = [[-2, 2], [-2, 2]]
95 q, p, r, t = Vector(2, 2), Vector(-2, -2), Vector(-2, 2), Vector(2, -2)
96 for v in [p, q, r, t]:
97 for w in [p, q, r, t]:
98 assert distance(v, w, periodic=True, space=small_space) == 0
99 assert distance(
100 v, zero_vector, periodic=True, space=small_space
101 ) == distance(v, zero_vector)
103 # Test specific value on torus
104 v, w = Vector(-1.5, -1.5), Vector(1.5, 1.5)
105 assert distance(v, w, periodic=True, space=small_space) == math.sqrt(2)