Coverage for src/rtflite/attributes.py: 78%
255 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-07 05:03 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-07 05:03 +0000
1from collections.abc import MutableSequence, Sequence
2from typing import Any, Tuple
4import numpy as np
5import pandas as pd
6import polars as pl
7from pydantic import BaseModel, ConfigDict, Field, field_validator
9from rtflite.row import (
10 BORDER_CODES,
11 FORMAT_CODES,
12 TEXT_JUSTIFICATION_CODES,
13 VERTICAL_ALIGNMENT_CODES,
14 Border,
15 Cell,
16 Row,
17 TextContent,
18 Utils,
19)
20from rtflite.strwidth import get_string_width
23def _to_nested_list(v):
24 if v is None:
25 return None
27 if isinstance(v, (int, str, float, bool)):
28 v = [[v]]
30 if isinstance(v, Sequence):
31 if isinstance(v, list) and any(
32 isinstance(item, (str, int, float, bool)) for item in v
33 ):
34 v = [v]
35 elif isinstance(v, list) and all(isinstance(item, list) for item in v):
36 v = v
37 elif isinstance(v, tuple):
38 v = [[item] for item in v]
39 else:
40 raise TypeError("Invalid value type. Must be a list or tuple.")
42 if isinstance(v, pd.DataFrame):
43 v = v.values.tolist()
45 if isinstance(v, pl.DataFrame):
46 v = v.to_pandas().values.tolist()
48 if isinstance(v, np.ndarray):
49 v = v.tolist()
51 return v
54class TextAttributes(BaseModel):
55 """Base class for text-related attributes in RTF components"""
57 text_font: list[int] | None = Field(
58 default=None, description="Font number for text"
59 )
61 @field_validator("text_font", mode="after")
62 def validate_text_font(cls, v):
63 if v is None:
64 return v
66 for font in v:
67 if font not in Utils._font_type()["type"]:
68 raise ValueError(f"Invalid font number: {font}")
69 return v
71 text_format: list[str] | None = Field(
72 default=None,
73 description="Text formatting (e.g. 'b' for 'bold', 'i' for'italic')",
74 )
76 @field_validator("text_format", mode="after")
77 def validate_text_format(cls, v):
78 if v is None:
79 return v
81 for format in v:
82 for fmt in format:
83 if fmt not in FORMAT_CODES:
84 raise ValueError(f"Invalid text format: {fmt}")
85 return v
87 text_font_size: list[float] | None = Field(
88 default=None, description="Font size in points"
89 )
91 @field_validator("text_font_size", mode="after")
92 def validate_text_font_size(cls, v):
93 if v is None:
94 return v
96 for size in v:
97 if size <= 0:
98 raise ValueError(f"Invalid font size: {size}")
99 return v
101 text_color: list[str] | None = Field(
102 default=None, description="Text color name or RGB value"
103 )
104 text_background_color: list[str] | None = Field(
105 default=None, description="Background color name or RGB value"
106 )
107 text_justification: list[str] | None = Field(
108 default=None,
109 description="Text alignment ('l'=left, 'c'=center, 'r'=right, 'j'=justify)",
110 )
112 @field_validator("text_justification", mode="after")
113 def validate_text_justification(cls, v):
114 if v is None:
115 return v
117 for justification in v:
118 if justification not in TEXT_JUSTIFICATION_CODES:
119 raise ValueError(f"Invalid text justification: {justification}")
120 return v
122 text_indent_first: list[int] | None = Field(
123 default=None, description="First line indent in twips"
124 )
125 text_indent_left: list[int] | None = Field(
126 default=None, description="Left indent in twips"
127 )
128 text_indent_right: list[int] | None = Field(
129 default=None, description="Right indent in twips"
130 )
131 text_space: list[int] | None = Field(
132 default=None, description="Line spacing multiplier"
133 )
134 text_space_before: list[int] | None = Field(
135 default=None, description="Space before paragraph in twips"
136 )
137 text_space_after: list[int] | None = Field(
138 default=None, description="Space after paragraph in twips"
139 )
140 text_hyphenation: list[bool] | None = Field(
141 default=None, description="Enable automatic hyphenation"
142 )
143 text_convert: list[bool] | None = Field(
144 default=None, description="Convert special characters to RTF"
145 )
147 @field_validator(
148 "text_font",
149 "text_format",
150 "text_font_size",
151 "text_color",
152 "text_background_color",
153 "text_justification",
154 "text_indent_first",
155 "text_indent_left",
156 "text_indent_right",
157 "text_space",
158 "text_space_before",
159 "text_space_after",
160 "text_hyphenation",
161 "text_convert",
162 mode="before",
163 )
164 def convert_to_list(cls, v):
165 """Convert single values to lists before validation."""
166 if v is not None and isinstance(v, (int, str, float, bool)):
167 return [v]
168 return v
170 def _encode(self, text: Sequence[str], method: str) -> str:
171 """Convert the RTF title into RTF syntax using the Text class."""
173 dim = [len(text), 1]
175 def get_broadcast_value(attr_name, row_idx, col_idx=0):
176 """Helper function to get broadcast value for a given attribute at specified indices."""
177 attr_value = getattr(self, attr_name)
178 return BroadcastValue(value=attr_value, dimension=dim).iloc(
179 row_idx, col_idx
180 )
182 text_components = []
183 for i in range(dim[0]):
184 text_components.append(
185 TextContent(
186 text=str(text[i]),
187 font=get_broadcast_value("text_font", i),
188 size=get_broadcast_value("text_font_size", i),
189 format=get_broadcast_value("text_format", i),
190 color=get_broadcast_value("text_color", i),
191 background_color=get_broadcast_value("text_background_color", i),
192 justification=get_broadcast_value("text_justification", i),
193 indent_first=get_broadcast_value("text_indent_first", i),
194 indent_left=get_broadcast_value("text_indent_left", i),
195 indent_right=get_broadcast_value("text_indent_right", i),
196 space=get_broadcast_value("text_space", i),
197 space_before=get_broadcast_value("text_space_before", i),
198 space_after=get_broadcast_value("text_space_after", i),
199 convert=get_broadcast_value("text_convert", i),
200 hyphenation=get_broadcast_value("text_hyphenation", i),
201 )
202 )
204 if method == "paragraph":
205 return [
206 text_component._as_rtf(method="paragraph")
207 for text_component in text_components
208 ]
210 if method == "line":
211 line = "\\line".join(
212 [
213 text_component._as_rtf(method="plain")
214 for text_component in text_components
215 ]
216 )
217 return TextContent(
218 text=str(line),
219 font=get_broadcast_value("text_font", i),
220 size=get_broadcast_value("text_font_size", i),
221 format=get_broadcast_value("text_format", i),
222 color=get_broadcast_value("text_color", i),
223 background_color=get_broadcast_value("text_background_color", i),
224 justification=get_broadcast_value("text_justification", i),
225 indent_first=get_broadcast_value("text_indent_first", i),
226 indent_left=get_broadcast_value("text_indent_left", i),
227 indent_right=get_broadcast_value("text_indent_right", i),
228 space=get_broadcast_value("text_space", i),
229 space_before=get_broadcast_value("text_space_before", i),
230 space_after=get_broadcast_value("text_space_after", i),
231 convert=get_broadcast_value("text_convert", i),
232 hyphenation=get_broadcast_value("text_hyphenation", i),
233 )._as_rtf(method="paragraph_format")
235 raise ValueError(f"Invalid method: {method}")
238class TableAttributes(TextAttributes):
239 """Base class for table-related attributes in RTF components"""
241 col_rel_width: list[float] | None = Field(
242 default=None, description="Relative widths of table columns"
243 )
245 border_left: list[list[str]] = Field(
246 default=[[""]], description="Left border style"
247 )
248 border_right: list[list[str]] = Field(
249 default=[[""]], description="Right border style"
250 )
251 border_top: list[list[str]] = Field(default=[[""]], description="Top border style")
252 border_bottom: list[list[str]] = Field(
253 default=[[""]], description="Bottom border style"
254 )
255 border_first: list[list[str]] = Field(
256 default=[[""]], description="First row border style"
257 )
258 border_last: list[list[str]] = Field(
259 default=[[""]], description="Last row border style"
260 )
261 border_color_left: list[list[str]] = Field(
262 default=[[""]], description="Left border color"
263 )
264 border_color_right: list[list[str]] = Field(
265 default=[[""]], description="Right border color"
266 )
267 border_color_top: list[list[str]] = Field(
268 default=[[""]], description="Top border color"
269 )
270 border_color_bottom: list[list[str]] = Field(
271 default=[[""]], description="Bottom border color"
272 )
273 border_color_first: list[list[str]] = Field(
274 default=[[""]], description="First row border color"
275 )
276 border_color_last: list[list[str]] = Field(
277 default=[[""]], description="Last row border color"
278 )
279 border_width: list[list[int]] = Field(
280 default=[[15]], description="Border width in twips"
281 )
282 cell_height: list[list[float]] = Field(
283 default=[[0.15]], description="Cell height in inches"
284 )
285 cell_justification: list[list[str]] = Field(
286 default=[["l"]],
287 description="Cell horizontal alignment ('l'=left, 'c'=center, 'r'=right, 'j'=justify)",
288 )
290 cell_vertical_justification: list[list[str]] = Field(
291 default=[["center"]],
292 description="Cell vertical alignment ('top', 'center', 'bottom')",
293 )
295 @field_validator("cell_vertical_justification", mode="after")
296 def validate_cell_vertical_justification(cls, v):
297 if v is None:
298 return v
300 for row in v:
301 for justification in row:
302 if justification not in VERTICAL_ALIGNMENT_CODES:
303 raise ValueError(
304 f"Invalid cell vertical justification: {justification}"
305 )
306 return v
308 cell_nrow: list[list[int]] = Field(
309 default=[[1]], description="Number of rows per cell"
310 )
312 @field_validator("col_rel_width", mode="before")
313 def convert_to_list(cls, v):
314 if v is not None and isinstance(v, (int, str, float, bool)):
315 return [v]
316 return v
318 @field_validator(
319 "border_left",
320 "border_right",
321 "border_top",
322 "border_bottom",
323 "border_first",
324 "border_last",
325 "border_color_left",
326 "border_color_right",
327 "border_color_top",
328 "border_color_bottom",
329 "border_color_first",
330 "border_color_last",
331 "border_width",
332 "cell_height",
333 "cell_justification",
334 "cell_vertical_justification",
335 "cell_nrow",
336 mode="before",
337 )
338 def convert_to_nested_list(cls, v):
339 return _to_nested_list(v)
341 @field_validator(
342 "col_rel_width", "border_width", "cell_height", "cell_nrow", mode="after"
343 )
344 def validate_positive_value(cls, v):
345 if v is not None and np.any(np.array(v) <= 0):
346 raise ValueError(f"{cls.__field_name__.capitalize()} must be positive")
347 return v
349 @field_validator("cell_justification", mode="after")
350 def validate_cell_justification(cls, v):
351 if v is None:
352 return v
354 for row in v:
355 for justification in row:
356 if justification not in TEXT_JUSTIFICATION_CODES:
357 raise ValueError(f"Invalid cell justification: {justification}")
358 return v
360 @field_validator(
361 "border_left",
362 "border_right",
363 "border_top",
364 "border_bottom",
365 "border_first",
366 "border_last",
367 mode="after",
368 )
369 def validate_border(cls, v):
370 """Validate that all border styles are valid."""
371 if v is None:
372 return v
374 for row in v:
375 for border in row:
376 if border not in BORDER_CODES:
377 field_name = cls.__field_name__.capitalize()
378 raise ValueError(
379 f"{field_name} with invalid border style: {border}"
380 )
382 return v
384 def _get_section_attributes(self, indices) -> dict:
385 """Helper method to collect all attributes for a section"""
386 # Get all attributes that start with text_, col_, border_, or cell_
387 attrs = {}
388 for attr in dir(self):
389 if (
390 attr.startswith("text_")
391 or attr.startswith("col_")
392 or attr.startswith("border_")
393 or attr.startswith("cell_")
394 ):
395 if not callable(getattr(self, attr)):
396 attrs[attr] = getattr(self, attr)
398 # Broadcast attributes to section indices
399 return {
400 attr: [BroadcastValue(value=val).iloc(row, col) for row, col in indices]
401 for attr, val in attrs.items()
402 }
404 def _encode(
405 self, df: pd.DataFrame, col_widths: Sequence[float]
406 ) -> MutableSequence[str]:
407 dim = df.shape
409 def get_broadcast_value(attr_name, row_idx, col_idx=0):
410 """Helper function to get broadcast value for a given attribute at specified indices."""
411 attr_value = getattr(self, attr_name)
412 return BroadcastValue(value=attr_value, dimension=dim).iloc(
413 row_idx, col_idx
414 )
416 if self.cell_nrow is None:
417 self.cell_nrow = np.zeros(dim)
419 for i in range(dim[0]):
420 for j in range(dim[1]):
421 text = str(BroadcastValue(value=df, dimension=dim).iloc(i, j))
422 font_size = BroadcastValue(
423 value=self.text_font_size, dimension=dim
424 ).iloc(i, j)
425 col_width = BroadcastValue(value=col_widths, dimension=dim).iloc(
426 i, j
427 )
428 cell_text_width = get_string_width(text=text, font_size=font_size)
429 self.cell_nrow[i, j] = np.ceil(cell_text_width / col_width)
431 rows: MutableSequence[str] = []
432 for i in range(dim[0]):
433 row = df.iloc[i]
434 cells = []
436 for j in range(dim[1]):
437 col = df.columns[j]
439 if j == dim[1] - 1:
440 border_right = Border(
441 style=BroadcastValue(
442 value=self.border_right, dimension=dim
443 ).iloc(i, j)
444 )
445 else:
446 border_right = None
448 cell = Cell(
449 text=TextContent(
450 text=str(row[col]),
451 font=get_broadcast_value("text_font", i, j),
452 size=get_broadcast_value("text_font_size", i, j),
453 format=get_broadcast_value("text_format", i, j),
454 color=get_broadcast_value("text_color", i, j),
455 background_color=get_broadcast_value(
456 "text_background_color", i, j
457 ),
458 justification=get_broadcast_value("text_justification", i, j),
459 indent_first=get_broadcast_value("text_indent_first", i, j),
460 indent_left=get_broadcast_value("text_indent_left", i, j),
461 indent_right=get_broadcast_value("text_indent_right", i, j),
462 space=get_broadcast_value("text_space", i, j),
463 space_before=get_broadcast_value("text_space_before", i, j),
464 space_after=get_broadcast_value("text_space_after", i, j),
465 convert=get_broadcast_value("text_convert", i, j),
466 hyphenation=get_broadcast_value("text_hyphenation", i, j),
467 ),
468 width=col_widths[j],
469 border_left=Border(style=get_broadcast_value("border_left", i, j)),
470 border_right=border_right,
471 border_top=Border(style=get_broadcast_value("border_top", i, j)),
472 border_bottom=Border(
473 style=get_broadcast_value("border_bottom", i, j)
474 ),
475 vertical_justification=get_broadcast_value(
476 "cell_vertical_justification", i, j
477 ),
478 )
479 cells.append(cell)
480 rtf_row = Row(
481 row_cells=cells,
482 justification=get_broadcast_value("cell_justification", i, 0),
483 height=get_broadcast_value("cell_height", i, 0),
484 )
485 rows.extend(rtf_row._as_rtf())
487 return rows
490class BroadcastValue(BaseModel):
491 model_config = ConfigDict(arbitrary_types_allowed=True)
493 value: list[list[Any]] | None = Field(
494 ...,
495 description="The value of the table, can be various types including DataFrame.",
496 )
498 dimension: Tuple[int, int] | None = Field(
499 None, description="Dimensions of the table (rows, columns)"
500 )
502 @field_validator("value", mode="before")
503 def convert_value(cls, v):
504 return _to_nested_list(v)
506 @field_validator("dimension")
507 def validate_dimension(cls, v):
508 if v is None:
509 return v
511 if not isinstance(v, tuple) or len(v) != 2:
512 raise TypeError("dimension must be a tuple of (rows, columns)")
514 rows, cols = v
515 if not isinstance(rows, int) or not isinstance(cols, int):
516 raise TypeError("dimension values must be integers")
518 if rows <= 0 or cols <= 0:
519 raise ValueError("dimension values must be positive")
521 return v
523 def iloc(self, row_index: int, column_index: int) -> Any:
524 if self.value is None:
525 return None
527 try:
528 return self.value[row_index % len(self.value)][
529 column_index % len(self.value[0])
530 ]
531 except IndexError as e:
532 raise ValueError(f"Invalid DataFrame index or slice: {e}")
534 def to_list(self) -> pd.DataFrame:
535 if self.value is None:
536 return None
538 row_count, col_count = len(self.value), len(self.value[0])
540 row_repeats = max(1, (self.dimension[0] + row_count - 1) // row_count)
541 col_repeats = max(1, (self.dimension[1] + col_count - 1) // col_count)
543 value = [column * col_repeats for column in self.value] * row_repeats
544 return [row[: self.dimension[1]] for row in value[: self.dimension[0]]]
546 def to_numpy(self) -> np.ndarray:
547 if self.value is None:
548 return None
550 return np.array(self.to_list())
552 def to_pandas(self) -> pd.DataFrame:
553 if self.value is None:
554 return None
556 return pd.DataFrame(self.to_list())
558 def to_polars(self) -> pl.DataFrame:
559 if self.value is None:
560 return None
562 return pl.DataFrame(self.to_list())
564 def update_row(self, row_index: int, row_value: list):
565 if self.value is None:
566 return None
568 self.value = self.to_list()
569 self.value[row_index] = row_value
570 return self.value
572 def update_column(self, column_index: int, column_value: list):
573 if self.value is None:
574 return None
576 self.value = self.to_list()
577 for i, row in enumerate(self.value):
578 row[column_index] = column_value[i]
579 return self.value
581 def update_cell(self, row_index: int, column_index: int, cell_value: Any):
582 if self.value is None:
583 return None
585 self.value = self.to_list()
586 self.value[row_index][column_index] = cell_value
587 return self.value