Coverage for src / rtflite / services / grouping_service.py: 79%

173 statements  

« prev     ^ index     » next       coverage.py v7.12.0, created at 2025-11-28 05:09 +0000

1""" 

2Enhanced group_by functionality for rtflite. 

3 

4This service implements r2rtf-compatible group_by behavior where values in 

5group_by columns are displayed only once per group, with subsequent rows 

6showing blank/suppressed values for better readability. 

7""" 

8 

9from collections.abc import Mapping, MutableSequence, Sequence 

10from typing import Any 

11 

12import polars as pl 

13 

14 

15class GroupingService: 

16 """Service for handling group_by functionality with value suppression""" 

17 

18 def __init__(self): 

19 pass 

20 

21 def enhance_group_by( 

22 self, df: pl.DataFrame, group_by: Sequence[str] | None = None 

23 ) -> pl.DataFrame: 

24 """Apply group_by value suppression to a DataFrame 

25 

26 Args: 

27 df: Input DataFrame 

28 group_by: List of column names to group by. Values will be suppressed 

29 for duplicate rows within groups. 

30 

31 Returns: 

32 DataFrame with group_by columns showing values only on first occurrence 

33 within each group 

34 

35 Raises: 

36 ValueError: If data is not properly sorted by group_by columns 

37 """ 

38 if not group_by or df.is_empty(): 

39 return df 

40 

41 # Validate that all group_by columns exist 

42 missing_cols = [col for col in group_by if col not in df.columns] 

43 if missing_cols: 

44 raise ValueError(f"group_by columns not found in DataFrame: {missing_cols}") 

45 

46 # Validate data sorting for group_by columns 

47 self.validate_data_sorting(df, group_by=group_by) 

48 

49 # Create a copy to avoid modifying original 

50 result_df = df.clone() 

51 

52 # Apply grouping logic based on number of group columns 

53 if len(group_by) == 1: 

54 result_df = self._suppress_single_column(result_df, group_by[0]) 

55 else: 

56 result_df = self._suppress_hierarchical_columns(result_df, group_by) 

57 

58 return result_df 

59 

60 def _suppress_single_column(self, df: pl.DataFrame, column: str) -> pl.DataFrame: 

61 """Suppress duplicate values in a single group column 

62 

63 Args: 

64 df: Input DataFrame 

65 column: Column name to suppress duplicates 

66 

67 Returns: 

68 DataFrame with duplicate values replaced with null 

69 """ 

70 # Create a mask for rows where the value is different from the previous row 

71 is_first_occurrence = (df[column] != df[column].shift(1)) | ( 

72 pl.int_range(df.height) == 0 

73 ) # First row is always shown 

74 

75 # Create suppressed column by setting duplicates to null 

76 suppressed_values = ( 

77 pl.when(is_first_occurrence).then(df[column]).otherwise(None) 

78 ) 

79 

80 # Replace the original column with suppressed version 

81 result_df = df.with_columns(suppressed_values.alias(column)) 

82 

83 return result_df 

84 

85 def _suppress_hierarchical_columns( 

86 self, df: pl.DataFrame, group_by: Sequence[str] 

87 ) -> pl.DataFrame: 

88 """Suppress duplicate values in hierarchical group columns 

89 

90 For multiple group columns, values are suppressed hierarchically: 

91 - First column: only shows when it changes 

92 - Second column: shows when first column changes OR when it changes 

93 - And so on... 

94 

95 Args: 

96 df: Input DataFrame 

97 group_by: List of column names in hierarchical order 

98 

99 Returns: 

100 DataFrame with hierarchical value suppression 

101 """ 

102 result_df = df.clone() 

103 

104 for i, column in enumerate(group_by): 

105 # For hierarchical grouping, a value should be shown if: 

106 # 1. It's the first row, OR 

107 # 2. Any of the higher-level group columns have changed, OR 

108 # 3. This column's value has changed 

109 

110 conditions = [] 

111 

112 # First row condition 

113 conditions.append(pl.int_range(df.height) == 0) 

114 

115 # Higher-level columns changed condition 

116 for higher_col in group_by[:i]: 

117 conditions.append(pl.col(higher_col) != pl.col(higher_col).shift(1)) 

118 

119 # This column changed condition 

120 conditions.append(pl.col(column) != pl.col(column).shift(1)) 

121 

122 # Combine all conditions with OR 

123 should_show = conditions[0] 

124 for condition in conditions[1:]: 

125 should_show = should_show | condition 

126 

127 # Apply suppression 

128 suppressed_values = ( 

129 pl.when(should_show).then(pl.col(column)).otherwise(None) 

130 ) 

131 result_df = result_df.with_columns(suppressed_values.alias(column)) 

132 

133 return result_df 

134 

135 def restore_page_context( 

136 self, 

137 suppressed_df: pl.DataFrame, 

138 original_df: pl.DataFrame, 

139 group_by: Sequence[str], 

140 page_start_indices: Sequence[int], 

141 ) -> pl.DataFrame: 

142 """Restore group context at the beginning of new pages 

143 

144 When content spans multiple pages, the first row of each new page 

145 should show the group values for context, even if they were suppressed 

146 in the continuous flow. 

147 

148 Args: 

149 suppressed_df: DataFrame with group_by suppression applied 

150 original_df: Original DataFrame with all values 

151 group_by: List of group columns 

152 page_start_indices: List of row indices where new pages start 

153 

154 Returns: 

155 DataFrame with group context restored at page boundaries 

156 """ 

157 if not group_by or not page_start_indices: 

158 return suppressed_df 

159 

160 result_df = suppressed_df.clone() 

161 

162 # For each page start, restore the group values from original data 

163 for page_start_idx in page_start_indices: 

164 if page_start_idx < len(original_df): 

165 # Create updates for each group column 

166 for col in group_by: 

167 # Get the original value for this row 

168 original_value = original_df[col][page_start_idx] 

169 

170 # Update the result DataFrame at this position 

171 # Create a mask for this specific row 

172 mask = pl.int_range(len(result_df)) == page_start_idx 

173 

174 # Update the column value where the mask is true 

175 result_df = result_df.with_columns( 

176 pl.when(mask) 

177 .then(pl.lit(original_value)) 

178 .otherwise(pl.col(col)) 

179 .alias(col) 

180 ) 

181 

182 return result_df 

183 

184 def get_group_structure( 

185 self, df: pl.DataFrame, group_by: Sequence[str] 

186 ) -> Mapping[str, Any]: 

187 """Analyze the group structure of a DataFrame 

188 

189 Args: 

190 df: Input DataFrame 

191 group_by: List of group columns 

192 

193 Returns: 

194 Dictionary with group structure information 

195 """ 

196 if not group_by or df.is_empty(): 

197 return {"groups": 0, "structure": {}} 

198 

199 # Count unique combinations at each level 

200 structure = {} 

201 

202 for i, _col in enumerate(group_by): 

203 level_cols = group_by[: i + 1] 

204 unique_combinations = df.select(level_cols).unique().height 

205 structure[f"level_{i + 1}"] = { 

206 "columns": level_cols, 

207 "unique_combinations": unique_combinations, 

208 } 

209 

210 # Overall statistics 

211 total_groups = df.select(group_by).unique().height 

212 

213 return { 

214 "total_groups": total_groups, 

215 "levels": len(group_by), 

216 "structure": structure, 

217 } 

218 

219 def validate_group_by_columns( 

220 self, df: pl.DataFrame, group_by: Sequence[str] 

221 ) -> Sequence[str]: 

222 """Validate group_by columns and return any issues 

223 

224 Args: 

225 df: Input DataFrame 

226 group_by: List of group columns to validate 

227 

228 Returns: 

229 List of validation issues (empty if all valid) 

230 """ 

231 issues: MutableSequence[str] = [] 

232 

233 if not group_by: 

234 return issues 

235 

236 # Check if columns exist 

237 missing_cols = [col for col in group_by if col not in df.columns] 

238 if missing_cols: 

239 issues.append(f"Missing columns: {missing_cols}") 

240 

241 # Check for empty DataFrame 

242 if df.is_empty(): 

243 issues.append("DataFrame is empty") 

244 

245 # Check for columns with all null values 

246 for col in group_by: 

247 if col in df.columns: 

248 null_count = df[col].null_count() 

249 if null_count == df.height: 

250 issues.append(f"Column '{col}' contains only null values") 

251 

252 return issues 

253 

254 def validate_data_sorting( 

255 self, 

256 df: pl.DataFrame, 

257 group_by: Sequence[str] | None = None, 

258 page_by: Sequence[str] | None = None, 

259 subline_by: Sequence[str] | None = None, 

260 ) -> None: 

261 """Validate that data is properly sorted for grouping operations 

262 

263 Based on r2rtf logic: ensures data is sorted by all grouping variables 

264 in the correct order for proper group_by, page_by, and subline_by functionality. 

265 

266 Args: 

267 df: Input DataFrame to validate 

268 group_by: List of group_by columns (optional) 

269 page_by: List of page_by columns (optional) 

270 subline_by: List of subline_by columns (optional) 

271 

272 Raises: 

273 ValueError: If data is not properly sorted or 

274 if there are overlapping columns 

275 """ 

276 if df.is_empty(): 

277 return 

278 

279 # Collect all grouping variables 

280 all_grouping_vars: list[str] = [] 

281 

282 # Add variables in priority order (page_by, subline_by, group_by) 

283 if page_by: 

284 all_grouping_vars.extend(page_by) 

285 if subline_by: 

286 all_grouping_vars.extend(subline_by) 

287 if group_by: 

288 all_grouping_vars.extend(group_by) 

289 

290 if not all_grouping_vars: 

291 return # No grouping variables to validate 

292 

293 # Check for overlapping variables between different grouping types 

294 self._validate_no_overlapping_grouping_vars(group_by, page_by, subline_by) 

295 

296 # Remove duplicates while preserving order 

297 unique_vars = [] 

298 seen = set() 

299 for var in all_grouping_vars: 

300 if var not in seen: 

301 unique_vars.append(var) 

302 seen.add(var) 

303 

304 # Validate all grouping columns exist 

305 missing_cols = [col for col in unique_vars if col not in df.columns] 

306 if missing_cols: 

307 raise ValueError(f"Grouping columns not found in DataFrame: {missing_cols}") 

308 

309 # Check if groups are contiguous (values in same group are together) 

310 # This ensures proper grouping behavior without requiring alphabetical sorting 

311 

312 # For each grouping variable, check if its groups are contiguous 

313 for i, var in enumerate(unique_vars): 

314 # Get the values for this variable and all previous variables 

315 group_cols = unique_vars[: i + 1] 

316 

317 # Create a key for each row based on grouping columns up to this level 

318 if i == 0: 

319 # For the first variable, just check if its values are contiguous 

320 values = df[var].to_list() 

321 current_value = values[0] 

322 seen_values = {current_value} 

323 

324 for j in range(1, len(values)): 

325 if values[j] != current_value: 

326 if values[j] in seen_values: 

327 # Found a value that appeared before but with 

328 # different values in between 

329 raise ValueError( 

330 f"Data is not properly grouped by '{var}'. " 

331 "Values with the same " 

332 f"'{var}' must be contiguous. Found " 

333 f"'{values[j]}' at position {j} but it also " 

334 "appeared earlier. Please reorder your data so " 

335 f"that all rows with the same '{var}' are " 

336 "together." 

337 ) 

338 current_value = values[j] 

339 seen_values.add(current_value) 

340 else: 

341 # For subsequent variables, check contiguity within parent groups 

342 # Create a composite key from all grouping variables up to this level 

343 # Handle null values by first converting to string with null handling 

344 df_with_key = df.with_columns( 

345 [ 

346 pl.col(col) 

347 .cast(pl.Utf8) 

348 .fill_null("__NULL__") 

349 .alias(f"_str_{col}") 

350 for col in group_cols 

351 ] 

352 ) 

353 

354 # Create the group key from the string columns 

355 str_cols = [f"_str_{col}" for col in group_cols] 

356 df_with_key = df_with_key.with_columns( 

357 pl.concat_str(str_cols, separator="|").alias("_group_key") 

358 ) 

359 

360 group_keys = df_with_key["_group_key"].to_list() 

361 current_key = group_keys[0] 

362 seen_keys = {current_key} 

363 

364 for j in range(1, len(group_keys)): 

365 if group_keys[j] != current_key: 

366 if group_keys[j] in seen_keys: 

367 # Found a group that appeared before 

368 group_values = df.row(j, named=True) 

369 key_parts = [ 

370 f"{col}='{group_values[col]}'" for col in group_cols 

371 ] 

372 key_desc = ", ".join(key_parts) 

373 

374 raise ValueError( 

375 "Data is not properly grouped. " 

376 f"Group with {key_desc} appears in multiple " 

377 "non-contiguous sections. Please reorder your " 

378 "data so that rows with the same grouping " 

379 "values are together." 

380 ) 

381 current_key = group_keys[j] 

382 seen_keys.add(current_key) 

383 

384 def validate_subline_formatting_consistency( 

385 self, df: pl.DataFrame, subline_by: Sequence[str], rtf_body 

386 ) -> Sequence[str]: 

387 """Validate that formatting is consistent within each column after broadcasting 

388 

389 When using subline_by, we need to ensure that after broadcasting formatting 

390 attributes, each remaining column (after removing subline_by columns) has 

391 consistent formatting values. Otherwise, different rows within a subline 

392 group would have different formatting. 

393 

394 Args: 

395 df: Input DataFrame 

396 subline_by: List of subline_by columns 

397 rtf_body: RTFBody instance with formatting attributes 

398 

399 Returns: 

400 List of warning messages 

401 """ 

402 warnings: MutableSequence[str] = [] 

403 

404 if not subline_by or df.is_empty(): 

405 return warnings 

406 

407 # Get the columns that will remain after removing subline_by columns 

408 remaining_cols = [col for col in df.columns if col not in subline_by] 

409 if not remaining_cols: 

410 return warnings 

411 

412 num_cols = len(remaining_cols) 

413 num_rows = df.height 

414 

415 # Format attributes to check 

416 format_attributes = [ 

417 "text_format", 

418 "text_justification", 

419 "text_font_size", 

420 "text_color", 

421 "border_top", 

422 "border_bottom", 

423 "border_left", 

424 "border_right", 

425 "border_color_top", 

426 "border_color_bottom", 

427 "border_color_left", 

428 "border_color_right", 

429 ] 

430 

431 for attr_name in format_attributes: 

432 if hasattr(rtf_body, attr_name): 

433 attr_value = getattr(rtf_body, attr_name) 

434 if attr_value is None: 

435 continue 

436 

437 # Use BroadcastValue to expand the attribute to full matrix 

438 from ..attributes import BroadcastValue 

439 

440 try: 

441 broadcast_obj = BroadcastValue( 

442 value=attr_value, dimension=(num_rows, num_cols) 

443 ) 

444 broadcasted = broadcast_obj.to_list() 

445 except Exception: 

446 # If broadcasting fails, skip this attribute 

447 continue 

448 

449 # Skip if broadcasting returned None 

450 if broadcasted is None: 

451 continue 

452 

453 # Check each column for consistency 

454 for col_idx in range(num_cols): 

455 # Get all values for this column 

456 col_values = [ 

457 broadcasted[row_idx][col_idx] for row_idx in range(num_rows) 

458 ] 

459 

460 # Filter out None and empty string values 

461 meaningful_values = [v for v in col_values if v not in [None, ""]] 

462 if not meaningful_values: 

463 continue 

464 

465 # Check if all values in this column are the same 

466 unique_values = set(meaningful_values) 

467 if len(unique_values) > 1: 

468 col_name = remaining_cols[col_idx] 

469 warnings.append( 

470 "Column " 

471 f"'{col_name}' has inconsistent {attr_name} values " 

472 f"{list(unique_values)} after broadcasting. When " 

473 "using subline_by, formatting should be consistent " 

474 "within each column to ensure uniform appearance " 

475 "within subline groups." 

476 ) 

477 

478 return warnings 

479 

480 def _validate_no_overlapping_grouping_vars( 

481 self, 

482 group_by: Sequence[str] | None = None, 

483 page_by: Sequence[str] | None = None, 

484 subline_by: Sequence[str] | None = None, 

485 ) -> None: 

486 """Validate that grouping variables don't overlap between different types 

487 

488 Based on r2rtf validation logic to prevent conflicts between 

489 group_by, page_by, and subline_by parameters. 

490 

491 Args: 

492 group_by: List of group_by columns (optional) 

493 page_by: List of page_by columns (optional) 

494 subline_by: List of subline_by columns (optional) 

495 

496 Raises: 

497 ValueError: If there are overlapping variables between grouping types 

498 """ 

499 # Convert None to empty lists for easier processing 

500 group_by = group_by or [] 

501 page_by = page_by or [] 

502 subline_by = subline_by or [] 

503 

504 # Check for overlaps between each pair 

505 overlaps = [] 

506 

507 # group_by vs page_by 

508 group_page_overlap = set(group_by) & set(page_by) 

509 if group_page_overlap: 

510 overlaps.append(f"group_by and page_by: {sorted(group_page_overlap)}") 

511 

512 # group_by vs subline_by 

513 group_subline_overlap = set(group_by) & set(subline_by) 

514 if group_subline_overlap: 

515 overlaps.append(f"group_by and subline_by: {sorted(group_subline_overlap)}") 

516 

517 # page_by vs subline_by 

518 page_subline_overlap = set(page_by) & set(subline_by) 

519 if page_subline_overlap: 

520 overlaps.append(f"page_by and subline_by: {sorted(page_subline_overlap)}") 

521 

522 if overlaps: 

523 overlap_details = "; ".join(overlaps) 

524 raise ValueError( 

525 "Overlapping variables found between grouping parameters: " 

526 f"{overlap_details}. Each variable can only be used in one " 

527 "grouping parameter (group_by, page_by, or subline_by)." 

528 ) 

529 

530 

531# Create a singleton instance for easy access 

532grouping_service = GroupingService()