Coverage for orchestr_ant_ion / pipeline / tracking / centroid.py: 19%
63 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-19 08:36 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-19 08:36 +0000
1"""Simple centroid-based tracking utilities."""
3from __future__ import annotations
5from collections import deque
6from typing import TYPE_CHECKING
8import numpy as np
10from orchestr_ant_ion.pipeline.constants import (
11 TRACKER_DEFAULT_MAX_AGE_SECONDS,
12 TRACKER_DEFAULT_MAX_MATCH_DISTANCE,
13 TRACKER_DEFAULT_MAX_TRAIL_POINTS,
14)
17if TYPE_CHECKING:
18 from orchestr_ant_ion.pipeline.types import Track
21class SimpleCentroidTracker:
22 """Very lightweight centroid tracker for a single class (e.g. persons)."""
24 def __init__(
25 self,
26 *,
27 max_age_s: float = TRACKER_DEFAULT_MAX_AGE_SECONDS,
28 max_match_dist_norm: float = TRACKER_DEFAULT_MAX_MATCH_DISTANCE,
29 max_trail_points: int = TRACKER_DEFAULT_MAX_TRAIL_POINTS,
30 ) -> None:
31 """Initialize tracker configuration and internal state."""
32 self._max_age_s = float(max_age_s)
33 self._max_match_dist_norm = float(max_match_dist_norm)
34 self._max_match_dist_norm_sq = self._max_match_dist_norm**2
35 self._max_trail_points = int(max_trail_points)
37 self._next_id = 1
38 self._tracks: dict[int, Track] = {}
40 def update(
41 self, centroids_norm: list[tuple[float, float]], now_ts: float
42 ) -> dict[int, Track]:
43 """Update tracks with the latest normalized centroids."""
44 self._expire_tracks(now_ts)
46 if not centroids_norm:
47 return self._tracks
49 if not self._tracks:
50 self._initialize_tracks(centroids_norm, now_ts)
51 return self._tracks
53 used_dets = self._associate_tracks(centroids_norm, now_ts)
54 self._add_unmatched(centroids_norm, used_dets, now_ts)
56 return self._tracks
58 def _expire_tracks(self, now_ts: float) -> None:
59 """Remove tracks that haven't been seen recently."""
60 expired_ids = [
61 tid
62 for tid, tr in self._tracks.items()
63 if (now_ts - tr.last_seen_ts) > self._max_age_s
64 ]
65 for tid in expired_ids:
66 self._tracks.pop(tid, None)
68 def _initialize_tracks(
69 self, centroids_norm: list[tuple[float, float]], now_ts: float
70 ) -> None:
71 """Create initial tracks from the first set of detections."""
72 for centroid in centroids_norm:
73 self._tracks[self._next_id] = Track(
74 track_id=self._next_id,
75 points_norm=deque([centroid], maxlen=self._max_trail_points),
76 last_seen_ts=now_ts,
77 )
78 self._next_id += 1
80 def _associate_tracks(
81 self, centroids_norm: list[tuple[float, float]], now_ts: float
82 ) -> set[int]:
83 """Associate detections with existing tracks using greedy matching.
85 Uses vectorized distance computation for O(n*m) complexity where
86 n = number of tracks and m = number of detections.
87 """
88 track_ids = list(self._tracks.keys())
89 prev_centroids = np.array(
90 [self._tracks[tid].points_norm[-1] for tid in track_ids]
91 )
92 curr_centroids = np.array(centroids_norm)
94 if len(track_ids) == 0 or len(curr_centroids) == 0:
95 return set()
97 diff = prev_centroids[:, np.newaxis, :] - curr_centroids[np.newaxis, :, :]
98 dist_sq_matrix = np.sum(diff**2, axis=2)
100 used_tracks: set[int] = set()
101 used_dets: set[int] = set()
103 valid_mask = dist_sq_matrix <= self._max_match_dist_norm_sq
104 candidates = []
105 for ti in range(len(track_ids)):
106 for di in range(len(centroids_norm)):
107 if valid_mask[ti, di]:
108 candidates.append((dist_sq_matrix[ti, di], ti, di))
109 candidates.sort(key=lambda item: item[0])
111 for dist_sq, ti, di in candidates:
112 if ti in used_tracks or di in used_dets:
113 continue
114 tid = track_ids[ti]
115 self._tracks[tid].points_norm.append(centroids_norm[di])
116 self._tracks[tid].last_seen_ts = now_ts
117 used_tracks.add(ti)
118 used_dets.add(di)
120 return used_dets
122 def _add_unmatched(
123 self,
124 centroids_norm: list[tuple[float, float]],
125 used_dets: set[int],
126 now_ts: float,
127 ) -> None:
128 """Create new tracks for unmatched detections."""
129 for di, centroid in enumerate(centroids_norm):
130 if di in used_dets:
131 continue
132 self._tracks[self._next_id] = Track(
133 track_id=self._next_id,
134 points_norm=deque([centroid], maxlen=self._max_trail_points),
135 last_seen_ts=now_ts,
136 )
137 self._next_id += 1