Coverage for src/rtflite/services/grouping_service.py: 78%

172 statements  

« prev     ^ index     » next       coverage.py v7.10.3, created at 2025-08-14 16:35 +0000

1""" 

2Enhanced group_by functionality for rtflite. 

3 

4This service implements r2rtf-compatible group_by behavior where values in 

5group_by columns are displayed only once per group, with subsequent rows 

6showing blank/suppressed values for better readability. 

7""" 

8 

9from typing import Any, Dict 

10 

11import polars as pl 

12 

13 

14class GroupingService: 

15 """Service for handling group_by functionality with value suppression""" 

16 

17 def __init__(self): 

18 pass 

19 

20 def enhance_group_by( 

21 self, df: pl.DataFrame, group_by: list[str] | None = None 

22 ) -> pl.DataFrame: 

23 """Apply group_by value suppression to a DataFrame 

24 

25 Args: 

26 df: Input DataFrame 

27 group_by: List of column names to group by. Values will be suppressed 

28 for duplicate rows within groups. 

29 

30 Returns: 

31 DataFrame with group_by columns showing values only on first occurrence 

32 within each group 

33 

34 Raises: 

35 ValueError: If data is not properly sorted by group_by columns 

36 """ 

37 if not group_by or df.is_empty(): 

38 return df 

39 

40 # Validate that all group_by columns exist 

41 missing_cols = [col for col in group_by if col not in df.columns] 

42 if missing_cols: 

43 raise ValueError(f"group_by columns not found in DataFrame: {missing_cols}") 

44 

45 # Validate data sorting for group_by columns 

46 self.validate_data_sorting(df, group_by=group_by) 

47 

48 # Create a copy to avoid modifying original 

49 result_df = df.clone() 

50 

51 # Apply grouping logic based on number of group columns 

52 if len(group_by) == 1: 

53 result_df = self._suppress_single_column(result_df, group_by[0]) 

54 else: 

55 result_df = self._suppress_hierarchical_columns(result_df, group_by) 

56 

57 return result_df 

58 

59 def _suppress_single_column(self, df: pl.DataFrame, column: str) -> pl.DataFrame: 

60 """Suppress duplicate values in a single group column 

61 

62 Args: 

63 df: Input DataFrame 

64 column: Column name to suppress duplicates 

65 

66 Returns: 

67 DataFrame with duplicate values replaced with null 

68 """ 

69 # Create a mask for rows where the value is different from the previous row 

70 is_first_occurrence = (df[column] != df[column].shift(1)) | ( 

71 pl.int_range(df.height) == 0 

72 ) # First row is always shown 

73 

74 # Create suppressed column by setting duplicates to null 

75 suppressed_values = ( 

76 pl.when(is_first_occurrence).then(df[column]).otherwise(None) 

77 ) 

78 

79 # Replace the original column with suppressed version 

80 result_df = df.with_columns(suppressed_values.alias(column)) 

81 

82 return result_df 

83 

84 def _suppress_hierarchical_columns( 

85 self, df: pl.DataFrame, group_by: list[str] 

86 ) -> pl.DataFrame: 

87 """Suppress duplicate values in hierarchical group columns 

88 

89 For multiple group columns, values are suppressed hierarchically: 

90 - First column: only shows when it changes 

91 - Second column: shows when first column changes OR when it changes 

92 - And so on... 

93 

94 Args: 

95 df: Input DataFrame 

96 group_by: List of column names in hierarchical order 

97 

98 Returns: 

99 DataFrame with hierarchical value suppression 

100 """ 

101 result_df = df.clone() 

102 

103 for i, column in enumerate(group_by): 

104 # For hierarchical grouping, a value should be shown if: 

105 # 1. It's the first row, OR 

106 # 2. Any of the higher-level group columns have changed, OR 

107 # 3. This column's value has changed 

108 

109 conditions = [] 

110 

111 # First row condition 

112 conditions.append(pl.int_range(df.height) == 0) 

113 

114 # Higher-level columns changed condition 

115 for higher_col in group_by[:i]: 

116 conditions.append(pl.col(higher_col) != pl.col(higher_col).shift(1)) 

117 

118 # This column changed condition 

119 conditions.append(pl.col(column) != pl.col(column).shift(1)) 

120 

121 # Combine all conditions with OR 

122 should_show = conditions[0] 

123 for condition in conditions[1:]: 

124 should_show = should_show | condition 

125 

126 # Apply suppression 

127 suppressed_values = ( 

128 pl.when(should_show).then(pl.col(column)).otherwise(None) 

129 ) 

130 result_df = result_df.with_columns(suppressed_values.alias(column)) 

131 

132 return result_df 

133 

134 def restore_page_context( 

135 self, 

136 suppressed_df: pl.DataFrame, 

137 original_df: pl.DataFrame, 

138 group_by: list[str], 

139 page_start_indices: list[int], 

140 ) -> pl.DataFrame: 

141 """Restore group context at the beginning of new pages 

142 

143 When content spans multiple pages, the first row of each new page 

144 should show the group values for context, even if they were suppressed 

145 in the continuous flow. 

146 

147 Args: 

148 suppressed_df: DataFrame with group_by suppression applied 

149 original_df: Original DataFrame with all values 

150 group_by: List of group columns 

151 page_start_indices: List of row indices where new pages start 

152 

153 Returns: 

154 DataFrame with group context restored at page boundaries 

155 """ 

156 if not group_by or not page_start_indices: 

157 return suppressed_df 

158 

159 result_df = suppressed_df.clone() 

160 

161 # For each page start, restore the group values from original data 

162 for page_start_idx in page_start_indices: 

163 if page_start_idx < len(original_df): 

164 # Create updates for each group column 

165 for col in group_by: 

166 # Get the original value for this row 

167 original_value = original_df[col][page_start_idx] 

168 

169 # Update the result DataFrame at this position 

170 # Create a mask for this specific row 

171 mask = pl.int_range(len(result_df)) == page_start_idx 

172 

173 # Update the column value where the mask is true 

174 result_df = result_df.with_columns( 

175 pl.when(mask) 

176 .then(pl.lit(original_value)) 

177 .otherwise(pl.col(col)) 

178 .alias(col) 

179 ) 

180 

181 return result_df 

182 

183 def get_group_structure( 

184 self, df: pl.DataFrame, group_by: list[str] 

185 ) -> Dict[str, Any]: 

186 """Analyze the group structure of a DataFrame 

187 

188 Args: 

189 df: Input DataFrame 

190 group_by: List of group columns 

191 

192 Returns: 

193 Dictionary with group structure information 

194 """ 

195 if not group_by or df.is_empty(): 

196 return {"groups": 0, "structure": {}} 

197 

198 # Count unique combinations at each level 

199 structure = {} 

200 

201 for i, col in enumerate(group_by): 

202 level_cols = group_by[: i + 1] 

203 unique_combinations = df.select(level_cols).unique().height 

204 structure[f"level_{i + 1}"] = { 

205 "columns": level_cols, 

206 "unique_combinations": unique_combinations, 

207 } 

208 

209 # Overall statistics 

210 total_groups = df.select(group_by).unique().height 

211 

212 return { 

213 "total_groups": total_groups, 

214 "levels": len(group_by), 

215 "structure": structure, 

216 } 

217 

218 def validate_group_by_columns( 

219 self, df: pl.DataFrame, group_by: list[str] 

220 ) -> list[str]: 

221 """Validate group_by columns and return any issues 

222 

223 Args: 

224 df: Input DataFrame 

225 group_by: List of group columns to validate 

226 

227 Returns: 

228 List of validation issues (empty if all valid) 

229 """ 

230 issues: list[str] = [] 

231 

232 if not group_by: 

233 return issues 

234 

235 # Check if columns exist 

236 missing_cols = [col for col in group_by if col not in df.columns] 

237 if missing_cols: 

238 issues.append(f"Missing columns: {missing_cols}") 

239 

240 # Check for empty DataFrame 

241 if df.is_empty(): 

242 issues.append("DataFrame is empty") 

243 

244 # Check for columns with all null values 

245 for col in group_by: 

246 if col in df.columns: 

247 null_count = df[col].null_count() 

248 if null_count == df.height: 

249 issues.append(f"Column '{col}' contains only null values") 

250 

251 return issues 

252 

253 def validate_data_sorting( 

254 self, 

255 df: pl.DataFrame, 

256 group_by: list[str] | None = None, 

257 page_by: list[str] | None = None, 

258 subline_by: list[str] | None = None, 

259 ) -> None: 

260 """Validate that data is properly sorted for grouping operations 

261 

262 Based on r2rtf logic: ensures data is sorted by all grouping variables 

263 in the correct order for proper group_by, page_by, and subline_by functionality. 

264 

265 Args: 

266 df: Input DataFrame to validate 

267 group_by: List of group_by columns (optional) 

268 page_by: List of page_by columns (optional) 

269 subline_by: List of subline_by columns (optional) 

270 

271 Raises: 

272 ValueError: If data is not properly sorted or if there are overlapping columns 

273 """ 

274 if df.is_empty(): 

275 return 

276 

277 # Collect all grouping variables 

278 all_grouping_vars = [] 

279 

280 # Add variables in priority order (page_by, subline_by, group_by) 

281 if page_by: 

282 all_grouping_vars.extend(page_by) 

283 if subline_by: 

284 all_grouping_vars.extend(subline_by) 

285 if group_by: 

286 all_grouping_vars.extend(group_by) 

287 

288 if not all_grouping_vars: 

289 return # No grouping variables to validate 

290 

291 # Check for overlapping variables between different grouping types 

292 self._validate_no_overlapping_grouping_vars(group_by, page_by, subline_by) 

293 

294 # Remove duplicates while preserving order 

295 unique_vars = [] 

296 seen = set() 

297 for var in all_grouping_vars: 

298 if var not in seen: 

299 unique_vars.append(var) 

300 seen.add(var) 

301 

302 # Validate all grouping columns exist 

303 missing_cols = [col for col in unique_vars if col not in df.columns] 

304 if missing_cols: 

305 raise ValueError(f"Grouping columns not found in DataFrame: {missing_cols}") 

306 

307 # Check if groups are contiguous (values in same group are together) 

308 # This ensures proper grouping behavior without requiring alphabetical sorting 

309 

310 # For each grouping variable, check if its groups are contiguous 

311 for i, var in enumerate(unique_vars): 

312 # Get the values for this variable and all previous variables 

313 group_cols = unique_vars[: i + 1] 

314 

315 # Create a key for each row based on grouping columns up to this level 

316 if i == 0: 

317 # For the first variable, just check if its values are contiguous 

318 values = df[var].to_list() 

319 current_value = values[0] 

320 seen_values = {current_value} 

321 

322 for j in range(1, len(values)): 

323 if values[j] != current_value: 

324 if values[j] in seen_values: 

325 # Found a value that appeared before but with different values in between 

326 raise ValueError( 

327 f"Data is not properly grouped by '{var}'. " 

328 f"Values with the same '{var}' must be contiguous (together). " 

329 f"Found '{values[j]}' at position {j} but it also appeared earlier. " 

330 f"Please reorder your data so that all rows with the same '{var}' are together." 

331 ) 

332 current_value = values[j] 

333 seen_values.add(current_value) 

334 else: 

335 # For subsequent variables, check contiguity within parent groups 

336 # Create a composite key from all grouping variables up to this level 

337 # Handle null values by first converting to string with null handling 

338 df_with_key = df.with_columns( 

339 [ 

340 pl.col(col) 

341 .cast(pl.Utf8) 

342 .fill_null("__NULL__") 

343 .alias(f"_str_{col}") 

344 for col in group_cols 

345 ] 

346 ) 

347 

348 # Create the group key from the string columns 

349 str_cols = [f"_str_{col}" for col in group_cols] 

350 df_with_key = df_with_key.with_columns( 

351 pl.concat_str(str_cols, separator="|").alias("_group_key") 

352 ) 

353 

354 group_keys = df_with_key["_group_key"].to_list() 

355 current_key = group_keys[0] 

356 seen_keys = {current_key} 

357 

358 for j in range(1, len(group_keys)): 

359 if group_keys[j] != current_key: 

360 if group_keys[j] in seen_keys: 

361 # Found a group that appeared before 

362 group_values = df.row(j, named=True) 

363 key_parts = [ 

364 f"{col}='{group_values[col]}'" for col in group_cols 

365 ] 

366 key_desc = ", ".join(key_parts) 

367 

368 raise ValueError( 

369 f"Data is not properly grouped. " 

370 f"Group with {key_desc} appears in multiple non-contiguous sections. " 

371 f"Please reorder your data so that rows with the same grouping values are together." 

372 ) 

373 current_key = group_keys[j] 

374 seen_keys.add(current_key) 

375 

376 def validate_subline_formatting_consistency( 

377 self, df: pl.DataFrame, subline_by: list[str], rtf_body 

378 ) -> list[str]: 

379 """Validate that formatting is consistent within each column after broadcasting 

380 

381 When using subline_by, we need to ensure that after broadcasting formatting 

382 attributes, each remaining column (after removing subline_by columns) has 

383 consistent formatting values. Otherwise, different rows within a subline 

384 group would have different formatting. 

385 

386 Args: 

387 df: Input DataFrame 

388 subline_by: List of subline_by columns 

389 rtf_body: RTFBody instance with formatting attributes 

390 

391 Returns: 

392 List of warning messages 

393 """ 

394 warnings: list[str] = [] 

395 

396 if not subline_by or df.is_empty(): 

397 return warnings 

398 

399 # Get the columns that will remain after removing subline_by columns 

400 remaining_cols = [col for col in df.columns if col not in subline_by] 

401 if not remaining_cols: 

402 return warnings 

403 

404 num_cols = len(remaining_cols) 

405 num_rows = df.height 

406 

407 # Format attributes to check 

408 format_attributes = [ 

409 "text_format", 

410 "text_justification", 

411 "text_font_size", 

412 "text_color", 

413 "border_top", 

414 "border_bottom", 

415 "border_left", 

416 "border_right", 

417 "border_color_top", 

418 "border_color_bottom", 

419 "border_color_left", 

420 "border_color_right", 

421 ] 

422 

423 for attr_name in format_attributes: 

424 if hasattr(rtf_body, attr_name): 

425 attr_value = getattr(rtf_body, attr_name) 

426 if attr_value is None: 

427 continue 

428 

429 # Use BroadcastValue to expand the attribute to full matrix 

430 from ..attributes import BroadcastValue 

431 

432 try: 

433 broadcast_obj = BroadcastValue( 

434 value=attr_value, dimension=(num_rows, num_cols) 

435 ) 

436 broadcasted = broadcast_obj.to_list() 

437 except: 

438 # If broadcasting fails, skip this attribute 

439 continue 

440 

441 # Skip if broadcasting returned None 

442 if broadcasted is None: 

443 continue 

444 

445 # Check each column for consistency 

446 for col_idx in range(num_cols): 

447 # Get all values for this column 

448 col_values = [ 

449 broadcasted[row_idx][col_idx] for row_idx in range(num_rows) 

450 ] 

451 

452 # Filter out None and empty string values 

453 meaningful_values = [v for v in col_values if v not in [None, ""]] 

454 if not meaningful_values: 

455 continue 

456 

457 # Check if all values in this column are the same 

458 unique_values = set(meaningful_values) 

459 if len(unique_values) > 1: 

460 col_name = remaining_cols[col_idx] 

461 warnings.append( 

462 f"Column '{col_name}' has inconsistent {attr_name} values {list(unique_values)} " 

463 f"after broadcasting. When using subline_by, formatting should be consistent " 

464 f"within each column to ensure uniform appearance within subline groups." 

465 ) 

466 

467 return warnings 

468 

469 def _validate_no_overlapping_grouping_vars( 

470 self, 

471 group_by: list[str] | None = None, 

472 page_by: list[str] | None = None, 

473 subline_by: list[str] | None = None, 

474 ) -> None: 

475 """Validate that grouping variables don't overlap between different types 

476 

477 Based on r2rtf validation logic to prevent conflicts between 

478 group_by, page_by, and subline_by parameters. 

479 

480 Args: 

481 group_by: List of group_by columns (optional) 

482 page_by: List of page_by columns (optional) 

483 subline_by: List of subline_by columns (optional) 

484 

485 Raises: 

486 ValueError: If there are overlapping variables between grouping types 

487 """ 

488 # Convert None to empty lists for easier processing 

489 group_by = group_by or [] 

490 page_by = page_by or [] 

491 subline_by = subline_by or [] 

492 

493 # Check for overlaps between each pair 

494 overlaps = [] 

495 

496 # group_by vs page_by 

497 group_page_overlap = set(group_by) & set(page_by) 

498 if group_page_overlap: 

499 overlaps.append(f"group_by and page_by: {sorted(group_page_overlap)}") 

500 

501 # group_by vs subline_by 

502 group_subline_overlap = set(group_by) & set(subline_by) 

503 if group_subline_overlap: 

504 overlaps.append(f"group_by and subline_by: {sorted(group_subline_overlap)}") 

505 

506 # page_by vs subline_by 

507 page_subline_overlap = set(page_by) & set(subline_by) 

508 if page_subline_overlap: 

509 overlaps.append(f"page_by and subline_by: {sorted(page_subline_overlap)}") 

510 

511 if overlaps: 

512 overlap_details = "; ".join(overlaps) 

513 raise ValueError( 

514 f"Overlapping variables found between grouping parameters: {overlap_details}. " 

515 f"Each variable can only be used in one grouping parameter (group_by, page_by, or subline_by)." 

516 ) 

517 

518 

519# Create a singleton instance for easy access 

520grouping_service = GroupingService()