Coverage for src/rtflite/attributes.py: 80%
336 statements
« prev ^ index » next coverage.py v7.10.4, created at 2025-08-17 01:22 +0000
« prev ^ index » next coverage.py v7.10.4, created at 2025-08-17 01:22 +0000
1import math
2from collections.abc import MutableSequence, Sequence
3from typing import Any, Tuple
5import narwhals as nw
6import polars as pl
7from pydantic import BaseModel, ConfigDict, Field, field_validator
9from rtflite.row import (
10 BORDER_CODES,
11 FORMAT_CODES,
12 TEXT_JUSTIFICATION_CODES,
13 VERTICAL_ALIGNMENT_CODES,
14 Border,
15 Cell,
16 Row,
17 TextContent,
18 Utils,
19)
20from rtflite.services.color_service import color_service
21from rtflite.strwidth import get_string_width
24def _to_nested_list(v):
25 if v is None:
26 return None
28 if isinstance(v, (int, str, float, bool)):
29 v = [[v]]
31 if isinstance(v, Sequence):
32 if isinstance(v, list) and any(
33 isinstance(item, (str, int, float, bool)) for item in v
34 ):
35 v = [v]
36 elif isinstance(v, list) and all(isinstance(item, list) for item in v):
37 v = v
38 elif isinstance(v, tuple):
39 v = [[item] for item in v]
40 else:
41 raise TypeError("Invalid value type. Must be a list or tuple.")
43 # Use narwhals to handle any DataFrame type
44 if hasattr(v, "__dataframe__") or hasattr(
45 v, "columns"
46 ): # Check if it's DataFrame-like
47 if isinstance(v, pl.DataFrame):
48 v = [list(row) for row in v.rows()] # Convert tuples to lists
49 else:
50 try:
51 nw_df = nw.from_native(v)
52 v = [
53 list(row) for row in nw_df.to_native(pl.DataFrame).rows()
54 ] # Convert tuples to lists
55 except Exception:
56 # If narwhals can't handle it, try direct conversion
57 if isinstance(v, pl.DataFrame):
58 v = [list(row) for row in v.rows()] # Convert tuples to lists
60 # Convert numpy arrays or array-like objects to lists
61 if hasattr(v, "__array__") and hasattr(v, "tolist"):
62 v = v.tolist()
64 return v
67class TextAttributes(BaseModel):
68 """Base class for text-related attributes in RTF components"""
70 text_font: list[int] | list[list[int]] | None = Field(
71 default=None, description="Font number for text"
72 )
74 @field_validator("text_font", mode="after")
75 def validate_text_font(cls, v):
76 if v is None:
77 return v
79 # Check if it's a nested list
80 if v and isinstance(v[0], list):
81 for row in v:
82 for font in row:
83 if font not in Utils._font_type()["type"]:
84 raise ValueError(f"Invalid font number: {font}")
85 else:
86 # Flat list
87 for font in v:
88 if font not in Utils._font_type()["type"]:
89 raise ValueError(f"Invalid font number: {font}")
90 return v
92 text_format: list[str] | list[list[str]] | None = Field(
93 default=None,
94 description="Text formatting (e.g. 'b' for 'bold', 'i' for'italic')",
95 )
97 @field_validator("text_format", mode="after")
98 def validate_text_format(cls, v):
99 if v is None:
100 return v
102 # Check if it's a nested list
103 if v and isinstance(v[0], list):
104 for row in v:
105 for format in row:
106 for fmt in format:
107 if fmt not in FORMAT_CODES:
108 raise ValueError(f"Invalid text format: {fmt}")
109 else:
110 # Flat list
111 for format in v:
112 for fmt in format:
113 if fmt not in FORMAT_CODES:
114 raise ValueError(f"Invalid text format: {fmt}")
115 return v
117 text_font_size: list[float] | list[list[float]] | None = Field(
118 default=None, description="Font size in points"
119 )
121 @field_validator("text_font_size", mode="after")
122 def validate_text_font_size(cls, v):
123 if v is None:
124 return v
126 # Check if it's a nested list
127 if v and isinstance(v[0], list):
128 for row in v:
129 for size in row:
130 if size <= 0:
131 raise ValueError(f"Invalid font size: {size}")
132 else:
133 # Flat list
134 for size in v:
135 if size <= 0:
136 raise ValueError(f"Invalid font size: {size}")
137 return v
139 text_color: list[str] | list[list[str]] | None = Field(
140 default=None, description="Text color name or RGB value"
141 )
143 @field_validator("text_color", mode="after")
144 def validate_text_color(cls, v):
145 if v is None:
146 return v
148 # Check if it's a nested list
149 if v and isinstance(v[0], list):
150 for row in v:
151 for color in row:
152 # Allow empty string for "no color"
153 if color and not color_service.validate_color(color):
154 suggestions = color_service.get_color_suggestions(color, 3)
155 suggestion_text = (
156 f" Did you mean: {', '.join(suggestions)}?"
157 if suggestions
158 else ""
159 )
160 raise ValueError(
161 f"Invalid text color: '{color}'.{suggestion_text}"
162 )
163 else:
164 # Flat list
165 for color in v:
166 # Allow empty string for "no color"
167 if color and not color_service.validate_color(color):
168 suggestions = color_service.get_color_suggestions(color, 3)
169 suggestion_text = (
170 f" Did you mean: {', '.join(suggestions)}?"
171 if suggestions
172 else ""
173 )
174 raise ValueError(f"Invalid text color: '{color}'.{suggestion_text}")
175 return v
177 text_background_color: list[str] | list[list[str]] | None = Field(
178 default=None, description="Background color name or RGB value"
179 )
181 @field_validator("text_background_color", mode="after")
182 def validate_text_background_color(cls, v):
183 if v is None:
184 return v
186 # Check if it's a nested list
187 if v and isinstance(v[0], list):
188 for row in v:
189 for color in row:
190 # Allow empty string for "no color"
191 if color and not color_service.validate_color(color):
192 suggestions = color_service.get_color_suggestions(color, 3)
193 suggestion_text = (
194 f" Did you mean: {', '.join(suggestions)}?"
195 if suggestions
196 else ""
197 )
198 raise ValueError(
199 f"Invalid text background color: '{color}'.{suggestion_text}"
200 )
201 else:
202 # Flat list
203 for color in v:
204 # Allow empty string for "no color"
205 if color and not color_service.validate_color(color):
206 suggestions = color_service.get_color_suggestions(color, 3)
207 suggestion_text = (
208 f" Did you mean: {', '.join(suggestions)}?"
209 if suggestions
210 else ""
211 )
212 raise ValueError(
213 f"Invalid text background color: '{color}'.{suggestion_text}"
214 )
215 return v
217 text_justification: list[str] | list[list[str]] | None = Field(
218 default=None,
219 description="Text alignment ('l'=left, 'c'=center, 'r'=right, 'j'=justify)",
220 )
222 @field_validator("text_justification", mode="after")
223 def validate_text_justification(cls, v):
224 if v is None:
225 return v
227 # Check if it's a nested list
228 if v and isinstance(v[0], list):
229 for row in v:
230 for justification in row:
231 if justification not in TEXT_JUSTIFICATION_CODES:
232 raise ValueError(f"Invalid text justification: {justification}")
233 else:
234 # Flat list
235 for justification in v:
236 if justification not in TEXT_JUSTIFICATION_CODES:
237 raise ValueError(f"Invalid text justification: {justification}")
238 return v
240 text_indent_first: list[int] | list[list[int]] | None = Field(
241 default=None, description="First line indent in twips"
242 )
243 text_indent_left: list[int] | list[list[int]] | None = Field(
244 default=None, description="Left indent in twips"
245 )
246 text_indent_right: list[int] | list[list[int]] | None = Field(
247 default=None, description="Right indent in twips"
248 )
249 text_space: list[int] | list[list[int]] | None = Field(
250 default=None, description="Line spacing multiplier"
251 )
252 text_space_before: list[int] | list[list[int]] | None = Field(
253 default=None, description="Space before paragraph in twips"
254 )
255 text_space_after: list[int] | list[list[int]] | None = Field(
256 default=None, description="Space after paragraph in twips"
257 )
258 text_hyphenation: list[bool] | list[list[bool]] | None = Field(
259 default=None, description="Enable automatic hyphenation"
260 )
261 text_convert: list[bool] | list[list[bool]] | None = Field(
262 default=[True], description="Convert LaTeX commands to Unicode characters"
263 )
265 @field_validator(
266 "text_font",
267 "text_format",
268 "text_font_size",
269 "text_color",
270 "text_background_color",
271 "text_justification",
272 "text_indent_first",
273 "text_indent_left",
274 "text_indent_right",
275 "text_space",
276 "text_space_before",
277 "text_space_after",
278 "text_hyphenation",
279 "text_convert",
280 mode="before",
281 )
282 def convert_to_list(cls, v):
283 """Convert single values to lists before validation."""
284 if v is not None and isinstance(v, (int, str, float, bool)):
285 return [v]
286 return v
288 def _encode_text(self, text: Sequence[str], method: str) -> str | list[str]:
289 """Convert the RTF title into RTF syntax using the Text class."""
291 dim = [len(text), 1]
293 def get_broadcast_value(attr_name, row_idx, col_idx=0):
294 """Helper function to get broadcast value for a given attribute at specified indices."""
295 attr_value = getattr(self, attr_name)
296 return BroadcastValue(value=attr_value, dimension=dim).iloc(
297 row_idx, col_idx
298 )
300 text_components = []
301 for i in range(dim[0]):
302 text_components.append(
303 TextContent(
304 text=str(text[i]),
305 font=get_broadcast_value("text_font", i),
306 size=get_broadcast_value("text_font_size", i),
307 format=get_broadcast_value("text_format", i),
308 color=get_broadcast_value("text_color", i),
309 background_color=get_broadcast_value("text_background_color", i),
310 justification=get_broadcast_value("text_justification", i),
311 indent_first=get_broadcast_value("text_indent_first", i),
312 indent_left=get_broadcast_value("text_indent_left", i),
313 indent_right=get_broadcast_value("text_indent_right", i),
314 space=get_broadcast_value("text_space", i),
315 space_before=get_broadcast_value("text_space_before", i),
316 space_after=get_broadcast_value("text_space_after", i),
317 convert=get_broadcast_value("text_convert", i),
318 hyphenation=get_broadcast_value("text_hyphenation", i),
319 )
320 )
322 if method == "paragraph":
323 return [
324 text_component._as_rtf(method="paragraph")
325 for text_component in text_components
326 ]
328 if method == "line":
329 line = "\\line".join(
330 [
331 text_component._as_rtf(method="plain")
332 for text_component in text_components
333 ]
334 )
335 return TextContent(
336 text=str(line),
337 font=get_broadcast_value("text_font", i),
338 size=get_broadcast_value("text_font_size", i),
339 format=get_broadcast_value("text_format", i),
340 color=get_broadcast_value("text_color", i),
341 background_color=get_broadcast_value("text_background_color", i),
342 justification=get_broadcast_value("text_justification", i),
343 indent_first=get_broadcast_value("text_indent_first", i),
344 indent_left=get_broadcast_value("text_indent_left", i),
345 indent_right=get_broadcast_value("text_indent_right", i),
346 space=get_broadcast_value("text_space", i),
347 space_before=get_broadcast_value("text_space_before", i),
348 space_after=get_broadcast_value("text_space_after", i),
349 convert=get_broadcast_value("text_convert", i),
350 hyphenation=get_broadcast_value("text_hyphenation", i),
351 )._as_rtf(method="paragraph_format")
353 raise ValueError(f"Invalid method: {method}")
355 def calculate_lines(
356 self, text: str, available_width: float, row_idx: int = 0, col_idx: int = 0
357 ) -> int:
358 """
359 Calculate number of lines needed for text given available width.
361 Args:
362 text: Text content to measure
363 available_width: Available width in inches
364 row_idx: Row index for attribute lookup (default: 0)
365 col_idx: Column index for attribute lookup (default: 0)
367 Returns:
368 Number of lines needed (minimum 1)
369 """
370 if not text or available_width <= 0:
371 return 1
373 # Create a dummy dimension for broadcast lookup
374 dim = (max(1, row_idx + 1), max(1, col_idx + 1))
376 # Get font attributes using broadcast logic - raise error if None
377 if self.text_font is None:
378 raise ValueError("text_font must be set to calculate lines")
379 font_broadcast = BroadcastValue(value=self.text_font, dimension=dim)
380 font_number = font_broadcast.iloc(row_idx, col_idx)
382 if self.text_font_size is None:
383 raise ValueError("text_font_size must be set to calculate lines")
384 size_broadcast = BroadcastValue(value=self.text_font_size, dimension=dim)
385 font_size = size_broadcast.iloc(row_idx, col_idx)
387 # Calculate total text width
388 total_width = get_string_width(
389 text=text, font=font_number, font_size=font_size, unit="in"
390 )
392 # Simple approximation: divide total width by available width and round up
393 return max(1, int(math.ceil(total_width / available_width)))
396class TableAttributes(TextAttributes):
397 """Base class for table-related attributes in RTF components"""
399 col_rel_width: list[float] | None = Field(
400 default=None, description="Relative widths of table columns"
401 )
403 border_left: list[list[str]] = Field(
404 default=[[""]], description="Left border style"
405 )
406 border_right: list[list[str]] = Field(
407 default=[[""]], description="Right border style"
408 )
409 border_top: list[list[str]] = Field(default=[[""]], description="Top border style")
410 border_bottom: list[list[str]] = Field(
411 default=[[""]], description="Bottom border style"
412 )
413 border_first: list[list[str]] = Field(
414 default=[[""]], description="First row border style"
415 )
416 border_last: list[list[str]] = Field(
417 default=[[""]], description="Last row border style"
418 )
419 border_color_left: list[list[str]] = Field(
420 default=[[""]], description="Left border color"
421 )
422 border_color_right: list[list[str]] = Field(
423 default=[[""]], description="Right border color"
424 )
425 border_color_top: list[list[str]] = Field(
426 default=[[""]], description="Top border color"
427 )
428 border_color_bottom: list[list[str]] = Field(
429 default=[[""]], description="Bottom border color"
430 )
431 border_color_first: list[list[str]] = Field(
432 default=[[""]], description="First row border color"
433 )
434 border_color_last: list[list[str]] = Field(
435 default=[[""]], description="Last row border color"
436 )
438 @field_validator(
439 "border_color_left",
440 "border_color_right",
441 "border_color_top",
442 "border_color_bottom",
443 "border_color_first",
444 "border_color_last",
445 mode="after",
446 )
447 def validate_border_colors(cls, v):
448 if v is None:
449 return v
451 for row in v:
452 for color in row:
453 # Allow empty string for no color
454 if color and not color_service.validate_color(color):
455 suggestions = color_service.get_color_suggestions(color, 3)
456 suggestion_text = (
457 f" Did you mean: {', '.join(suggestions)}?"
458 if suggestions
459 else ""
460 )
461 raise ValueError(
462 f"Invalid border color: '{color}'.{suggestion_text}"
463 )
464 return v
466 border_width: list[list[int]] = Field(
467 default=[[15]], description="Border width in twips"
468 )
469 cell_height: list[list[float]] = Field(
470 default=[[0.15]], description="Cell height in inches"
471 )
472 cell_justification: list[list[str]] = Field(
473 default=[["l"]],
474 description="Cell horizontal alignment ('l'=left, 'c'=center, 'r'=right, 'j'=justify)",
475 )
477 cell_vertical_justification: list[list[str]] = Field(
478 default=[["center"]],
479 description="Cell vertical alignment ('top', 'center', 'bottom')",
480 )
482 @field_validator("cell_vertical_justification", mode="after")
483 def validate_cell_vertical_justification(cls, v):
484 if v is None:
485 return v
487 for row in v:
488 for justification in row:
489 if justification not in VERTICAL_ALIGNMENT_CODES:
490 raise ValueError(
491 f"Invalid cell vertical justification: {justification}"
492 )
493 return v
495 cell_nrow: list[list[int]] = Field(
496 default=[[1]], description="Number of rows per cell"
497 )
499 @field_validator("col_rel_width", mode="before")
500 def convert_col_rel_width_to_list(cls, v):
501 if v is not None and isinstance(v, (int, str, float, bool)):
502 return [v]
503 return v
505 @field_validator(
506 "border_left",
507 "border_right",
508 "border_top",
509 "border_bottom",
510 "border_first",
511 "border_last",
512 "border_color_left",
513 "border_color_right",
514 "border_color_top",
515 "border_color_bottom",
516 "border_color_first",
517 "border_color_last",
518 "border_width",
519 "cell_height",
520 "cell_justification",
521 "cell_vertical_justification",
522 "cell_nrow",
523 "text_font",
524 "text_format",
525 "text_font_size",
526 "text_color",
527 "text_background_color",
528 "text_justification",
529 "text_indent_first",
530 "text_indent_left",
531 "text_indent_right",
532 "text_space",
533 "text_space_before",
534 "text_space_after",
535 "text_hyphenation",
536 "text_convert",
537 mode="before",
538 )
539 def convert_to_nested_list(cls, v):
540 return _to_nested_list(v)
542 @field_validator(
543 "col_rel_width", "border_width", "cell_height", "cell_nrow", mode="after"
544 )
545 def validate_positive_value(cls, v):
546 if v is not None:
547 # Check if any value is <= 0
548 if isinstance(v[0], (list, tuple)):
549 # 2D array
550 if any(val <= 0 for row in v for val in row):
551 raise ValueError(
552 f"{cls.__field_name__.capitalize()} must be positive"
553 )
554 else:
555 # 1D array
556 if any(val <= 0 for val in v):
557 raise ValueError(
558 f"{cls.__field_name__.capitalize()} must be positive"
559 )
560 return v
562 @field_validator("cell_justification", mode="after")
563 def validate_cell_justification(cls, v):
564 if v is None:
565 return v
567 for row in v:
568 for justification in row:
569 if justification not in TEXT_JUSTIFICATION_CODES:
570 raise ValueError(f"Invalid cell justification: {justification}")
571 return v
573 @field_validator(
574 "border_left",
575 "border_right",
576 "border_top",
577 "border_bottom",
578 "border_first",
579 "border_last",
580 mode="after",
581 )
582 def validate_border(cls, v):
583 """Validate that all border styles are valid."""
584 if v is None:
585 return v
587 for row in v:
588 for border in row:
589 if border not in BORDER_CODES:
590 field_name = cls.__field_name__.capitalize()
591 raise ValueError(
592 f"{field_name} with invalid border style: {border}"
593 )
595 return v
597 def _get_section_attributes(self, indices) -> dict:
598 """Helper method to collect all attributes for a section"""
599 # Get all attributes that start with text_, col_, border_, or cell_
600 attrs = {}
601 for attr in dir(self):
602 if (
603 attr.startswith("text_")
604 or attr.startswith("col_")
605 or attr.startswith("border_")
606 or attr.startswith("cell_")
607 ):
608 if not callable(getattr(self, attr)):
609 attrs[attr] = getattr(self, attr)
611 # Broadcast attributes to section indices, excluding None values
612 return {
613 attr: [
614 BroadcastValue(value=val, dimension=None).iloc(row, col)
615 for row, col in indices
616 ]
617 for attr, val in attrs.items()
618 if val is not None
619 }
621 def _encode(
622 self, df: pl.DataFrame, col_widths: Sequence[float]
623 ) -> MutableSequence[str]:
624 dim = df.shape
626 def get_broadcast_value(attr_name, row_idx, col_idx=0):
627 """Helper function to get broadcast value for a given attribute at specified indices."""
628 attr_value = getattr(self, attr_name)
629 return BroadcastValue(value=attr_value, dimension=dim).iloc(
630 row_idx, col_idx
631 )
633 if self.cell_nrow is None:
634 self.cell_nrow = [[0.0 for _ in range(dim[1])] for _ in range(dim[0])]
636 for i in range(dim[0]):
637 for j in range(dim[1]):
638 text = str(BroadcastValue(value=df, dimension=dim).iloc(i, j))
639 col_width = BroadcastValue(value=col_widths, dimension=dim).iloc(
640 i, j
641 )
643 # Enhanced: Use calculate_lines method for better text wrapping
644 self.cell_nrow[i, j] = self.calculate_lines(
645 text=text, available_width=col_width, row_idx=i, col_idx=j
646 )
648 rows: MutableSequence[str] = []
649 for i in range(dim[0]):
650 row = df.row(i)
651 cells = []
653 for j in range(dim[1]):
654 if j == dim[1] - 1:
655 border_right = Border(
656 style=BroadcastValue(
657 value=self.border_right, dimension=dim
658 ).iloc(i, j)
659 )
660 else:
661 border_right = None
663 # Handle null values - display as empty string instead of "None"
664 raw_value = row[j]
665 if raw_value is None:
666 cell_value = ""
667 else:
668 cell_value = str(raw_value)
670 cell = Cell(
671 text=TextContent(
672 text=cell_value,
673 font=get_broadcast_value("text_font", i, j),
674 size=get_broadcast_value("text_font_size", i, j),
675 format=get_broadcast_value("text_format", i, j),
676 color=get_broadcast_value("text_color", i, j),
677 background_color=get_broadcast_value(
678 "text_background_color", i, j
679 ),
680 justification=get_broadcast_value("text_justification", i, j),
681 indent_first=get_broadcast_value("text_indent_first", i, j),
682 indent_left=get_broadcast_value("text_indent_left", i, j),
683 indent_right=get_broadcast_value("text_indent_right", i, j),
684 space=get_broadcast_value("text_space", i, j),
685 space_before=get_broadcast_value("text_space_before", i, j),
686 space_after=get_broadcast_value("text_space_after", i, j),
687 convert=get_broadcast_value("text_convert", i, j),
688 hyphenation=get_broadcast_value("text_hyphenation", i, j),
689 ),
690 width=col_widths[j],
691 border_left=Border(style=get_broadcast_value("border_left", i, j)),
692 border_right=border_right,
693 border_top=Border(style=get_broadcast_value("border_top", i, j)),
694 border_bottom=Border(
695 style=get_broadcast_value("border_bottom", i, j)
696 ),
697 vertical_justification=get_broadcast_value(
698 "cell_vertical_justification", i, j
699 ),
700 )
701 cells.append(cell)
702 rtf_row = Row(
703 row_cells=cells,
704 justification=get_broadcast_value("cell_justification", i, 0),
705 height=get_broadcast_value("cell_height", i, 0),
706 )
707 rows.extend(rtf_row._as_rtf())
709 return rows
712class BroadcastValue(BaseModel):
713 model_config = ConfigDict(arbitrary_types_allowed=True)
715 value: Any = Field(
716 ...,
717 description="The value of the table, can be various types including DataFrame.",
718 )
720 dimension: Tuple[int, int] | None = Field(
721 None, description="Dimensions of the table (rows, columns)"
722 )
724 @field_validator("value", mode="before")
725 def convert_value(cls, v):
726 return _to_nested_list(v)
728 @field_validator("dimension")
729 def validate_dimension(cls, v):
730 if v is None:
731 return v
733 if not isinstance(v, tuple) or len(v) != 2:
734 raise TypeError("dimension must be a tuple of (rows, columns)")
736 rows, cols = v
737 if not isinstance(rows, int) or not isinstance(cols, int):
738 raise TypeError("dimension values must be integers")
740 if rows < 0 or cols <= 0:
741 raise ValueError("rows must be non-negative and cols must be positive")
743 return v
745 def iloc(self, row_index: int, column_index: int) -> Any:
746 if self.value is None:
747 return None
749 try:
750 return self.value[row_index % len(self.value)][
751 column_index % len(self.value[0])
752 ]
753 except IndexError as e:
754 raise ValueError(f"Invalid DataFrame index or slice: {e}")
756 def to_list(self) -> list | None:
757 if self.value is None:
758 return None
760 if self.dimension is None:
761 return self.value
763 row_count, col_count = len(self.value), len(self.value[0])
765 row_repeats = max(1, (self.dimension[0] + row_count - 1) // row_count)
766 col_repeats = max(1, (self.dimension[1] + col_count - 1) // col_count)
768 value = [column * col_repeats for column in self.value] * row_repeats
769 return [row[: self.dimension[1]] for row in value[: self.dimension[0]]]
771 def update_row(self, row_index: int, row_value: list):
772 if self.value is None:
773 return None
775 self.value = self.to_list()
776 self.value[row_index] = row_value
777 return self.value
779 def update_column(self, column_index: int, column_value: list):
780 if self.value is None:
781 return None
783 self.value = self.to_list()
784 for i, row in enumerate(self.value):
785 row[column_index] = column_value[i]
786 return self.value
788 def update_cell(self, row_index: int, column_index: int, cell_value: Any):
789 if self.value is None:
790 return None
792 self.value = self.to_list()
793 self.value[row_index][column_index] = cell_value
794 return self.value