Coverage for orchestr_ant_ion / yolo / core / postprocess.py: 0%

170 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-03-19 08:36 +0000

1"""Post-processing utilities for YOLO model outputs.""" 

2 

3from __future__ import annotations 

4 

5from dataclasses import dataclass 

6from typing import TYPE_CHECKING 

7 

8import numpy as np 

9from loguru import logger 

10 

11from orchestr_ant_ion.pipeline.constants import POSTPROCESS_DEFAULT_CONF_THRESHOLD 

12from orchestr_ant_ion.yolo.core.constants import CLASS_NAMES 

13 

14 

15if TYPE_CHECKING: 

16 from collections.abc import Sequence 

17 

18 

19@dataclass 

20class DecodeConfig: 

21 """Configuration for post-processing model outputs.""" 

22 

23 scale: float 

24 pad_x: int 

25 pad_y: int 

26 input_size: tuple[int, int] 

27 conf_threshold: float = POSTPROCESS_DEFAULT_CONF_THRESHOLD 

28 debug_boxes: bool = False 

29 

30 

31def _softmax(scores: np.ndarray) -> np.ndarray: 

32 """Apply softmax normalization to a score array. 

33 

34 Args: 

35 scores: Raw score array (any shape). 

36 

37 Returns: 

38 Normalized probability array with same shape. 

39 """ 

40 shifted = scores - np.max(scores) 

41 exp_scores = np.exp(shifted) 

42 return exp_scores / np.sum(exp_scores) 

43 

44 

45def _xywh_to_xyxy(boxes: np.ndarray) -> np.ndarray: 

46 """Convert bounding boxes from center format to corner format. 

47 

48 Args: 

49 boxes: Array of shape (N, 4) with [cx, cy, w, h] format. 

50 

51 Returns: 

52 Array of shape (N, 4) with [x1, y1, x2, y2] format. 

53 """ 

54 x, y, w, h = boxes.T 

55 x1 = x - w / 2 

56 y1 = y - h / 2 

57 x2 = x + w / 2 

58 y2 = y + h / 2 

59 return np.stack([x1, y1, x2, y2], axis=1) 

60 

61 

62def _sigmoid(values: np.ndarray) -> np.ndarray: 

63 """Apply sigmoid activation to values. 

64 

65 Args: 

66 values: Input array (any shape). 

67 

68 Returns: 

69 Sigmoid-activated array with same shape. 

70 """ 

71 return 1.0 / (1.0 + np.exp(-values)) 

72 

73 

74def _squeeze_to_2d(arr: np.ndarray) -> np.ndarray: 

75 """Reduce array dimensions to 2D by removing singleton dimensions. 

76 

77 Args: 

78 arr: Input array with potential singleton dimensions. 

79 

80 Returns: 

81 2D array with singleton dimensions removed. 

82 """ 

83 data = np.asarray(arr) 

84 data = np.squeeze(data) 

85 if data.ndim == 3 and data.shape[0] == 1: 

86 data = data[0] 

87 return data 

88 

89 

90def _looks_like_xywh(boxes: np.ndarray) -> bool: 

91 """Detect if boxes are in center format (xywh) vs corner format (xyxy). 

92 

93 Heuristic: if many boxes have x2 < x1 or y2 < y1, they're likely xywh. 

94 

95 Args: 

96 boxes: Array of shape (N, 4). 

97 

98 Returns: 

99 True if boxes appear to be in center format. 

100 """ 

101 if boxes.size == 0: 

102 return False 

103 x2_lt_x1 = np.mean(boxes[:, 2] < boxes[:, 0]) 

104 y2_lt_y1 = np.mean(boxes[:, 3] < boxes[:, 1]) 

105 return bool((x2_lt_x1 > 0.3) or (y2_lt_y1 > 0.3)) 

106 

107 

108def _decode_classification(output: np.ndarray) -> dict | None: 

109 """Decode classification model output to class prediction. 

110 

111 Args: 

112 output: Model output array. 

113 

114 Returns: 

115 Dictionary with class_id, score, and label, or None if not classification. 

116 """ 

117 if output.ndim == 1 or (output.ndim == 2 and output.shape[0] == 1): 

118 scores = output if output.ndim == 1 else output[0] 

119 scores = scores.astype(np.float32) 

120 if np.any(scores < 0) or np.any(scores > 1): 

121 scores = _softmax(scores) 

122 class_id = int(np.argmax(scores)) 

123 score = float(scores[class_id]) 

124 label = CLASS_NAMES[class_id] if class_id < len(CLASS_NAMES) else str(class_id) 

125 return { 

126 "class_id": class_id, 

127 "score": score, 

128 "label": label, 

129 } 

130 return None 

131 

132 

133def _log_debug_boxes( 

134 boxes: np.ndarray, 

135 scores: np.ndarray, 

136 class_ids: np.ndarray, 

137) -> None: 

138 """Log the first 3 decoded boxes and their scores for debugging. 

139 

140 Args: 

141 boxes: Bounding box array. 

142 scores: Confidence score array. 

143 class_ids: Class ID array. 

144 """ 

145 logger.info( 

146 "Decoded boxes (first 3): {}", 

147 boxes[:3].round(2).tolist(), 

148 ) 

149 logger.info( 

150 "Decoded scores/classes (first 3): {}", 

151 list( 

152 zip( 

153 scores[:3].round(3).tolist(), 

154 class_ids[:3].tolist(), 

155 strict=False, 

156 ) 

157 ), 

158 ) 

159 

160 

161def _unscale_and_collect( 

162 boxes: np.ndarray, 

163 scores: np.ndarray, 

164 class_ids: np.ndarray, 

165 scale: float, 

166 pad_x: int, 

167 pad_y: int, 

168 conf_threshold: float, 

169) -> list[dict]: 

170 """Unscale bounding boxes and collect detections above the threshold. 

171 

172 Args: 

173 boxes: Bounding boxes in input image coordinates. 

174 scores: Confidence scores for each detection. 

175 class_ids: Class IDs for each detection. 

176 scale: Scale factor applied during preprocessing. 

177 pad_x: Horizontal padding applied during preprocessing. 

178 pad_y: Vertical padding applied during preprocessing. 

179 conf_threshold: Minimum confidence threshold. 

180 

181 Returns: 

182 List of detection dictionaries with bbox, score, and class_id. 

183 """ 

184 detections: list[dict] = [] 

185 for box, score, class_id in zip(boxes, scores, class_ids, strict=False): 

186 if score < conf_threshold: 

187 continue 

188 x1, y1, x2, y2 = box 

189 x1 = (x1 - pad_x) / scale 

190 y1 = (y1 - pad_y) / scale 

191 x2 = (x2 - pad_x) / scale 

192 y2 = (y2 - pad_y) / scale 

193 detections.append( 

194 { 

195 "bbox": [int(x1), int(y1), int(x2), int(y2)], 

196 "score": float(score), 

197 "class_id": int(class_id), 

198 } 

199 ) 

200 return detections 

201 

202 

203def _prepare_boxes( 

204 boxes: np.ndarray, 

205 input_size: tuple[int, int], 

206 convert_xywh: bool = True, 

207) -> np.ndarray: 

208 """Scale boxes to pixel coordinates and optionally convert format. 

209 

210 Args: 

211 boxes: Raw bounding boxes. 

212 input_size: (height, width) of input image. 

213 convert_xywh: Whether to convert from xywh to xyxy format. 

214 

215 Returns: 

216 Processed bounding boxes in pixel coordinates. 

217 """ 

218 height, width = input_size 

219 if np.max(boxes) <= 1.5: 

220 boxes = boxes * np.array([width, height, width, height], dtype=np.float32) 

221 

222 if convert_xywh and _looks_like_xywh(boxes): 

223 boxes = _xywh_to_xyxy(boxes) 

224 

225 return boxes 

226 

227 

228def _decode_triplet_outputs( 

229 outputs: Sequence[np.ndarray], 

230 input_size: tuple[int, int], 

231 scale: float, 

232 pad_x: int, 

233 pad_y: int, 

234 conf_threshold: float, 

235 *, 

236 debug_boxes: bool, 

237) -> list[dict] | None: 

238 """Decode triplet output format (boxes, scores, class_ids). 

239 

240 Args: 

241 outputs: Model output tensors. 

242 input_size: Input image dimensions. 

243 scale: Preprocessing scale factor. 

244 pad_x: Horizontal padding. 

245 pad_y: Vertical padding. 

246 conf_threshold: Confidence threshold. 

247 debug_boxes: Whether to log debug info. 

248 

249 Returns: 

250 List of detections or None if not triplet format. 

251 """ 

252 if len(outputs) < 3: 

253 return None 

254 

255 boxes = _squeeze_to_2d(outputs[0]) 

256 scores = _squeeze_to_2d(outputs[1]) 

257 class_ids = _squeeze_to_2d(outputs[2]) 

258 

259 if not (boxes.ndim == 2 and boxes.shape[-1] == 4): 

260 return None 

261 

262 if scores.ndim > 1: 

263 scores = scores.reshape(-1) 

264 if class_ids.ndim > 1: 

265 class_ids = class_ids.reshape(-1) 

266 

267 boxes = _prepare_boxes(boxes, input_size, convert_xywh=True) 

268 

269 if debug_boxes: 

270 _log_debug_boxes(boxes, scores, class_ids) 

271 

272 return _unscale_and_collect( 

273 boxes, scores, class_ids, scale, pad_x, pad_y, conf_threshold 

274 ) 

275 

276 

277def _decode_pair_outputs( 

278 outputs: Sequence[np.ndarray], 

279 input_size: tuple[int, int], 

280 scale: float, 

281 pad_x: int, 

282 pad_y: int, 

283 conf_threshold: float, 

284 *, 

285 debug_boxes: bool, 

286) -> list[dict] | None: 

287 """Decode pair output format (scores, boxes). 

288 

289 Args: 

290 outputs: Model output tensors. 

291 input_size: Input image dimensions. 

292 scale: Preprocessing scale factor. 

293 pad_x: Horizontal padding. 

294 pad_y: Vertical padding. 

295 conf_threshold: Confidence threshold. 

296 debug_boxes: Whether to log debug info. 

297 

298 Returns: 

299 List of detections or None if not pair format. 

300 """ 

301 if len(outputs) < 2: 

302 return None 

303 

304 scores = _squeeze_to_2d(outputs[0]) 

305 boxes = _squeeze_to_2d(outputs[1]) 

306 

307 if not (scores.ndim == 2 and boxes.ndim == 2 and boxes.shape[1] == 4): 

308 return None 

309 

310 height, width = input_size 

311 probs = _softmax(scores) if scores.shape[1] > 1 else scores 

312 class_ids = np.argmax(probs, axis=1) 

313 confs = probs[np.arange(len(class_ids)), class_ids] 

314 

315 if np.min(boxes) < 0.0 or np.max(boxes) > 1.5: 

316 boxes = _sigmoid(boxes) 

317 boxes = boxes * np.array([width, height, width, height], dtype=np.float32) 

318 boxes = _xywh_to_xyxy(boxes) 

319 

320 if debug_boxes: 

321 _log_debug_boxes(boxes, confs, class_ids) 

322 

323 return _unscale_and_collect( 

324 boxes, confs, class_ids, scale, pad_x, pad_y, conf_threshold 

325 ) 

326 

327 

328def _normalize_output_data(output: np.ndarray) -> np.ndarray: 

329 """Normalize output data to standard 2D format. 

330 

331 Args: 

332 output: Raw model output. 

333 

334 Returns: 

335 Normalized 2D array. 

336 """ 

337 data = output 

338 if data.ndim == 3: 

339 if data.shape[0] == 1: 

340 data = data[0] 

341 elif data.shape[0] > 1 and data.shape[-1] == 1: 

342 data = np.squeeze(data, axis=-1) 

343 

344 if data.ndim == 2 and data.shape[0] < data.shape[1] and data.shape[1] >= 6: 

345 data = data.T 

346 return data 

347 

348 

349def _decode_scores_and_boxes( 

350 data: np.ndarray, 

351 input_size: tuple[int, int], 

352) -> tuple[np.ndarray, np.ndarray, np.ndarray] | None: 

353 """Extract boxes, scores, and class IDs from normalized output. 

354 

355 Args: 

356 data: Normalized 2D output data. 

357 input_size: Input image dimensions. 

358 

359 Returns: 

360 Tuple of (boxes, scores, class_ids) or None if invalid. 

361 """ 

362 if data.ndim != 2 or data.shape[1] < 6: 

363 return None 

364 

365 channels = data.shape[1] 

366 height, width = input_size 

367 

368 if channels == 6: 

369 boxes = data[:, :4] 

370 scores = data[:, 4] 

371 class_ids = data[:, 5].astype(int) 

372 else: 

373 boxes = data[:, :4] 

374 if channels >= 85: 

375 obj = data[:, 4] 

376 class_scores = data[:, 5:] 

377 class_ids = np.argmax(class_scores, axis=1) 

378 scores = obj * class_scores[np.arange(len(class_ids)), class_ids] 

379 else: 

380 class_scores = data[:, 4:] 

381 class_ids = np.argmax(class_scores, axis=1) 

382 scores = class_scores[np.arange(len(class_ids)), class_ids] 

383 

384 if np.max(boxes) <= 1.5: 

385 boxes = boxes * np.array([width, height, width, height], dtype=np.float32) 

386 

387 if channels != 6 or _looks_like_xywh(boxes): 

388 boxes = _xywh_to_xyxy(boxes) 

389 

390 return boxes, scores, class_ids 

391 

392 

393def _decode_generic_output( 

394 output: np.ndarray, 

395 input_size: tuple[int, int], 

396 scale: float, 

397 pad_x: int, 

398 pad_y: int, 

399 conf_threshold: float, 

400 *, 

401 debug_boxes: bool, 

402) -> list[dict]: 

403 """Decode generic YOLO output format. 

404 

405 Args: 

406 output: Raw model output. 

407 input_size: Input image dimensions. 

408 scale: Preprocessing scale factor. 

409 pad_x: Horizontal padding. 

410 pad_y: Vertical padding. 

411 conf_threshold: Confidence threshold. 

412 debug_boxes: Whether to log debug info. 

413 

414 Returns: 

415 List of detections. 

416 """ 

417 data = _normalize_output_data(output) 

418 decoded = _decode_scores_and_boxes(data, input_size) 

419 if decoded is None: 

420 return [] 

421 boxes, scores, class_ids = decoded 

422 

423 if debug_boxes: 

424 _log_debug_boxes(boxes, scores, class_ids) 

425 

426 return _unscale_and_collect( 

427 boxes, scores, class_ids, scale, pad_x, pad_y, conf_threshold 

428 ) 

429 

430 

431def postprocess( 

432 outputs: Sequence[np.ndarray], 

433 config: DecodeConfig, 

434 *, 

435 debug_output: bool = False, 

436) -> tuple[list, dict | None]: 

437 """Parse model outputs for detection or classification models. 

438 

439 Args: 

440 outputs: Model output tensors. 

441 config: Decoding configuration. 

442 debug_output: Whether to log output shapes. 

443 

444 Returns: 

445 Tuple of (detections list, classification dict or None). 

446 """ 

447 detections: list[dict] = [] 

448 classification: dict | None = None 

449 

450 if outputs is None or len(outputs) == 0: 

451 return detections, classification 

452 

453 if debug_output: 

454 logger.info("Model outputs: {}", [np.asarray(out).shape for out in outputs]) 

455 

456 output = _squeeze_to_2d(outputs[0]) 

457 

458 if output.ndim == 0: 

459 return detections, classification 

460 

461 classification = _decode_classification(output) 

462 if classification is not None: 

463 return detections, classification 

464 

465 triplet = _decode_triplet_outputs( 

466 outputs, 

467 config.input_size, 

468 config.scale, 

469 config.pad_x, 

470 config.pad_y, 

471 config.conf_threshold, 

472 debug_boxes=config.debug_boxes, 

473 ) 

474 if triplet is not None: 

475 return triplet, classification 

476 

477 pair = _decode_pair_outputs( 

478 outputs, 

479 config.input_size, 

480 config.scale, 

481 config.pad_x, 

482 config.pad_y, 

483 config.conf_threshold, 

484 debug_boxes=config.debug_boxes, 

485 ) 

486 if pair is not None: 

487 return pair, classification 

488 

489 detections = _decode_generic_output( 

490 output, 

491 config.input_size, 

492 config.scale, 

493 config.pad_x, 

494 config.pad_y, 

495 config.conf_threshold, 

496 debug_boxes=config.debug_boxes, 

497 ) 

498 

499 return detections, classification