Coverage for orchestr_ant_ion / pipeline / tracking / centroid.py: 19%

63 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-03-19 08:36 +0000

1"""Simple centroid-based tracking utilities.""" 

2 

3from __future__ import annotations 

4 

5from collections import deque 

6from typing import TYPE_CHECKING 

7 

8import numpy as np 

9 

10from orchestr_ant_ion.pipeline.constants import ( 

11 TRACKER_DEFAULT_MAX_AGE_SECONDS, 

12 TRACKER_DEFAULT_MAX_MATCH_DISTANCE, 

13 TRACKER_DEFAULT_MAX_TRAIL_POINTS, 

14) 

15 

16 

17if TYPE_CHECKING: 

18 from orchestr_ant_ion.pipeline.types import Track 

19 

20 

21class SimpleCentroidTracker: 

22 """Very lightweight centroid tracker for a single class (e.g. persons).""" 

23 

24 def __init__( 

25 self, 

26 *, 

27 max_age_s: float = TRACKER_DEFAULT_MAX_AGE_SECONDS, 

28 max_match_dist_norm: float = TRACKER_DEFAULT_MAX_MATCH_DISTANCE, 

29 max_trail_points: int = TRACKER_DEFAULT_MAX_TRAIL_POINTS, 

30 ) -> None: 

31 """Initialize tracker configuration and internal state.""" 

32 self._max_age_s = float(max_age_s) 

33 self._max_match_dist_norm = float(max_match_dist_norm) 

34 self._max_match_dist_norm_sq = self._max_match_dist_norm**2 

35 self._max_trail_points = int(max_trail_points) 

36 

37 self._next_id = 1 

38 self._tracks: dict[int, Track] = {} 

39 

40 def update( 

41 self, centroids_norm: list[tuple[float, float]], now_ts: float 

42 ) -> dict[int, Track]: 

43 """Update tracks with the latest normalized centroids.""" 

44 self._expire_tracks(now_ts) 

45 

46 if not centroids_norm: 

47 return self._tracks 

48 

49 if not self._tracks: 

50 self._initialize_tracks(centroids_norm, now_ts) 

51 return self._tracks 

52 

53 used_dets = self._associate_tracks(centroids_norm, now_ts) 

54 self._add_unmatched(centroids_norm, used_dets, now_ts) 

55 

56 return self._tracks 

57 

58 def _expire_tracks(self, now_ts: float) -> None: 

59 """Remove tracks that haven't been seen recently.""" 

60 expired_ids = [ 

61 tid 

62 for tid, tr in self._tracks.items() 

63 if (now_ts - tr.last_seen_ts) > self._max_age_s 

64 ] 

65 for tid in expired_ids: 

66 self._tracks.pop(tid, None) 

67 

68 def _initialize_tracks( 

69 self, centroids_norm: list[tuple[float, float]], now_ts: float 

70 ) -> None: 

71 """Create initial tracks from the first set of detections.""" 

72 for centroid in centroids_norm: 

73 self._tracks[self._next_id] = Track( 

74 track_id=self._next_id, 

75 points_norm=deque([centroid], maxlen=self._max_trail_points), 

76 last_seen_ts=now_ts, 

77 ) 

78 self._next_id += 1 

79 

80 def _associate_tracks( 

81 self, centroids_norm: list[tuple[float, float]], now_ts: float 

82 ) -> set[int]: 

83 """Associate detections with existing tracks using greedy matching. 

84 

85 Uses vectorized distance computation for O(n*m) complexity where 

86 n = number of tracks and m = number of detections. 

87 """ 

88 track_ids = list(self._tracks.keys()) 

89 prev_centroids = np.array( 

90 [self._tracks[tid].points_norm[-1] for tid in track_ids] 

91 ) 

92 curr_centroids = np.array(centroids_norm) 

93 

94 if len(track_ids) == 0 or len(curr_centroids) == 0: 

95 return set() 

96 

97 diff = prev_centroids[:, np.newaxis, :] - curr_centroids[np.newaxis, :, :] 

98 dist_sq_matrix = np.sum(diff**2, axis=2) 

99 

100 used_tracks: set[int] = set() 

101 used_dets: set[int] = set() 

102 

103 valid_mask = dist_sq_matrix <= self._max_match_dist_norm_sq 

104 candidates = [] 

105 for ti in range(len(track_ids)): 

106 for di in range(len(centroids_norm)): 

107 if valid_mask[ti, di]: 

108 candidates.append((dist_sq_matrix[ti, di], ti, di)) 

109 candidates.sort(key=lambda item: item[0]) 

110 

111 for dist_sq, ti, di in candidates: 

112 if ti in used_tracks or di in used_dets: 

113 continue 

114 tid = track_ids[ti] 

115 self._tracks[tid].points_norm.append(centroids_norm[di]) 

116 self._tracks[tid].last_seen_ts = now_ts 

117 used_tracks.add(ti) 

118 used_dets.add(di) 

119 

120 return used_dets 

121 

122 def _add_unmatched( 

123 self, 

124 centroids_norm: list[tuple[float, float]], 

125 used_dets: set[int], 

126 now_ts: float, 

127 ) -> None: 

128 """Create new tracks for unmatched detections.""" 

129 for di, centroid in enumerate(centroids_norm): 

130 if di in used_dets: 

131 continue 

132 self._tracks[self._next_id] = Track( 

133 track_id=self._next_id, 

134 points_norm=deque([centroid], maxlen=self._max_trail_points), 

135 last_seen_ts=now_ts, 

136 ) 

137 self._next_id += 1