Coverage for src/rtflite/attributes.py: 78%

255 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-05-07 05:03 +0000

1from collections.abc import MutableSequence, Sequence 

2from typing import Any, Tuple 

3 

4import numpy as np 

5import pandas as pd 

6import polars as pl 

7from pydantic import BaseModel, ConfigDict, Field, field_validator 

8 

9from rtflite.row import ( 

10 BORDER_CODES, 

11 FORMAT_CODES, 

12 TEXT_JUSTIFICATION_CODES, 

13 VERTICAL_ALIGNMENT_CODES, 

14 Border, 

15 Cell, 

16 Row, 

17 TextContent, 

18 Utils, 

19) 

20from rtflite.strwidth import get_string_width 

21 

22 

23def _to_nested_list(v): 

24 if v is None: 

25 return None 

26 

27 if isinstance(v, (int, str, float, bool)): 

28 v = [[v]] 

29 

30 if isinstance(v, Sequence): 

31 if isinstance(v, list) and any( 

32 isinstance(item, (str, int, float, bool)) for item in v 

33 ): 

34 v = [v] 

35 elif isinstance(v, list) and all(isinstance(item, list) for item in v): 

36 v = v 

37 elif isinstance(v, tuple): 

38 v = [[item] for item in v] 

39 else: 

40 raise TypeError("Invalid value type. Must be a list or tuple.") 

41 

42 if isinstance(v, pd.DataFrame): 

43 v = v.values.tolist() 

44 

45 if isinstance(v, pl.DataFrame): 

46 v = v.to_pandas().values.tolist() 

47 

48 if isinstance(v, np.ndarray): 

49 v = v.tolist() 

50 

51 return v 

52 

53 

54class TextAttributes(BaseModel): 

55 """Base class for text-related attributes in RTF components""" 

56 

57 text_font: list[int] | None = Field( 

58 default=None, description="Font number for text" 

59 ) 

60 

61 @field_validator("text_font", mode="after") 

62 def validate_text_font(cls, v): 

63 if v is None: 

64 return v 

65 

66 for font in v: 

67 if font not in Utils._font_type()["type"]: 

68 raise ValueError(f"Invalid font number: {font}") 

69 return v 

70 

71 text_format: list[str] | None = Field( 

72 default=None, 

73 description="Text formatting (e.g. 'b' for 'bold', 'i' for'italic')", 

74 ) 

75 

76 @field_validator("text_format", mode="after") 

77 def validate_text_format(cls, v): 

78 if v is None: 

79 return v 

80 

81 for format in v: 

82 for fmt in format: 

83 if fmt not in FORMAT_CODES: 

84 raise ValueError(f"Invalid text format: {fmt}") 

85 return v 

86 

87 text_font_size: list[float] | None = Field( 

88 default=None, description="Font size in points" 

89 ) 

90 

91 @field_validator("text_font_size", mode="after") 

92 def validate_text_font_size(cls, v): 

93 if v is None: 

94 return v 

95 

96 for size in v: 

97 if size <= 0: 

98 raise ValueError(f"Invalid font size: {size}") 

99 return v 

100 

101 text_color: list[str] | None = Field( 

102 default=None, description="Text color name or RGB value" 

103 ) 

104 text_background_color: list[str] | None = Field( 

105 default=None, description="Background color name or RGB value" 

106 ) 

107 text_justification: list[str] | None = Field( 

108 default=None, 

109 description="Text alignment ('l'=left, 'c'=center, 'r'=right, 'j'=justify)", 

110 ) 

111 

112 @field_validator("text_justification", mode="after") 

113 def validate_text_justification(cls, v): 

114 if v is None: 

115 return v 

116 

117 for justification in v: 

118 if justification not in TEXT_JUSTIFICATION_CODES: 

119 raise ValueError(f"Invalid text justification: {justification}") 

120 return v 

121 

122 text_indent_first: list[int] | None = Field( 

123 default=None, description="First line indent in twips" 

124 ) 

125 text_indent_left: list[int] | None = Field( 

126 default=None, description="Left indent in twips" 

127 ) 

128 text_indent_right: list[int] | None = Field( 

129 default=None, description="Right indent in twips" 

130 ) 

131 text_space: list[int] | None = Field( 

132 default=None, description="Line spacing multiplier" 

133 ) 

134 text_space_before: list[int] | None = Field( 

135 default=None, description="Space before paragraph in twips" 

136 ) 

137 text_space_after: list[int] | None = Field( 

138 default=None, description="Space after paragraph in twips" 

139 ) 

140 text_hyphenation: list[bool] | None = Field( 

141 default=None, description="Enable automatic hyphenation" 

142 ) 

143 text_convert: list[bool] | None = Field( 

144 default=None, description="Convert special characters to RTF" 

145 ) 

146 

147 @field_validator( 

148 "text_font", 

149 "text_format", 

150 "text_font_size", 

151 "text_color", 

152 "text_background_color", 

153 "text_justification", 

154 "text_indent_first", 

155 "text_indent_left", 

156 "text_indent_right", 

157 "text_space", 

158 "text_space_before", 

159 "text_space_after", 

160 "text_hyphenation", 

161 "text_convert", 

162 mode="before", 

163 ) 

164 def convert_to_list(cls, v): 

165 """Convert single values to lists before validation.""" 

166 if v is not None and isinstance(v, (int, str, float, bool)): 

167 return [v] 

168 return v 

169 

170 def _encode(self, text: Sequence[str], method: str) -> str: 

171 """Convert the RTF title into RTF syntax using the Text class.""" 

172 

173 dim = [len(text), 1] 

174 

175 def get_broadcast_value(attr_name, row_idx, col_idx=0): 

176 """Helper function to get broadcast value for a given attribute at specified indices.""" 

177 attr_value = getattr(self, attr_name) 

178 return BroadcastValue(value=attr_value, dimension=dim).iloc( 

179 row_idx, col_idx 

180 ) 

181 

182 text_components = [] 

183 for i in range(dim[0]): 

184 text_components.append( 

185 TextContent( 

186 text=str(text[i]), 

187 font=get_broadcast_value("text_font", i), 

188 size=get_broadcast_value("text_font_size", i), 

189 format=get_broadcast_value("text_format", i), 

190 color=get_broadcast_value("text_color", i), 

191 background_color=get_broadcast_value("text_background_color", i), 

192 justification=get_broadcast_value("text_justification", i), 

193 indent_first=get_broadcast_value("text_indent_first", i), 

194 indent_left=get_broadcast_value("text_indent_left", i), 

195 indent_right=get_broadcast_value("text_indent_right", i), 

196 space=get_broadcast_value("text_space", i), 

197 space_before=get_broadcast_value("text_space_before", i), 

198 space_after=get_broadcast_value("text_space_after", i), 

199 convert=get_broadcast_value("text_convert", i), 

200 hyphenation=get_broadcast_value("text_hyphenation", i), 

201 ) 

202 ) 

203 

204 if method == "paragraph": 

205 return [ 

206 text_component._as_rtf(method="paragraph") 

207 for text_component in text_components 

208 ] 

209 

210 if method == "line": 

211 line = "\\line".join( 

212 [ 

213 text_component._as_rtf(method="plain") 

214 for text_component in text_components 

215 ] 

216 ) 

217 return TextContent( 

218 text=str(line), 

219 font=get_broadcast_value("text_font", i), 

220 size=get_broadcast_value("text_font_size", i), 

221 format=get_broadcast_value("text_format", i), 

222 color=get_broadcast_value("text_color", i), 

223 background_color=get_broadcast_value("text_background_color", i), 

224 justification=get_broadcast_value("text_justification", i), 

225 indent_first=get_broadcast_value("text_indent_first", i), 

226 indent_left=get_broadcast_value("text_indent_left", i), 

227 indent_right=get_broadcast_value("text_indent_right", i), 

228 space=get_broadcast_value("text_space", i), 

229 space_before=get_broadcast_value("text_space_before", i), 

230 space_after=get_broadcast_value("text_space_after", i), 

231 convert=get_broadcast_value("text_convert", i), 

232 hyphenation=get_broadcast_value("text_hyphenation", i), 

233 )._as_rtf(method="paragraph_format") 

234 

235 raise ValueError(f"Invalid method: {method}") 

236 

237 

238class TableAttributes(TextAttributes): 

239 """Base class for table-related attributes in RTF components""" 

240 

241 col_rel_width: list[float] | None = Field( 

242 default=None, description="Relative widths of table columns" 

243 ) 

244 

245 border_left: list[list[str]] = Field( 

246 default=[[""]], description="Left border style" 

247 ) 

248 border_right: list[list[str]] = Field( 

249 default=[[""]], description="Right border style" 

250 ) 

251 border_top: list[list[str]] = Field(default=[[""]], description="Top border style") 

252 border_bottom: list[list[str]] = Field( 

253 default=[[""]], description="Bottom border style" 

254 ) 

255 border_first: list[list[str]] = Field( 

256 default=[[""]], description="First row border style" 

257 ) 

258 border_last: list[list[str]] = Field( 

259 default=[[""]], description="Last row border style" 

260 ) 

261 border_color_left: list[list[str]] = Field( 

262 default=[[""]], description="Left border color" 

263 ) 

264 border_color_right: list[list[str]] = Field( 

265 default=[[""]], description="Right border color" 

266 ) 

267 border_color_top: list[list[str]] = Field( 

268 default=[[""]], description="Top border color" 

269 ) 

270 border_color_bottom: list[list[str]] = Field( 

271 default=[[""]], description="Bottom border color" 

272 ) 

273 border_color_first: list[list[str]] = Field( 

274 default=[[""]], description="First row border color" 

275 ) 

276 border_color_last: list[list[str]] = Field( 

277 default=[[""]], description="Last row border color" 

278 ) 

279 border_width: list[list[int]] = Field( 

280 default=[[15]], description="Border width in twips" 

281 ) 

282 cell_height: list[list[float]] = Field( 

283 default=[[0.15]], description="Cell height in inches" 

284 ) 

285 cell_justification: list[list[str]] = Field( 

286 default=[["l"]], 

287 description="Cell horizontal alignment ('l'=left, 'c'=center, 'r'=right, 'j'=justify)", 

288 ) 

289 

290 cell_vertical_justification: list[list[str]] = Field( 

291 default=[["center"]], 

292 description="Cell vertical alignment ('top', 'center', 'bottom')", 

293 ) 

294 

295 @field_validator("cell_vertical_justification", mode="after") 

296 def validate_cell_vertical_justification(cls, v): 

297 if v is None: 

298 return v 

299 

300 for row in v: 

301 for justification in row: 

302 if justification not in VERTICAL_ALIGNMENT_CODES: 

303 raise ValueError( 

304 f"Invalid cell vertical justification: {justification}" 

305 ) 

306 return v 

307 

308 cell_nrow: list[list[int]] = Field( 

309 default=[[1]], description="Number of rows per cell" 

310 ) 

311 

312 @field_validator("col_rel_width", mode="before") 

313 def convert_to_list(cls, v): 

314 if v is not None and isinstance(v, (int, str, float, bool)): 

315 return [v] 

316 return v 

317 

318 @field_validator( 

319 "border_left", 

320 "border_right", 

321 "border_top", 

322 "border_bottom", 

323 "border_first", 

324 "border_last", 

325 "border_color_left", 

326 "border_color_right", 

327 "border_color_top", 

328 "border_color_bottom", 

329 "border_color_first", 

330 "border_color_last", 

331 "border_width", 

332 "cell_height", 

333 "cell_justification", 

334 "cell_vertical_justification", 

335 "cell_nrow", 

336 mode="before", 

337 ) 

338 def convert_to_nested_list(cls, v): 

339 return _to_nested_list(v) 

340 

341 @field_validator( 

342 "col_rel_width", "border_width", "cell_height", "cell_nrow", mode="after" 

343 ) 

344 def validate_positive_value(cls, v): 

345 if v is not None and np.any(np.array(v) <= 0): 

346 raise ValueError(f"{cls.__field_name__.capitalize()} must be positive") 

347 return v 

348 

349 @field_validator("cell_justification", mode="after") 

350 def validate_cell_justification(cls, v): 

351 if v is None: 

352 return v 

353 

354 for row in v: 

355 for justification in row: 

356 if justification not in TEXT_JUSTIFICATION_CODES: 

357 raise ValueError(f"Invalid cell justification: {justification}") 

358 return v 

359 

360 @field_validator( 

361 "border_left", 

362 "border_right", 

363 "border_top", 

364 "border_bottom", 

365 "border_first", 

366 "border_last", 

367 mode="after", 

368 ) 

369 def validate_border(cls, v): 

370 """Validate that all border styles are valid.""" 

371 if v is None: 

372 return v 

373 

374 for row in v: 

375 for border in row: 

376 if border not in BORDER_CODES: 

377 field_name = cls.__field_name__.capitalize() 

378 raise ValueError( 

379 f"{field_name} with invalid border style: {border}" 

380 ) 

381 

382 return v 

383 

384 def _get_section_attributes(self, indices) -> dict: 

385 """Helper method to collect all attributes for a section""" 

386 # Get all attributes that start with text_, col_, border_, or cell_ 

387 attrs = {} 

388 for attr in dir(self): 

389 if ( 

390 attr.startswith("text_") 

391 or attr.startswith("col_") 

392 or attr.startswith("border_") 

393 or attr.startswith("cell_") 

394 ): 

395 if not callable(getattr(self, attr)): 

396 attrs[attr] = getattr(self, attr) 

397 

398 # Broadcast attributes to section indices 

399 return { 

400 attr: [BroadcastValue(value=val).iloc(row, col) for row, col in indices] 

401 for attr, val in attrs.items() 

402 } 

403 

404 def _encode( 

405 self, df: pd.DataFrame, col_widths: Sequence[float] 

406 ) -> MutableSequence[str]: 

407 dim = df.shape 

408 

409 def get_broadcast_value(attr_name, row_idx, col_idx=0): 

410 """Helper function to get broadcast value for a given attribute at specified indices.""" 

411 attr_value = getattr(self, attr_name) 

412 return BroadcastValue(value=attr_value, dimension=dim).iloc( 

413 row_idx, col_idx 

414 ) 

415 

416 if self.cell_nrow is None: 

417 self.cell_nrow = np.zeros(dim) 

418 

419 for i in range(dim[0]): 

420 for j in range(dim[1]): 

421 text = str(BroadcastValue(value=df, dimension=dim).iloc(i, j)) 

422 font_size = BroadcastValue( 

423 value=self.text_font_size, dimension=dim 

424 ).iloc(i, j) 

425 col_width = BroadcastValue(value=col_widths, dimension=dim).iloc( 

426 i, j 

427 ) 

428 cell_text_width = get_string_width(text=text, font_size=font_size) 

429 self.cell_nrow[i, j] = np.ceil(cell_text_width / col_width) 

430 

431 rows: MutableSequence[str] = [] 

432 for i in range(dim[0]): 

433 row = df.iloc[i] 

434 cells = [] 

435 

436 for j in range(dim[1]): 

437 col = df.columns[j] 

438 

439 if j == dim[1] - 1: 

440 border_right = Border( 

441 style=BroadcastValue( 

442 value=self.border_right, dimension=dim 

443 ).iloc(i, j) 

444 ) 

445 else: 

446 border_right = None 

447 

448 cell = Cell( 

449 text=TextContent( 

450 text=str(row[col]), 

451 font=get_broadcast_value("text_font", i, j), 

452 size=get_broadcast_value("text_font_size", i, j), 

453 format=get_broadcast_value("text_format", i, j), 

454 color=get_broadcast_value("text_color", i, j), 

455 background_color=get_broadcast_value( 

456 "text_background_color", i, j 

457 ), 

458 justification=get_broadcast_value("text_justification", i, j), 

459 indent_first=get_broadcast_value("text_indent_first", i, j), 

460 indent_left=get_broadcast_value("text_indent_left", i, j), 

461 indent_right=get_broadcast_value("text_indent_right", i, j), 

462 space=get_broadcast_value("text_space", i, j), 

463 space_before=get_broadcast_value("text_space_before", i, j), 

464 space_after=get_broadcast_value("text_space_after", i, j), 

465 convert=get_broadcast_value("text_convert", i, j), 

466 hyphenation=get_broadcast_value("text_hyphenation", i, j), 

467 ), 

468 width=col_widths[j], 

469 border_left=Border(style=get_broadcast_value("border_left", i, j)), 

470 border_right=border_right, 

471 border_top=Border(style=get_broadcast_value("border_top", i, j)), 

472 border_bottom=Border( 

473 style=get_broadcast_value("border_bottom", i, j) 

474 ), 

475 vertical_justification=get_broadcast_value( 

476 "cell_vertical_justification", i, j 

477 ), 

478 ) 

479 cells.append(cell) 

480 rtf_row = Row( 

481 row_cells=cells, 

482 justification=get_broadcast_value("cell_justification", i, 0), 

483 height=get_broadcast_value("cell_height", i, 0), 

484 ) 

485 rows.extend(rtf_row._as_rtf()) 

486 

487 return rows 

488 

489 

490class BroadcastValue(BaseModel): 

491 model_config = ConfigDict(arbitrary_types_allowed=True) 

492 

493 value: list[list[Any]] | None = Field( 

494 ..., 

495 description="The value of the table, can be various types including DataFrame.", 

496 ) 

497 

498 dimension: Tuple[int, int] | None = Field( 

499 None, description="Dimensions of the table (rows, columns)" 

500 ) 

501 

502 @field_validator("value", mode="before") 

503 def convert_value(cls, v): 

504 return _to_nested_list(v) 

505 

506 @field_validator("dimension") 

507 def validate_dimension(cls, v): 

508 if v is None: 

509 return v 

510 

511 if not isinstance(v, tuple) or len(v) != 2: 

512 raise TypeError("dimension must be a tuple of (rows, columns)") 

513 

514 rows, cols = v 

515 if not isinstance(rows, int) or not isinstance(cols, int): 

516 raise TypeError("dimension values must be integers") 

517 

518 if rows <= 0 or cols <= 0: 

519 raise ValueError("dimension values must be positive") 

520 

521 return v 

522 

523 def iloc(self, row_index: int, column_index: int) -> Any: 

524 if self.value is None: 

525 return None 

526 

527 try: 

528 return self.value[row_index % len(self.value)][ 

529 column_index % len(self.value[0]) 

530 ] 

531 except IndexError as e: 

532 raise ValueError(f"Invalid DataFrame index or slice: {e}") 

533 

534 def to_list(self) -> pd.DataFrame: 

535 if self.value is None: 

536 return None 

537 

538 row_count, col_count = len(self.value), len(self.value[0]) 

539 

540 row_repeats = max(1, (self.dimension[0] + row_count - 1) // row_count) 

541 col_repeats = max(1, (self.dimension[1] + col_count - 1) // col_count) 

542 

543 value = [column * col_repeats for column in self.value] * row_repeats 

544 return [row[: self.dimension[1]] for row in value[: self.dimension[0]]] 

545 

546 def to_numpy(self) -> np.ndarray: 

547 if self.value is None: 

548 return None 

549 

550 return np.array(self.to_list()) 

551 

552 def to_pandas(self) -> pd.DataFrame: 

553 if self.value is None: 

554 return None 

555 

556 return pd.DataFrame(self.to_list()) 

557 

558 def to_polars(self) -> pl.DataFrame: 

559 if self.value is None: 

560 return None 

561 

562 return pl.DataFrame(self.to_list()) 

563 

564 def update_row(self, row_index: int, row_value: list): 

565 if self.value is None: 

566 return None 

567 

568 self.value = self.to_list() 

569 self.value[row_index] = row_value 

570 return self.value 

571 

572 def update_column(self, column_index: int, column_value: list): 

573 if self.value is None: 

574 return None 

575 

576 self.value = self.to_list() 

577 for i, row in enumerate(self.value): 

578 row[column_index] = column_value[i] 

579 return self.value 

580 

581 def update_cell(self, row_index: int, column_index: int, cell_value: Any): 

582 if self.value is None: 

583 return None 

584 

585 self.value = self.to_list() 

586 self.value[row_index][column_index] = cell_value 

587 return self.value