Coverage for src/rtflite/services/grouping_service.py: 79%
173 statements
« prev ^ index » next coverage.py v7.10.5, created at 2025-08-25 22:35 +0000
« prev ^ index » next coverage.py v7.10.5, created at 2025-08-25 22:35 +0000
1"""
2Enhanced group_by functionality for rtflite.
4This service implements r2rtf-compatible group_by behavior where values in
5group_by columns are displayed only once per group, with subsequent rows
6showing blank/suppressed values for better readability.
7"""
9from collections.abc import Mapping, MutableSequence, Sequence
10from typing import Any
12import polars as pl
15class GroupingService:
16 """Service for handling group_by functionality with value suppression"""
18 def __init__(self):
19 pass
21 def enhance_group_by(
22 self, df: pl.DataFrame, group_by: Sequence[str] | None = None
23 ) -> pl.DataFrame:
24 """Apply group_by value suppression to a DataFrame
26 Args:
27 df: Input DataFrame
28 group_by: List of column names to group by. Values will be suppressed
29 for duplicate rows within groups.
31 Returns:
32 DataFrame with group_by columns showing values only on first occurrence
33 within each group
35 Raises:
36 ValueError: If data is not properly sorted by group_by columns
37 """
38 if not group_by or df.is_empty():
39 return df
41 # Validate that all group_by columns exist
42 missing_cols = [col for col in group_by if col not in df.columns]
43 if missing_cols:
44 raise ValueError(f"group_by columns not found in DataFrame: {missing_cols}")
46 # Validate data sorting for group_by columns
47 self.validate_data_sorting(df, group_by=group_by)
49 # Create a copy to avoid modifying original
50 result_df = df.clone()
52 # Apply grouping logic based on number of group columns
53 if len(group_by) == 1:
54 result_df = self._suppress_single_column(result_df, group_by[0])
55 else:
56 result_df = self._suppress_hierarchical_columns(result_df, group_by)
58 return result_df
60 def _suppress_single_column(self, df: pl.DataFrame, column: str) -> pl.DataFrame:
61 """Suppress duplicate values in a single group column
63 Args:
64 df: Input DataFrame
65 column: Column name to suppress duplicates
67 Returns:
68 DataFrame with duplicate values replaced with null
69 """
70 # Create a mask for rows where the value is different from the previous row
71 is_first_occurrence = (df[column] != df[column].shift(1)) | (
72 pl.int_range(df.height) == 0
73 ) # First row is always shown
75 # Create suppressed column by setting duplicates to null
76 suppressed_values = (
77 pl.when(is_first_occurrence).then(df[column]).otherwise(None)
78 )
80 # Replace the original column with suppressed version
81 result_df = df.with_columns(suppressed_values.alias(column))
83 return result_df
85 def _suppress_hierarchical_columns(
86 self, df: pl.DataFrame, group_by: Sequence[str]
87 ) -> pl.DataFrame:
88 """Suppress duplicate values in hierarchical group columns
90 For multiple group columns, values are suppressed hierarchically:
91 - First column: only shows when it changes
92 - Second column: shows when first column changes OR when it changes
93 - And so on...
95 Args:
96 df: Input DataFrame
97 group_by: List of column names in hierarchical order
99 Returns:
100 DataFrame with hierarchical value suppression
101 """
102 result_df = df.clone()
104 for i, column in enumerate(group_by):
105 # For hierarchical grouping, a value should be shown if:
106 # 1. It's the first row, OR
107 # 2. Any of the higher-level group columns have changed, OR
108 # 3. This column's value has changed
110 conditions = []
112 # First row condition
113 conditions.append(pl.int_range(df.height) == 0)
115 # Higher-level columns changed condition
116 for higher_col in group_by[:i]:
117 conditions.append(pl.col(higher_col) != pl.col(higher_col).shift(1))
119 # This column changed condition
120 conditions.append(pl.col(column) != pl.col(column).shift(1))
122 # Combine all conditions with OR
123 should_show = conditions[0]
124 for condition in conditions[1:]:
125 should_show = should_show | condition
127 # Apply suppression
128 suppressed_values = (
129 pl.when(should_show).then(pl.col(column)).otherwise(None)
130 )
131 result_df = result_df.with_columns(suppressed_values.alias(column))
133 return result_df
135 def restore_page_context(
136 self,
137 suppressed_df: pl.DataFrame,
138 original_df: pl.DataFrame,
139 group_by: Sequence[str],
140 page_start_indices: Sequence[int],
141 ) -> pl.DataFrame:
142 """Restore group context at the beginning of new pages
144 When content spans multiple pages, the first row of each new page
145 should show the group values for context, even if they were suppressed
146 in the continuous flow.
148 Args:
149 suppressed_df: DataFrame with group_by suppression applied
150 original_df: Original DataFrame with all values
151 group_by: List of group columns
152 page_start_indices: List of row indices where new pages start
154 Returns:
155 DataFrame with group context restored at page boundaries
156 """
157 if not group_by or not page_start_indices:
158 return suppressed_df
160 result_df = suppressed_df.clone()
162 # For each page start, restore the group values from original data
163 for page_start_idx in page_start_indices:
164 if page_start_idx < len(original_df):
165 # Create updates for each group column
166 for col in group_by:
167 # Get the original value for this row
168 original_value = original_df[col][page_start_idx]
170 # Update the result DataFrame at this position
171 # Create a mask for this specific row
172 mask = pl.int_range(len(result_df)) == page_start_idx
174 # Update the column value where the mask is true
175 result_df = result_df.with_columns(
176 pl.when(mask)
177 .then(pl.lit(original_value))
178 .otherwise(pl.col(col))
179 .alias(col)
180 )
182 return result_df
184 def get_group_structure(
185 self, df: pl.DataFrame, group_by: Sequence[str]
186 ) -> Mapping[str, Any]:
187 """Analyze the group structure of a DataFrame
189 Args:
190 df: Input DataFrame
191 group_by: List of group columns
193 Returns:
194 Dictionary with group structure information
195 """
196 if not group_by or df.is_empty():
197 return {"groups": 0, "structure": {}}
199 # Count unique combinations at each level
200 structure = {}
202 for i, col in enumerate(group_by):
203 level_cols = group_by[: i + 1]
204 unique_combinations = df.select(level_cols).unique().height
205 structure[f"level_{i + 1}"] = {
206 "columns": level_cols,
207 "unique_combinations": unique_combinations,
208 }
210 # Overall statistics
211 total_groups = df.select(group_by).unique().height
213 return {
214 "total_groups": total_groups,
215 "levels": len(group_by),
216 "structure": structure,
217 }
219 def validate_group_by_columns(
220 self, df: pl.DataFrame, group_by: Sequence[str]
221 ) -> Sequence[str]:
222 """Validate group_by columns and return any issues
224 Args:
225 df: Input DataFrame
226 group_by: List of group columns to validate
228 Returns:
229 List of validation issues (empty if all valid)
230 """
231 issues: MutableSequence[str] = []
233 if not group_by:
234 return issues
236 # Check if columns exist
237 missing_cols = [col for col in group_by if col not in df.columns]
238 if missing_cols:
239 issues.append(f"Missing columns: {missing_cols}")
241 # Check for empty DataFrame
242 if df.is_empty():
243 issues.append("DataFrame is empty")
245 # Check for columns with all null values
246 for col in group_by:
247 if col in df.columns:
248 null_count = df[col].null_count()
249 if null_count == df.height:
250 issues.append(f"Column '{col}' contains only null values")
252 return issues
254 def validate_data_sorting(
255 self,
256 df: pl.DataFrame,
257 group_by: Sequence[str] | None = None,
258 page_by: Sequence[str] | None = None,
259 subline_by: Sequence[str] | None = None,
260 ) -> None:
261 """Validate that data is properly sorted for grouping operations
263 Based on r2rtf logic: ensures data is sorted by all grouping variables
264 in the correct order for proper group_by, page_by, and subline_by functionality.
266 Args:
267 df: Input DataFrame to validate
268 group_by: List of group_by columns (optional)
269 page_by: List of page_by columns (optional)
270 subline_by: List of subline_by columns (optional)
272 Raises:
273 ValueError: If data is not properly sorted or if there are overlapping columns
274 """
275 if df.is_empty():
276 return
278 # Collect all grouping variables
279 all_grouping_vars: list[str] = []
281 # Add variables in priority order (page_by, subline_by, group_by)
282 if page_by:
283 all_grouping_vars.extend(page_by)
284 if subline_by:
285 all_grouping_vars.extend(subline_by)
286 if group_by:
287 all_grouping_vars.extend(group_by)
289 if not all_grouping_vars:
290 return # No grouping variables to validate
292 # Check for overlapping variables between different grouping types
293 self._validate_no_overlapping_grouping_vars(group_by, page_by, subline_by)
295 # Remove duplicates while preserving order
296 unique_vars = []
297 seen = set()
298 for var in all_grouping_vars:
299 if var not in seen:
300 unique_vars.append(var)
301 seen.add(var)
303 # Validate all grouping columns exist
304 missing_cols = [col for col in unique_vars if col not in df.columns]
305 if missing_cols:
306 raise ValueError(f"Grouping columns not found in DataFrame: {missing_cols}")
308 # Check if groups are contiguous (values in same group are together)
309 # This ensures proper grouping behavior without requiring alphabetical sorting
311 # For each grouping variable, check if its groups are contiguous
312 for i, var in enumerate(unique_vars):
313 # Get the values for this variable and all previous variables
314 group_cols = unique_vars[: i + 1]
316 # Create a key for each row based on grouping columns up to this level
317 if i == 0:
318 # For the first variable, just check if its values are contiguous
319 values = df[var].to_list()
320 current_value = values[0]
321 seen_values = {current_value}
323 for j in range(1, len(values)):
324 if values[j] != current_value:
325 if values[j] in seen_values:
326 # Found a value that appeared before but with different values in between
327 raise ValueError(
328 f"Data is not properly grouped by '{var}'. "
329 f"Values with the same '{var}' must be contiguous (together). "
330 f"Found '{values[j]}' at position {j} but it also appeared earlier. "
331 f"Please reorder your data so that all rows with the same '{var}' are together."
332 )
333 current_value = values[j]
334 seen_values.add(current_value)
335 else:
336 # For subsequent variables, check contiguity within parent groups
337 # Create a composite key from all grouping variables up to this level
338 # Handle null values by first converting to string with null handling
339 df_with_key = df.with_columns(
340 [
341 pl.col(col)
342 .cast(pl.Utf8)
343 .fill_null("__NULL__")
344 .alias(f"_str_{col}")
345 for col in group_cols
346 ]
347 )
349 # Create the group key from the string columns
350 str_cols = [f"_str_{col}" for col in group_cols]
351 df_with_key = df_with_key.with_columns(
352 pl.concat_str(str_cols, separator="|").alias("_group_key")
353 )
355 group_keys = df_with_key["_group_key"].to_list()
356 current_key = group_keys[0]
357 seen_keys = {current_key}
359 for j in range(1, len(group_keys)):
360 if group_keys[j] != current_key:
361 if group_keys[j] in seen_keys:
362 # Found a group that appeared before
363 group_values = df.row(j, named=True)
364 key_parts = [
365 f"{col}='{group_values[col]}'" for col in group_cols
366 ]
367 key_desc = ", ".join(key_parts)
369 raise ValueError(
370 f"Data is not properly grouped. "
371 f"Group with {key_desc} appears in multiple non-contiguous sections. "
372 f"Please reorder your data so that rows with the same grouping values are together."
373 )
374 current_key = group_keys[j]
375 seen_keys.add(current_key)
377 def validate_subline_formatting_consistency(
378 self, df: pl.DataFrame, subline_by: Sequence[str], rtf_body
379 ) -> Sequence[str]:
380 """Validate that formatting is consistent within each column after broadcasting
382 When using subline_by, we need to ensure that after broadcasting formatting
383 attributes, each remaining column (after removing subline_by columns) has
384 consistent formatting values. Otherwise, different rows within a subline
385 group would have different formatting.
387 Args:
388 df: Input DataFrame
389 subline_by: List of subline_by columns
390 rtf_body: RTFBody instance with formatting attributes
392 Returns:
393 List of warning messages
394 """
395 warnings: MutableSequence[str] = []
397 if not subline_by or df.is_empty():
398 return warnings
400 # Get the columns that will remain after removing subline_by columns
401 remaining_cols = [col for col in df.columns if col not in subline_by]
402 if not remaining_cols:
403 return warnings
405 num_cols = len(remaining_cols)
406 num_rows = df.height
408 # Format attributes to check
409 format_attributes = [
410 "text_format",
411 "text_justification",
412 "text_font_size",
413 "text_color",
414 "border_top",
415 "border_bottom",
416 "border_left",
417 "border_right",
418 "border_color_top",
419 "border_color_bottom",
420 "border_color_left",
421 "border_color_right",
422 ]
424 for attr_name in format_attributes:
425 if hasattr(rtf_body, attr_name):
426 attr_value = getattr(rtf_body, attr_name)
427 if attr_value is None:
428 continue
430 # Use BroadcastValue to expand the attribute to full matrix
431 from ..attributes import BroadcastValue
433 try:
434 broadcast_obj = BroadcastValue(
435 value=attr_value, dimension=(num_rows, num_cols)
436 )
437 broadcasted = broadcast_obj.to_list()
438 except:
439 # If broadcasting fails, skip this attribute
440 continue
442 # Skip if broadcasting returned None
443 if broadcasted is None:
444 continue
446 # Check each column for consistency
447 for col_idx in range(num_cols):
448 # Get all values for this column
449 col_values = [
450 broadcasted[row_idx][col_idx] for row_idx in range(num_rows)
451 ]
453 # Filter out None and empty string values
454 meaningful_values = [v for v in col_values if v not in [None, ""]]
455 if not meaningful_values:
456 continue
458 # Check if all values in this column are the same
459 unique_values = set(meaningful_values)
460 if len(unique_values) > 1:
461 col_name = remaining_cols[col_idx]
462 warnings.append(
463 f"Column '{col_name}' has inconsistent {attr_name} values {list(unique_values)} "
464 f"after broadcasting. When using subline_by, formatting should be consistent "
465 f"within each column to ensure uniform appearance within subline groups."
466 )
468 return warnings
470 def _validate_no_overlapping_grouping_vars(
471 self,
472 group_by: Sequence[str] | None = None,
473 page_by: Sequence[str] | None = None,
474 subline_by: Sequence[str] | None = None,
475 ) -> None:
476 """Validate that grouping variables don't overlap between different types
478 Based on r2rtf validation logic to prevent conflicts between
479 group_by, page_by, and subline_by parameters.
481 Args:
482 group_by: List of group_by columns (optional)
483 page_by: List of page_by columns (optional)
484 subline_by: List of subline_by columns (optional)
486 Raises:
487 ValueError: If there are overlapping variables between grouping types
488 """
489 # Convert None to empty lists for easier processing
490 group_by = group_by or []
491 page_by = page_by or []
492 subline_by = subline_by or []
494 # Check for overlaps between each pair
495 overlaps = []
497 # group_by vs page_by
498 group_page_overlap = set(group_by) & set(page_by)
499 if group_page_overlap:
500 overlaps.append(f"group_by and page_by: {sorted(group_page_overlap)}")
502 # group_by vs subline_by
503 group_subline_overlap = set(group_by) & set(subline_by)
504 if group_subline_overlap:
505 overlaps.append(f"group_by and subline_by: {sorted(group_subline_overlap)}")
507 # page_by vs subline_by
508 page_subline_overlap = set(page_by) & set(subline_by)
509 if page_subline_overlap:
510 overlaps.append(f"page_by and subline_by: {sorted(page_subline_overlap)}")
512 if overlaps:
513 overlap_details = "; ".join(overlaps)
514 raise ValueError(
515 f"Overlapping variables found between grouping parameters: {overlap_details}. "
516 f"Each variable can only be used in one grouping parameter (group_by, page_by, or subline_by)."
517 )
520# Create a singleton instance for easy access
521grouping_service = GroupingService()