Coverage for orchestr_ant_ion / yolo / core / postprocess.py: 0%
170 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-19 08:36 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-19 08:36 +0000
1"""Post-processing utilities for YOLO model outputs."""
3from __future__ import annotations
5from dataclasses import dataclass
6from typing import TYPE_CHECKING
8import numpy as np
9from loguru import logger
11from orchestr_ant_ion.pipeline.constants import POSTPROCESS_DEFAULT_CONF_THRESHOLD
12from orchestr_ant_ion.yolo.core.constants import CLASS_NAMES
15if TYPE_CHECKING:
16 from collections.abc import Sequence
19@dataclass
20class DecodeConfig:
21 """Configuration for post-processing model outputs."""
23 scale: float
24 pad_x: int
25 pad_y: int
26 input_size: tuple[int, int]
27 conf_threshold: float = POSTPROCESS_DEFAULT_CONF_THRESHOLD
28 debug_boxes: bool = False
31def _softmax(scores: np.ndarray) -> np.ndarray:
32 """Apply softmax normalization to a score array.
34 Args:
35 scores: Raw score array (any shape).
37 Returns:
38 Normalized probability array with same shape.
39 """
40 shifted = scores - np.max(scores)
41 exp_scores = np.exp(shifted)
42 return exp_scores / np.sum(exp_scores)
45def _xywh_to_xyxy(boxes: np.ndarray) -> np.ndarray:
46 """Convert bounding boxes from center format to corner format.
48 Args:
49 boxes: Array of shape (N, 4) with [cx, cy, w, h] format.
51 Returns:
52 Array of shape (N, 4) with [x1, y1, x2, y2] format.
53 """
54 x, y, w, h = boxes.T
55 x1 = x - w / 2
56 y1 = y - h / 2
57 x2 = x + w / 2
58 y2 = y + h / 2
59 return np.stack([x1, y1, x2, y2], axis=1)
62def _sigmoid(values: np.ndarray) -> np.ndarray:
63 """Apply sigmoid activation to values.
65 Args:
66 values: Input array (any shape).
68 Returns:
69 Sigmoid-activated array with same shape.
70 """
71 return 1.0 / (1.0 + np.exp(-values))
74def _squeeze_to_2d(arr: np.ndarray) -> np.ndarray:
75 """Reduce array dimensions to 2D by removing singleton dimensions.
77 Args:
78 arr: Input array with potential singleton dimensions.
80 Returns:
81 2D array with singleton dimensions removed.
82 """
83 data = np.asarray(arr)
84 data = np.squeeze(data)
85 if data.ndim == 3 and data.shape[0] == 1:
86 data = data[0]
87 return data
90def _looks_like_xywh(boxes: np.ndarray) -> bool:
91 """Detect if boxes are in center format (xywh) vs corner format (xyxy).
93 Heuristic: if many boxes have x2 < x1 or y2 < y1, they're likely xywh.
95 Args:
96 boxes: Array of shape (N, 4).
98 Returns:
99 True if boxes appear to be in center format.
100 """
101 if boxes.size == 0:
102 return False
103 x2_lt_x1 = np.mean(boxes[:, 2] < boxes[:, 0])
104 y2_lt_y1 = np.mean(boxes[:, 3] < boxes[:, 1])
105 return bool((x2_lt_x1 > 0.3) or (y2_lt_y1 > 0.3))
108def _decode_classification(output: np.ndarray) -> dict | None:
109 """Decode classification model output to class prediction.
111 Args:
112 output: Model output array.
114 Returns:
115 Dictionary with class_id, score, and label, or None if not classification.
116 """
117 if output.ndim == 1 or (output.ndim == 2 and output.shape[0] == 1):
118 scores = output if output.ndim == 1 else output[0]
119 scores = scores.astype(np.float32)
120 if np.any(scores < 0) or np.any(scores > 1):
121 scores = _softmax(scores)
122 class_id = int(np.argmax(scores))
123 score = float(scores[class_id])
124 label = CLASS_NAMES[class_id] if class_id < len(CLASS_NAMES) else str(class_id)
125 return {
126 "class_id": class_id,
127 "score": score,
128 "label": label,
129 }
130 return None
133def _log_debug_boxes(
134 boxes: np.ndarray,
135 scores: np.ndarray,
136 class_ids: np.ndarray,
137) -> None:
138 """Log the first 3 decoded boxes and their scores for debugging.
140 Args:
141 boxes: Bounding box array.
142 scores: Confidence score array.
143 class_ids: Class ID array.
144 """
145 logger.info(
146 "Decoded boxes (first 3): {}",
147 boxes[:3].round(2).tolist(),
148 )
149 logger.info(
150 "Decoded scores/classes (first 3): {}",
151 list(
152 zip(
153 scores[:3].round(3).tolist(),
154 class_ids[:3].tolist(),
155 strict=False,
156 )
157 ),
158 )
161def _unscale_and_collect(
162 boxes: np.ndarray,
163 scores: np.ndarray,
164 class_ids: np.ndarray,
165 scale: float,
166 pad_x: int,
167 pad_y: int,
168 conf_threshold: float,
169) -> list[dict]:
170 """Unscale bounding boxes and collect detections above the threshold.
172 Args:
173 boxes: Bounding boxes in input image coordinates.
174 scores: Confidence scores for each detection.
175 class_ids: Class IDs for each detection.
176 scale: Scale factor applied during preprocessing.
177 pad_x: Horizontal padding applied during preprocessing.
178 pad_y: Vertical padding applied during preprocessing.
179 conf_threshold: Minimum confidence threshold.
181 Returns:
182 List of detection dictionaries with bbox, score, and class_id.
183 """
184 detections: list[dict] = []
185 for box, score, class_id in zip(boxes, scores, class_ids, strict=False):
186 if score < conf_threshold:
187 continue
188 x1, y1, x2, y2 = box
189 x1 = (x1 - pad_x) / scale
190 y1 = (y1 - pad_y) / scale
191 x2 = (x2 - pad_x) / scale
192 y2 = (y2 - pad_y) / scale
193 detections.append(
194 {
195 "bbox": [int(x1), int(y1), int(x2), int(y2)],
196 "score": float(score),
197 "class_id": int(class_id),
198 }
199 )
200 return detections
203def _prepare_boxes(
204 boxes: np.ndarray,
205 input_size: tuple[int, int],
206 convert_xywh: bool = True,
207) -> np.ndarray:
208 """Scale boxes to pixel coordinates and optionally convert format.
210 Args:
211 boxes: Raw bounding boxes.
212 input_size: (height, width) of input image.
213 convert_xywh: Whether to convert from xywh to xyxy format.
215 Returns:
216 Processed bounding boxes in pixel coordinates.
217 """
218 height, width = input_size
219 if np.max(boxes) <= 1.5:
220 boxes = boxes * np.array([width, height, width, height], dtype=np.float32)
222 if convert_xywh and _looks_like_xywh(boxes):
223 boxes = _xywh_to_xyxy(boxes)
225 return boxes
228def _decode_triplet_outputs(
229 outputs: Sequence[np.ndarray],
230 input_size: tuple[int, int],
231 scale: float,
232 pad_x: int,
233 pad_y: int,
234 conf_threshold: float,
235 *,
236 debug_boxes: bool,
237) -> list[dict] | None:
238 """Decode triplet output format (boxes, scores, class_ids).
240 Args:
241 outputs: Model output tensors.
242 input_size: Input image dimensions.
243 scale: Preprocessing scale factor.
244 pad_x: Horizontal padding.
245 pad_y: Vertical padding.
246 conf_threshold: Confidence threshold.
247 debug_boxes: Whether to log debug info.
249 Returns:
250 List of detections or None if not triplet format.
251 """
252 if len(outputs) < 3:
253 return None
255 boxes = _squeeze_to_2d(outputs[0])
256 scores = _squeeze_to_2d(outputs[1])
257 class_ids = _squeeze_to_2d(outputs[2])
259 if not (boxes.ndim == 2 and boxes.shape[-1] == 4):
260 return None
262 if scores.ndim > 1:
263 scores = scores.reshape(-1)
264 if class_ids.ndim > 1:
265 class_ids = class_ids.reshape(-1)
267 boxes = _prepare_boxes(boxes, input_size, convert_xywh=True)
269 if debug_boxes:
270 _log_debug_boxes(boxes, scores, class_ids)
272 return _unscale_and_collect(
273 boxes, scores, class_ids, scale, pad_x, pad_y, conf_threshold
274 )
277def _decode_pair_outputs(
278 outputs: Sequence[np.ndarray],
279 input_size: tuple[int, int],
280 scale: float,
281 pad_x: int,
282 pad_y: int,
283 conf_threshold: float,
284 *,
285 debug_boxes: bool,
286) -> list[dict] | None:
287 """Decode pair output format (scores, boxes).
289 Args:
290 outputs: Model output tensors.
291 input_size: Input image dimensions.
292 scale: Preprocessing scale factor.
293 pad_x: Horizontal padding.
294 pad_y: Vertical padding.
295 conf_threshold: Confidence threshold.
296 debug_boxes: Whether to log debug info.
298 Returns:
299 List of detections or None if not pair format.
300 """
301 if len(outputs) < 2:
302 return None
304 scores = _squeeze_to_2d(outputs[0])
305 boxes = _squeeze_to_2d(outputs[1])
307 if not (scores.ndim == 2 and boxes.ndim == 2 and boxes.shape[1] == 4):
308 return None
310 height, width = input_size
311 probs = _softmax(scores) if scores.shape[1] > 1 else scores
312 class_ids = np.argmax(probs, axis=1)
313 confs = probs[np.arange(len(class_ids)), class_ids]
315 if np.min(boxes) < 0.0 or np.max(boxes) > 1.5:
316 boxes = _sigmoid(boxes)
317 boxes = boxes * np.array([width, height, width, height], dtype=np.float32)
318 boxes = _xywh_to_xyxy(boxes)
320 if debug_boxes:
321 _log_debug_boxes(boxes, confs, class_ids)
323 return _unscale_and_collect(
324 boxes, confs, class_ids, scale, pad_x, pad_y, conf_threshold
325 )
328def _normalize_output_data(output: np.ndarray) -> np.ndarray:
329 """Normalize output data to standard 2D format.
331 Args:
332 output: Raw model output.
334 Returns:
335 Normalized 2D array.
336 """
337 data = output
338 if data.ndim == 3:
339 if data.shape[0] == 1:
340 data = data[0]
341 elif data.shape[0] > 1 and data.shape[-1] == 1:
342 data = np.squeeze(data, axis=-1)
344 if data.ndim == 2 and data.shape[0] < data.shape[1] and data.shape[1] >= 6:
345 data = data.T
346 return data
349def _decode_scores_and_boxes(
350 data: np.ndarray,
351 input_size: tuple[int, int],
352) -> tuple[np.ndarray, np.ndarray, np.ndarray] | None:
353 """Extract boxes, scores, and class IDs from normalized output.
355 Args:
356 data: Normalized 2D output data.
357 input_size: Input image dimensions.
359 Returns:
360 Tuple of (boxes, scores, class_ids) or None if invalid.
361 """
362 if data.ndim != 2 or data.shape[1] < 6:
363 return None
365 channels = data.shape[1]
366 height, width = input_size
368 if channels == 6:
369 boxes = data[:, :4]
370 scores = data[:, 4]
371 class_ids = data[:, 5].astype(int)
372 else:
373 boxes = data[:, :4]
374 if channels >= 85:
375 obj = data[:, 4]
376 class_scores = data[:, 5:]
377 class_ids = np.argmax(class_scores, axis=1)
378 scores = obj * class_scores[np.arange(len(class_ids)), class_ids]
379 else:
380 class_scores = data[:, 4:]
381 class_ids = np.argmax(class_scores, axis=1)
382 scores = class_scores[np.arange(len(class_ids)), class_ids]
384 if np.max(boxes) <= 1.5:
385 boxes = boxes * np.array([width, height, width, height], dtype=np.float32)
387 if channels != 6 or _looks_like_xywh(boxes):
388 boxes = _xywh_to_xyxy(boxes)
390 return boxes, scores, class_ids
393def _decode_generic_output(
394 output: np.ndarray,
395 input_size: tuple[int, int],
396 scale: float,
397 pad_x: int,
398 pad_y: int,
399 conf_threshold: float,
400 *,
401 debug_boxes: bool,
402) -> list[dict]:
403 """Decode generic YOLO output format.
405 Args:
406 output: Raw model output.
407 input_size: Input image dimensions.
408 scale: Preprocessing scale factor.
409 pad_x: Horizontal padding.
410 pad_y: Vertical padding.
411 conf_threshold: Confidence threshold.
412 debug_boxes: Whether to log debug info.
414 Returns:
415 List of detections.
416 """
417 data = _normalize_output_data(output)
418 decoded = _decode_scores_and_boxes(data, input_size)
419 if decoded is None:
420 return []
421 boxes, scores, class_ids = decoded
423 if debug_boxes:
424 _log_debug_boxes(boxes, scores, class_ids)
426 return _unscale_and_collect(
427 boxes, scores, class_ids, scale, pad_x, pad_y, conf_threshold
428 )
431def postprocess(
432 outputs: Sequence[np.ndarray],
433 config: DecodeConfig,
434 *,
435 debug_output: bool = False,
436) -> tuple[list, dict | None]:
437 """Parse model outputs for detection or classification models.
439 Args:
440 outputs: Model output tensors.
441 config: Decoding configuration.
442 debug_output: Whether to log output shapes.
444 Returns:
445 Tuple of (detections list, classification dict or None).
446 """
447 detections: list[dict] = []
448 classification: dict | None = None
450 if outputs is None or len(outputs) == 0:
451 return detections, classification
453 if debug_output:
454 logger.info("Model outputs: {}", [np.asarray(out).shape for out in outputs])
456 output = _squeeze_to_2d(outputs[0])
458 if output.ndim == 0:
459 return detections, classification
461 classification = _decode_classification(output)
462 if classification is not None:
463 return detections, classification
465 triplet = _decode_triplet_outputs(
466 outputs,
467 config.input_size,
468 config.scale,
469 config.pad_x,
470 config.pad_y,
471 config.conf_threshold,
472 debug_boxes=config.debug_boxes,
473 )
474 if triplet is not None:
475 return triplet, classification
477 pair = _decode_pair_outputs(
478 outputs,
479 config.input_size,
480 config.scale,
481 config.pad_x,
482 config.pad_y,
483 config.conf_threshold,
484 debug_boxes=config.debug_boxes,
485 )
486 if pair is not None:
487 return pair, classification
489 detections = _decode_generic_output(
490 output,
491 config.input_size,
492 config.scale,
493 config.pad_x,
494 config.pad_y,
495 config.conf_threshold,
496 debug_boxes=config.debug_boxes,
497 )
499 return detections, classification