Coverage for src / rtflite / attributes.py: 77%
339 statements
« prev ^ index » next coverage.py v7.12.0, created at 2025-11-28 05:09 +0000
« prev ^ index » next coverage.py v7.12.0, created at 2025-11-28 05:09 +0000
1import math
2from collections.abc import MutableSequence, Sequence
3from typing import Any
5import narwhals as nw
6import polars as pl
7from pydantic import BaseModel, ConfigDict, Field, field_validator
9from rtflite.row import (
10 BORDER_CODES,
11 FORMAT_CODES,
12 TEXT_JUSTIFICATION_CODES,
13 VERTICAL_ALIGNMENT_CODES,
14 Border,
15 Cell,
16 Row,
17 TextContent,
18 Utils,
19)
20from rtflite.services.color_service import color_service
21from rtflite.strwidth import get_string_width
24def _to_nested_list(v):
25 if v is None:
26 return None
28 if isinstance(v, (int, str, float, bool)):
29 v = [[v]]
31 if isinstance(v, Sequence):
32 if isinstance(v, list) and any(
33 isinstance(item, (str, int, float, bool)) for item in v
34 ):
35 v = [v]
36 elif isinstance(v, list) and all(isinstance(item, list) for item in v):
37 v = v
38 elif isinstance(v, tuple):
39 v = [[item] for item in v]
40 else:
41 raise TypeError("Invalid value type. Must be a list or tuple.")
43 # Use narwhals to handle any DataFrame type
44 if hasattr(v, "__dataframe__") or hasattr(
45 v, "columns"
46 ): # Check if it's DataFrame-like
47 if isinstance(v, pl.DataFrame):
48 v = [list(row) for row in v.rows()] # Convert tuples to lists
49 else:
50 try:
51 nw_df = nw.from_native(v)
52 v = [
53 list(row) for row in nw_df.to_native(pl.DataFrame).rows()
54 ] # Convert tuples to lists
55 except Exception:
56 # If narwhals can't handle it, try direct conversion
57 if isinstance(v, pl.DataFrame):
58 v = [list(row) for row in v.rows()] # Convert tuples to lists
60 # Convert numpy arrays or array-like objects to lists
61 if hasattr(v, "__array__") and hasattr(v, "tolist"):
62 v = v.tolist()
64 return v
67class TextAttributes(BaseModel):
68 """Base class for text-related attributes in RTF components"""
70 text_font: list[int] | list[list[int]] | None = Field(
71 default=None, description="Font number for text"
72 )
74 @field_validator("text_font", mode="after")
75 def validate_text_font(cls, v):
76 if v is None:
77 return v
79 # Check if it's a nested list
80 if v and isinstance(v[0], list):
81 for row in v:
82 for font in row:
83 if font not in Utils._font_type()["type"]:
84 raise ValueError(f"Invalid font number: {font}")
85 else:
86 # Flat list
87 for font in v:
88 if font not in Utils._font_type()["type"]:
89 raise ValueError(f"Invalid font number: {font}")
90 return v
92 text_format: list[str] | list[list[str]] | None = Field(
93 default=None,
94 description="Text formatting (e.g. 'b' for 'bold', 'i' for'italic')",
95 )
97 @field_validator("text_format", mode="after")
98 def validate_text_format(cls, v):
99 if v is None:
100 return v
102 # Check if it's a nested list
103 if v and isinstance(v[0], list):
104 for row in v:
105 for format in row:
106 for fmt in format:
107 if fmt not in FORMAT_CODES:
108 raise ValueError(f"Invalid text format: {fmt}")
109 else:
110 # Flat list
111 for format in v:
112 for fmt in format:
113 if fmt not in FORMAT_CODES:
114 raise ValueError(f"Invalid text format: {fmt}")
115 return v
117 text_font_size: list[float] | list[list[float]] | None = Field(
118 default=None, description="Font size in points"
119 )
121 @field_validator("text_font_size", mode="after")
122 def validate_text_font_size(cls, v):
123 if v is None:
124 return v
126 # Check if it's a nested list
127 if v and isinstance(v[0], list):
128 for row in v:
129 for size in row:
130 if size <= 0:
131 raise ValueError(f"Invalid font size: {size}")
132 else:
133 # Flat list
134 for size in v:
135 if size <= 0:
136 raise ValueError(f"Invalid font size: {size}")
137 return v
139 text_color: list[str] | list[list[str]] | None = Field(
140 default=None, description="Text color name or RGB value"
141 )
143 @field_validator("text_color", mode="after")
144 def validate_text_color(cls, v):
145 if v is None:
146 return v
148 # Check if it's a nested list
149 if v and isinstance(v[0], list):
150 for row in v:
151 for color in row:
152 # Allow empty string for "no color"
153 if color and not color_service.validate_color(color):
154 suggestions = color_service.get_color_suggestions(color, 3)
155 suggestion_text = (
156 f" Did you mean: {', '.join(suggestions)}?"
157 if suggestions
158 else ""
159 )
160 raise ValueError(
161 f"Invalid text color: '{color}'.{suggestion_text}"
162 )
163 else:
164 # Flat list
165 for color in v:
166 # Allow empty string for "no color"
167 if color and not color_service.validate_color(color):
168 suggestions = color_service.get_color_suggestions(color, 3)
169 suggestion_text = (
170 f" Did you mean: {', '.join(suggestions)}?"
171 if suggestions
172 else ""
173 )
174 raise ValueError(f"Invalid text color: '{color}'.{suggestion_text}")
175 return v
177 text_background_color: list[str] | list[list[str]] | None = Field(
178 default=None, description="Background color name or RGB value"
179 )
181 @field_validator("text_background_color", mode="after")
182 def validate_text_background_color(cls, v):
183 if v is None:
184 return v
186 # Check if it's a nested list
187 if v and isinstance(v[0], list):
188 for row in v:
189 for color in row:
190 # Allow empty string for "no color"
191 if color and not color_service.validate_color(color):
192 suggestions = color_service.get_color_suggestions(color, 3)
193 suggestion_text = (
194 f" Did you mean: {', '.join(suggestions)}?"
195 if suggestions
196 else ""
197 )
198 raise ValueError(
199 "Invalid text background color: "
200 f"'{color}'.{suggestion_text}"
201 )
202 else:
203 # Flat list
204 for color in v:
205 # Allow empty string for "no color"
206 if color and not color_service.validate_color(color):
207 suggestions = color_service.get_color_suggestions(color, 3)
208 suggestion_text = (
209 f" Did you mean: {', '.join(suggestions)}?"
210 if suggestions
211 else ""
212 )
213 raise ValueError(
214 f"Invalid text background color: '{color}'.{suggestion_text}"
215 )
216 return v
218 text_justification: list[str] | list[list[str]] | None = Field(
219 default=None,
220 description="Text alignment ('l'=left, 'c'=center, 'r'=right, 'j'=justify)",
221 )
223 @field_validator("text_justification", mode="after")
224 def validate_text_justification(cls, v):
225 if v is None:
226 return v
228 # Check if it's a nested list
229 if v and isinstance(v[0], list):
230 for row in v:
231 for justification in row:
232 if justification not in TEXT_JUSTIFICATION_CODES:
233 raise ValueError(f"Invalid text justification: {justification}")
234 else:
235 # Flat list
236 for justification in v:
237 if justification not in TEXT_JUSTIFICATION_CODES:
238 raise ValueError(f"Invalid text justification: {justification}")
239 return v
241 text_indent_first: list[int] | list[list[int]] | None = Field(
242 default=None, description="First line indent in twips"
243 )
244 text_indent_left: list[int] | list[list[int]] | None = Field(
245 default=None, description="Left indent in twips"
246 )
247 text_indent_right: list[int] | list[list[int]] | None = Field(
248 default=None, description="Right indent in twips"
249 )
250 text_space: list[int] | list[list[int]] | None = Field(
251 default=None, description="Line spacing multiplier"
252 )
253 text_space_before: list[int] | list[list[int]] | None = Field(
254 default=None, description="Space before paragraph in twips"
255 )
256 text_space_after: list[int] | list[list[int]] | None = Field(
257 default=None, description="Space after paragraph in twips"
258 )
259 text_hyphenation: list[bool] | list[list[bool]] | None = Field(
260 default=None, description="Enable automatic hyphenation"
261 )
262 text_convert: list[bool] | list[list[bool]] | None = Field(
263 default=[True], description="Convert LaTeX commands to Unicode characters"
264 )
266 @field_validator(
267 "text_font",
268 "text_format",
269 "text_font_size",
270 "text_color",
271 "text_background_color",
272 "text_justification",
273 "text_indent_first",
274 "text_indent_left",
275 "text_indent_right",
276 "text_space",
277 "text_space_before",
278 "text_space_after",
279 "text_hyphenation",
280 "text_convert",
281 mode="before",
282 )
283 def convert_to_list(cls, v):
284 """Convert single values to lists before validation."""
285 if v is not None and isinstance(v, (int, str, float, bool)):
286 return [v]
287 return v
289 def _encode_text(self, text: Sequence[str], method: str) -> str | list[str]:
290 """Convert the RTF title into RTF syntax using the Text class."""
292 dim = [len(text), 1]
294 def get_broadcast_value(attr_name, row_idx, col_idx=0):
295 """Get broadcast value for an attribute at specified indices."""
296 attr_value = getattr(self, attr_name)
297 return BroadcastValue(value=attr_value, dimension=dim).iloc(
298 row_idx, col_idx
299 )
301 text_components = []
302 for i in range(dim[0]):
303 text_components.append(
304 TextContent(
305 text=str(text[i]),
306 font=get_broadcast_value("text_font", i),
307 size=get_broadcast_value("text_font_size", i),
308 format=get_broadcast_value("text_format", i),
309 color=get_broadcast_value("text_color", i),
310 background_color=get_broadcast_value("text_background_color", i),
311 justification=get_broadcast_value("text_justification", i),
312 indent_first=get_broadcast_value("text_indent_first", i),
313 indent_left=get_broadcast_value("text_indent_left", i),
314 indent_right=get_broadcast_value("text_indent_right", i),
315 space=get_broadcast_value("text_space", i),
316 space_before=get_broadcast_value("text_space_before", i),
317 space_after=get_broadcast_value("text_space_after", i),
318 convert=get_broadcast_value("text_convert", i),
319 hyphenation=get_broadcast_value("text_hyphenation", i),
320 )
321 )
323 if method == "paragraph":
324 return [
325 text_component._as_rtf(method="paragraph")
326 for text_component in text_components
327 ]
329 if method == "line":
330 line = "\\line".join(
331 [
332 text_component._as_rtf(method="plain")
333 for text_component in text_components
334 ]
335 )
336 return TextContent(
337 text=str(line),
338 font=get_broadcast_value("text_font", i),
339 size=get_broadcast_value("text_font_size", i),
340 format=get_broadcast_value("text_format", i),
341 color=get_broadcast_value("text_color", i),
342 background_color=get_broadcast_value("text_background_color", i),
343 justification=get_broadcast_value("text_justification", i),
344 indent_first=get_broadcast_value("text_indent_first", i),
345 indent_left=get_broadcast_value("text_indent_left", i),
346 indent_right=get_broadcast_value("text_indent_right", i),
347 space=get_broadcast_value("text_space", i),
348 space_before=get_broadcast_value("text_space_before", i),
349 space_after=get_broadcast_value("text_space_after", i),
350 convert=get_broadcast_value("text_convert", i),
351 hyphenation=get_broadcast_value("text_hyphenation", i),
352 )._as_rtf(method="paragraph_format")
354 raise ValueError(f"Invalid method: {method}")
356 def calculate_lines(
357 self, text: str, available_width: float, row_idx: int = 0, col_idx: int = 0
358 ) -> int:
359 """
360 Calculate number of lines needed for text given available width.
362 Args:
363 text: Text content to measure
364 available_width: Available width in inches
365 row_idx: Row index for attribute lookup (default: 0)
366 col_idx: Column index for attribute lookup (default: 0)
368 Returns:
369 Number of lines needed (minimum 1)
370 """
371 if not text or available_width <= 0:
372 return 1
374 # Create a dummy dimension for broadcast lookup
375 dim = (max(1, row_idx + 1), max(1, col_idx + 1))
377 # Get font attributes using broadcast logic - raise error if None
378 if self.text_font is None:
379 raise ValueError("text_font must be set to calculate lines")
380 font_broadcast = BroadcastValue(value=self.text_font, dimension=dim)
381 font_number = font_broadcast.iloc(row_idx, col_idx)
383 if self.text_font_size is None:
384 raise ValueError("text_font_size must be set to calculate lines")
385 size_broadcast = BroadcastValue(value=self.text_font_size, dimension=dim)
386 font_size = size_broadcast.iloc(row_idx, col_idx)
388 # Calculate total text width
389 total_width = get_string_width(
390 text=text, font=font_number, font_size=font_size, unit="in"
391 )
393 # Simple approximation: divide total width by available width and round up
394 return max(1, int(math.ceil(total_width / available_width)))
397class TableAttributes(TextAttributes):
398 """Base class for table-related attributes in RTF components"""
400 col_rel_width: list[float] | None = Field(
401 default=None, description="Relative widths of table columns"
402 )
404 border_left: list[list[str]] = Field(
405 default=[[""]], description="Left border style"
406 )
407 border_right: list[list[str]] = Field(
408 default=[[""]], description="Right border style"
409 )
410 border_top: list[list[str]] = Field(default=[[""]], description="Top border style")
411 border_bottom: list[list[str]] = Field(
412 default=[[""]], description="Bottom border style"
413 )
414 border_first: list[list[str]] = Field(
415 default=[[""]], description="First row border style"
416 )
417 border_last: list[list[str]] = Field(
418 default=[[""]], description="Last row border style"
419 )
420 border_color_left: list[list[str]] = Field(
421 default=[[""]], description="Left border color"
422 )
423 border_color_right: list[list[str]] = Field(
424 default=[[""]], description="Right border color"
425 )
426 border_color_top: list[list[str]] = Field(
427 default=[[""]], description="Top border color"
428 )
429 border_color_bottom: list[list[str]] = Field(
430 default=[[""]], description="Bottom border color"
431 )
432 border_color_first: list[list[str]] = Field(
433 default=[[""]], description="First row border color"
434 )
435 border_color_last: list[list[str]] = Field(
436 default=[[""]], description="Last row border color"
437 )
439 @field_validator(
440 "border_color_left",
441 "border_color_right",
442 "border_color_top",
443 "border_color_bottom",
444 "border_color_first",
445 "border_color_last",
446 mode="after",
447 )
448 def validate_border_colors(cls, v):
449 if v is None:
450 return v
452 for row in v:
453 for color in row:
454 # Allow empty string for no color
455 if color and not color_service.validate_color(color):
456 suggestions = color_service.get_color_suggestions(color, 3)
457 suggestion_text = (
458 f" Did you mean: {', '.join(suggestions)}?"
459 if suggestions
460 else ""
461 )
462 raise ValueError(
463 f"Invalid border color: '{color}'.{suggestion_text}"
464 )
465 return v
467 border_width: list[list[int]] = Field(
468 default=[[15]], description="Border width in twips"
469 )
470 cell_height: list[list[float]] = Field(
471 default=[[0.15]], description="Cell height in inches"
472 )
473 cell_justification: list[list[str]] = Field(
474 default=[["l"]],
475 description=(
476 "Cell horizontal alignment ('l'=left, 'c'=center, 'r'=right, 'j'=justify)"
477 ),
478 )
480 cell_vertical_justification: list[list[str]] = Field(
481 default=[["center"]],
482 description="Cell vertical alignment ('top', 'center', 'bottom')",
483 )
485 @field_validator("cell_vertical_justification", mode="after")
486 def validate_cell_vertical_justification(cls, v):
487 if v is None:
488 return v
490 for row in v:
491 for justification in row:
492 if justification not in VERTICAL_ALIGNMENT_CODES:
493 raise ValueError(
494 f"Invalid cell vertical justification: {justification}"
495 )
496 return v
498 cell_nrow: list[list[int]] = Field(
499 default=[[1]], description="Number of rows per cell"
500 )
502 @field_validator("col_rel_width", mode="before")
503 def convert_col_rel_width_to_list(cls, v):
504 if v is not None and isinstance(v, (int, str, float, bool)):
505 return [v]
506 return v
508 @field_validator(
509 "border_left",
510 "border_right",
511 "border_top",
512 "border_bottom",
513 "border_first",
514 "border_last",
515 "border_color_left",
516 "border_color_right",
517 "border_color_top",
518 "border_color_bottom",
519 "border_color_first",
520 "border_color_last",
521 "border_width",
522 "cell_height",
523 "cell_justification",
524 "cell_vertical_justification",
525 "cell_nrow",
526 "text_font",
527 "text_format",
528 "text_font_size",
529 "text_color",
530 "text_background_color",
531 "text_justification",
532 "text_indent_first",
533 "text_indent_left",
534 "text_indent_right",
535 "text_space",
536 "text_space_before",
537 "text_space_after",
538 "text_hyphenation",
539 "text_convert",
540 mode="before",
541 )
542 def convert_to_nested_list(cls, v):
543 return _to_nested_list(v)
545 @field_validator(
546 "col_rel_width", "border_width", "cell_height", "cell_nrow", mode="after"
547 )
548 def validate_positive_value(cls, v):
549 if v is not None:
550 # Check if any value is <= 0
551 if isinstance(v[0], (list, tuple)):
552 # 2D array
553 if any(val <= 0 for row in v for val in row):
554 raise ValueError(
555 f"{cls.__field_name__.capitalize()} must be positive"
556 )
557 else:
558 # 1D array
559 if any(val <= 0 for val in v):
560 raise ValueError(
561 f"{cls.__field_name__.capitalize()} must be positive"
562 )
563 return v
565 @field_validator("cell_justification", mode="after")
566 def validate_cell_justification(cls, v):
567 if v is None:
568 return v
570 for row in v:
571 for justification in row:
572 if justification not in TEXT_JUSTIFICATION_CODES:
573 raise ValueError(f"Invalid cell justification: {justification}")
574 return v
576 @field_validator(
577 "border_left",
578 "border_right",
579 "border_top",
580 "border_bottom",
581 "border_first",
582 "border_last",
583 mode="after",
584 )
585 def validate_border(cls, v):
586 """Validate that all border styles are valid."""
587 if v is None:
588 return v
590 for row in v:
591 for border in row:
592 if border not in BORDER_CODES:
593 field_name = cls.__field_name__.capitalize()
594 raise ValueError(
595 f"{field_name} with invalid border style: {border}"
596 )
598 return v
600 def _get_section_attributes(self, indices) -> dict:
601 """Helper method to collect all attributes for a section"""
602 # Get all attributes that start with text_, col_, border_, or cell_
603 attrs = {}
604 for attr in dir(self):
605 if not (
606 attr.startswith("text_")
607 or attr.startswith("col_")
608 or attr.startswith("border_")
609 or attr.startswith("cell_")
610 ):
611 continue
613 try:
614 attr_value = getattr(self, attr)
615 except AttributeError:
616 continue
618 if not callable(attr_value):
619 attrs[attr] = attr_value
621 # Broadcast attributes to section indices, excluding None values
622 return {
623 attr: [
624 BroadcastValue(value=val, dimension=None).iloc(row, col)
625 for row, col in indices
626 ]
627 for attr, val in attrs.items()
628 if val is not None
629 }
631 def _encode(
632 self, df: pl.DataFrame, col_widths: Sequence[float]
633 ) -> MutableSequence[str]:
634 dim = df.shape
636 def get_broadcast_value(attr_name, row_idx, col_idx=0):
637 """Get broadcast value for an attribute at specified indices."""
638 attr_value = getattr(self, attr_name)
639 return BroadcastValue(value=attr_value, dimension=dim).iloc(
640 row_idx, col_idx
641 )
643 if self.cell_nrow is None:
644 self.cell_nrow = [[0.0 for _ in range(dim[1])] for _ in range(dim[0])]
646 for i in range(dim[0]):
647 for j in range(dim[1]):
648 text = str(BroadcastValue(value=df, dimension=dim).iloc(i, j))
649 col_width = BroadcastValue(value=col_widths, dimension=dim).iloc(
650 i, j
651 )
653 # Enhanced: Use calculate_lines method for better text wrapping
654 self.cell_nrow[i, j] = self.calculate_lines(
655 text=text, available_width=col_width, row_idx=i, col_idx=j
656 )
658 rows: MutableSequence[str] = []
659 for i in range(dim[0]):
660 row = df.row(i)
661 cells = []
663 for j in range(dim[1]):
664 if j == dim[1] - 1:
665 border_right = Border(
666 style=BroadcastValue(
667 value=self.border_right, dimension=dim
668 ).iloc(i, j)
669 )
670 else:
671 border_right = None
673 # Handle null values - display as empty string instead of "None"
674 raw_value = row[j]
675 cell_value = "" if raw_value is None else str(raw_value)
677 cell = Cell(
678 text=TextContent(
679 text=cell_value,
680 font=get_broadcast_value("text_font", i, j),
681 size=get_broadcast_value("text_font_size", i, j),
682 format=get_broadcast_value("text_format", i, j),
683 color=get_broadcast_value("text_color", i, j),
684 background_color=get_broadcast_value(
685 "text_background_color", i, j
686 ),
687 justification=get_broadcast_value("text_justification", i, j),
688 indent_first=get_broadcast_value("text_indent_first", i, j),
689 indent_left=get_broadcast_value("text_indent_left", i, j),
690 indent_right=get_broadcast_value("text_indent_right", i, j),
691 space=get_broadcast_value("text_space", i, j),
692 space_before=get_broadcast_value("text_space_before", i, j),
693 space_after=get_broadcast_value("text_space_after", i, j),
694 convert=get_broadcast_value("text_convert", i, j),
695 hyphenation=get_broadcast_value("text_hyphenation", i, j),
696 ),
697 width=col_widths[j],
698 border_left=Border(style=get_broadcast_value("border_left", i, j)),
699 border_right=border_right,
700 border_top=Border(style=get_broadcast_value("border_top", i, j)),
701 border_bottom=Border(
702 style=get_broadcast_value("border_bottom", i, j)
703 ),
704 vertical_justification=get_broadcast_value(
705 "cell_vertical_justification", i, j
706 ),
707 )
708 cells.append(cell)
709 rtf_row = Row(
710 row_cells=cells,
711 justification=get_broadcast_value("cell_justification", i, 0),
712 height=get_broadcast_value("cell_height", i, 0),
713 )
714 rows.extend(rtf_row._as_rtf())
716 return rows
719class BroadcastValue(BaseModel):
720 model_config = ConfigDict(arbitrary_types_allowed=True)
722 value: Any = Field(
723 ...,
724 description="The value of the table, can be various types including DataFrame.",
725 )
727 dimension: tuple[int, int] | None = Field(
728 None, description="Dimensions of the table (rows, columns)"
729 )
731 @field_validator("value", mode="before")
732 def convert_value(cls, v):
733 return _to_nested_list(v)
735 @field_validator("dimension")
736 def validate_dimension(cls, v):
737 if v is None:
738 return v
740 if not isinstance(v, tuple) or len(v) != 2:
741 raise TypeError("dimension must be a tuple of (rows, columns)")
743 rows, cols = v
744 if not isinstance(rows, int) or not isinstance(cols, int):
745 raise TypeError("dimension values must be integers")
747 if rows < 0 or cols <= 0:
748 raise ValueError("rows must be non-negative and cols must be positive")
750 return v
752 def iloc(self, row_index: int, column_index: int) -> Any:
753 if self.value is None:
754 return None
756 try:
757 return self.value[row_index % len(self.value)][
758 column_index % len(self.value[0])
759 ]
760 except IndexError as e:
761 raise ValueError(f"Invalid DataFrame index or slice: {e}") from e
763 def to_list(self) -> list | None:
764 if self.value is None:
765 return None
767 if self.dimension is None:
768 return self.value
770 row_count, col_count = len(self.value), len(self.value[0])
772 row_repeats = max(1, (self.dimension[0] + row_count - 1) // row_count)
773 col_repeats = max(1, (self.dimension[1] + col_count - 1) // col_count)
775 value = [column * col_repeats for column in self.value] * row_repeats
776 return [row[: self.dimension[1]] for row in value[: self.dimension[0]]]
778 def update_row(self, row_index: int, row_value: list):
779 if self.value is None:
780 return None
782 self.value = self.to_list()
783 self.value[row_index] = row_value
784 return self.value
786 def update_column(self, column_index: int, column_value: list):
787 if self.value is None:
788 return None
790 self.value = self.to_list()
791 for i, row in enumerate(self.value):
792 row[column_index] = column_value[i]
793 return self.value
795 def update_cell(self, row_index: int, column_index: int, cell_value: Any):
796 if self.value is None:
797 return None
799 self.value = self.to_list()
800 self.value[row_index][column_index] = cell_value
801 return self.value