initial commit

This commit is contained in:
Jacob Lifshay 2024-10-24 21:42:22 -07:00
commit 77b0ce2c3d
Signed by: programmerjake
SSH key fingerprint: SHA256:B1iRVvUJkvd7upMIiMqn6OyxvD2SgJkAH3ZnUOj6z+c
8 changed files with 1289 additions and 0 deletions

View file

View file

@ -0,0 +1,392 @@
from __future__ import annotations
from collections import defaultdict
from dataclasses import dataclass, field
from functools import cached_property
import sys
from typing import ClassVar, Iterable, Iterator, NewType, TypeAlias, TypeVar, assert_never, overload
from xml.etree import ElementTree
import enum
from pdfminer.high_level import extract_pages
from pdfminer.layout import LTChar, LTLine, LTPage, LTRect, LTTextBox
from parse_powerisa_pdf.quad_tree import QuadTree
from parse_powerisa_pdf.set_by_id import SetById
@dataclass(unsafe_hash=True, frozen=True)
class Font:
font_name: str
size: float
__KNOWN_NAMES: ClassVar[dict[Font, str]]
@cached_property
def space_width(self) -> float:
match self:
case Font.INSTR_HEADER:
return 3.12
case _:
return self.size * 0.31
@cached_property
def line_height(self) -> float:
match self:
case Font.INSTR_HEADER:
return 10.961
case _:
return self.size * 1.1
@classmethod
def __iter__(cls) -> Iterator[Font]:
return iter(cls.__KNOWN_NAMES.keys())
@property
def known_name(self) -> None | str:
return self.__KNOWN_NAMES.get(self)
@classmethod
def _register_known_fonts(cls) -> None:
cls.INSTR_HEADER = Font(font_name='YDJYQV+DejaVuSansCondensed-BoldOblique', size=9.963)
cls.PAGE_HEADER = Font(font_name='MJBFWM+DejaVuSansCondensed', size=9.963)
cls.PAGE_FOOTER = Font(font_name='MJBFWM+DejaVuSansCondensed', size=4.981)
cls.INSTR_DESC = Font(font_name='MJBFWM+DejaVuSansCondensed', size=8.966)
cls.INSTR_DESC_ITALIC = Font(font_name='CGMSHV+DejaVuSansCondensed-Oblique', size=8.966)
cls.INSTR_DESC_BOLD = Font(font_name='NHUPPK+DejaVuSansCondensed-Bold', size=8.966)
cls.INSTR_DESC_BOLD_ITALIC = Font(font_name='YDJYQV+DejaVuSansCondensed-BoldOblique', size=8.966)
cls.INSTR_DESC_SUBSCRIPT = Font(font_name='MJBFWM+DejaVuSansCondensed', size=5.978)
cls.INSTR_FIELD_BIT_NUMS = Font(font_name='MJBFWM+DejaVuSansCondensed', size=7.97)
cls.INSTR_EXT_MNEMONIC = Font(font_name='APUYSQ+zcoN-Regular', size=8.966)
cls.INSTR_CODE = Font(font_name='APUYSQ+zcoN-Regular', size=7.97)
cls.INSTR_CODE_SYM = Font(font_name='RRFUNA+CMSY8', size=7.97)
cls.INSTR_CODE_NE_EQ_SIGN = Font(font_name='HPXOZC+CMSS8', size=7.97)
cls.INSTR_CODE_SUBSCRIPT = Font(font_name='APUYSQ+zcoN-Regular', size=5.978)
cls.__KNOWN_NAMES = {}
for name, value in cls.__dict__.items():
if name[0].isupper() and isinstance(value, cls):
assert value not in cls.__KNOWN_NAMES, f"duplicate known font: {value}"
cls.__KNOWN_NAMES[value] = name
old_repr = cls.__repr__
def __repr__(self: cls) -> str:
known_name = self.known_name
if known_name is not None:
return f"<{self.__class__.__name__}.{known_name}: {old_repr(self)}>"
return old_repr(self)
cls.__repr__ = __repr__
del cls._register_known_fonts
Font._register_known_fonts()
@dataclass(unsafe_hash=True, frozen=True)
class Char:
font: Font
text: str
adv: float
min_x: float
min_y: float
max_x: float
max_y: float
def top_down_left_to_right_sort_key(self):
return -self.min_y, self.min_x
@property
def width(self) -> float:
return self.max_x - self.min_x
@property
def height(self) -> float:
return self.max_y - self.min_y
@dataclass()
class Parser:
def parse_pdf(self, file: str, page_numbers: range | None = None):
for page in extract_pages(file, page_numbers=page_numbers):
PageParser(parser=self, page_id=page.pageid).parse_page(page)
COLUMN_SPLIT_X = 300.0
@dataclass()
class ParsedTextLine:
element: ElementTree.Element
regular_min_y: float
fonts: TextLineFonts
chars: list[Char]
def __str__(self) -> str:
return ElementTree.tostring(self.element, encoding="unicode")
_T = TypeVar("_T")
@dataclass(unsafe_hash=True, frozen=True)
class TextLineFonts:
regular: Font
italic: Font | None = None
bold: Font | None = None
bold_italic: Font | None = None
def get_font(self, part_kind: TextLineFontKind, default: _T=None) -> _T | Font:
match part_kind:
case TextLineFontKind.REGULAR:
retval = self.regular
case TextLineFontKind.ITALIC:
retval = self.italic
case TextLineFontKind.BOLD:
retval = self.bold
case TextLineFontKind.BOLD_ITALIC:
retval = self.bold_italic
case _:
assert_never(part_kind)
if retval is None:
return default
return retval
@cached_property
def __font_to_kind_map(self) -> dict[Font, TextLineFontKind]:
retval = {}
for kind in TextLineFontKind:
font = self.get_font(kind)
if font is None:
continue
assert font not in retval, \
f"duplicate font: kind={kind} old_kind={retval[font]} font={font}"
retval[font] = kind
return retval
def get_kind(self, font: Font, default: _T=None) -> _T | TextLineFontKind:
return self.__font_to_kind_map.get(font, default)
class TextLineFontKind(enum.Enum):
REGULAR = "regular"
ITALIC = "italic"
BOLD = "bold"
BOLD_ITALIC = "bold_italic"
@cached_property
def text_line_tags(self) -> tuple[str, ...]:
match self:
case TextLineFontKind.REGULAR:
return ()
case TextLineFontKind.ITALIC:
return "i",
case TextLineFontKind.BOLD:
return "b",
case TextLineFontKind.BOLD_ITALIC:
return "b", "i"
case _:
assert_never(self)
class PageParseFailed(Exception):
pass
class ElementBodyBuilder:
def __init__(self, containing_element: ElementTree.Element):
self.__containing_element = containing_element
self.__stack: list[ElementTree.Element] = []
self.__text_buffer: list[str] = []
def __shrink_stack(self, new_len: int):
while new_len < len(self.__stack):
self.__flush_text_buffer()
self.__stack.pop()
def set_tag_stack(self, tag_stack: Iterable[str]):
new_len = 0
for i, tag in enumerate(tag_stack):
new_len = i + 1
if i >= len(self.__stack):
self.__flush_text_buffer()
self.__stack.append(ElementTree.SubElement(self.__insert_point(), tag))
elif self.__stack[i].tag != tag:
self.__shrink_stack(new_len)
self.__shrink_stack(new_len)
def write_text(self, text: str):
self.__text_buffer.append(text)
def __insert_point(self) -> ElementTree.Element:
if len(self.__stack) != 0:
return self.__stack[-1]
return self.__containing_element
def __flush_text_buffer(self):
if len(self.__text_buffer) == 0:
return
insert_point = self.__insert_point()
text = "".join(self.__text_buffer)
self.__text_buffer.clear()
if len(insert_point) != 0:
element = insert_point[-1]
element.tail = (element.tail or "") + text
else:
insert_point.text = (insert_point.text or "") + text
def __enter__(self) -> ElementBodyBuilder:
return self
def __exit__(self, exc_type, exc_value, traceback):
self.flush()
def flush(self):
self.set_tag_stack(())
self.__flush_text_buffer()
@dataclass()
class PageParser:
parser: Parser
page_id: int
qt: QuadTree[Char | LTLine | LTRect] = field(default_factory=QuadTree)
unprocessed_chars: defaultdict[Font, SetById[Char]] = field(
default_factory=lambda: defaultdict(SetById[Char]))
unprocessed_non_text: SetById[LTLine | LTRect] = field(
default_factory=SetById[LTLine | LTRect])
def parse_page(self, page: LTPage):
for component in page:
if isinstance(component, (LTLine, LTRect)):
self.qt.insert(component.x0, component.y0, component)
continue
if not isinstance(component, LTTextBox):
print(f"ignoring: {component}")
continue
for text_line in component:
for element in text_line:
if not isinstance(element, LTChar):
continue
char = Char(
text=element.get_text(),
font=Font(font_name=element.fontname, size=round(element.size, 3)),
adv=element.adv,
min_x=element.x0,
min_y=element.y0,
max_x=element.x1,
max_y=element.y1,
)
self.qt.insert(char.min_x, char.min_y, char)
self.unprocessed_chars[char.font].add(char)
for i in self.unprocessed_chars.values():
i.sort(key=Char.top_down_left_to_right_sort_key)
for font, chars in self.unprocessed_chars.items():
print()
print(font)
text = ""
char = None
for char in chars:
text += char.text
print(repr(text))
assert font.known_name is not None, f"unknown font {font}\nlast char: {char}"
self.extract_instructions()
def extract_text_line(
self, *,
start_char: None | Char = None,
start_min_y: float,
min_x: float,
max_x: float,
fonts: TextLineFonts,
) -> None | ParsedTextLine:
chars: list[Char] = []
if start_char is not None:
chars.append(start_char)
self.unprocessed_chars[start_char.font].remove(start_char)
for x, y, char in self.qt.range(
min_x=min_x,
max_x=max_x,
min_y=start_min_y - fonts.regular.size * 0.5,
max_y=start_min_y + fonts.regular.size * 0.5,
):
if not isinstance(char, Char):
continue
if char not in self.unprocessed_chars[char.font]:
continue
self.unprocessed_chars[char.font].remove(char)
chars.append(char)
if len(chars) == 0:
return None
chars.sort(key=Char.top_down_left_to_right_sort_key)
retval = ParsedTextLine(
element=ElementTree.Element("text-line"),
regular_min_y=chars[0].min_y,
fonts=fonts,
chars=chars,
)
with ElementBodyBuilder(retval.element) as body_builder:
last_max_x = min_x
last_kind = None
for char in chars:
kind = fonts.get_kind(char.font)
if kind is None:
return None
if last_kind is None:
space_kind = kind
elif last_kind != kind:
space_kind = TextLineFontKind.REGULAR
else:
space_kind = kind
space_font = fonts.get_font(space_kind, fonts.regular)
space_width = char.min_x - last_max_x
space_count_f = space_width / space_font.space_width
space_count = round(space_count_f)
if space_count_f > 0.25 and abs(space_count - space_count_f) > 0.15:
print(f"spaces: space_count_f={space_count_f} space_width={space_width}")
if space_count > 0:
body_builder.set_tag_stack(space_kind.text_line_tags)
body_builder.write_text(" " * space_count)
body_builder.set_tag_stack(kind.text_line_tags)
body_builder.write_text(char.text)
last_max_x = char.max_x
last_kind = kind
return retval
def extract_following_text_lines(
self,
first_text_line: ParsedTextLine,
min_x: float,
max_x: float,
) -> list[ParsedTextLine]:
retval: list[ParsedTextLine] = []
line = first_text_line
while line is not None:
retval.append(line)
line = self.extract_text_line(
start_min_y=line.regular_min_y - first_text_line.fonts.regular.line_height,
min_x=min_x,
max_x=max_x,
fonts=first_text_line.fonts,
)
return retval
def extract_instruction(self, header_start_char: Char):
assert header_start_char.font == Font.INSTR_HEADER
if header_start_char.min_x < COLUMN_SPLIT_X:
column_max_x = COLUMN_SPLIT_X
else:
column_max_x = 1000
header_text_line = self.extract_text_line(
start_char=header_start_char,
start_min_y=header_start_char.min_y,
min_x=header_start_char.min_x,
max_x=column_max_x,
fonts=TextLineFonts(regular=Font.INSTR_HEADER),
)
if header_text_line is None:
raise PageParseFailed("can't find header text line")
print(header_text_line)
header_lines = self.extract_following_text_lines(
first_text_line=header_text_line,
min_x=header_start_char.min_x,
max_x=column_max_x,
)
print(*header_lines)
# TODO: finish
def extract_instructions(self):
unprocessed_header_chars = self.unprocessed_chars[Font.INSTR_HEADER]
while len(unprocessed_header_chars) != 0:
self.extract_instruction(next(iter(unprocessed_header_chars)))
def main():
Parser().parse_pdf(sys.argv[1], page_numbers=range(76, 78))

View file

@ -0,0 +1,629 @@
from __future__ import annotations
from typing import Callable, Generic, Iterable, Iterator, TypeVar
from math import frexp, isfinite, isnan, ldexp
import unittest
_V = TypeVar("_V")
class _QuadTreeNode(Generic[_V]):
""" node in a quad-tree extending from `-(2 ** log2_size)` to
`2 ** log2_size` exclusive on both sides in both `x` and `y`
dimensions.
invariants:
* if `self.values is not None` then all of
`nx_ny`, `px_ny`, `nx_py`, and `px_py` are `None`.
* the `log2_size` field of all of `nx_ny`, `px_ny`, `nx_py`, and `px_py`
is always `self.log2_size - 1`.
"""
__slots__ = "log2_size", "nx_ny", "px_ny", "nx_py", "px_py", "values"
log2_size: int
nx_ny: None | _QuadTreeNode[_V]
px_ny: None | _QuadTreeNode[_V]
nx_py: None | _QuadTreeNode[_V]
px_py: None | _QuadTreeNode[_V]
values: None | list[tuple[float, float, _V]]
def __init__(
self, *,
log2_size: int,
nx_ny: None | _QuadTreeNode[_V],
px_ny: None | _QuadTreeNode[_V],
nx_py: None | _QuadTreeNode[_V],
px_py: None | _QuadTreeNode[_V],
values: None | list[tuple[float, float, _V]],
):
self.log2_size = log2_size
self.nx_ny = nx_ny
self.px_ny = px_ny
self.nx_py = nx_py
self.px_py = px_py
self.values = values
@staticmethod
def leaf(
log2_size: int,
values: None | list[tuple[float, float, _V]] = None,
) -> _QuadTreeNode[_V]:
if values is None:
values = []
return _QuadTreeNode(
log2_size=log2_size,
nx_ny=None,
px_ny=None,
nx_py=None,
px_py=None,
values=values,
)
@staticmethod
def interior(
log2_size: int, *,
nx_ny: None | _QuadTreeNode[_V] = None,
px_ny: None | _QuadTreeNode[_V] = None,
nx_py: None | _QuadTreeNode[_V] = None,
px_py: None | _QuadTreeNode[_V] = None,
) -> _QuadTreeNode[_V]:
return _QuadTreeNode(
log2_size=log2_size,
nx_ny=nx_ny,
px_ny=px_ny,
nx_py=nx_py,
px_py=px_py,
values=None,
)
def __iter__(self) -> Iterator[tuple[float, float, _V]]:
if self.nx_ny is not None:
yield from self.nx_ny
if self.px_ny is not None:
yield from self.px_ny
if self.nx_py is not None:
yield from self.nx_py
if self.px_py is not None:
yield from self.px_py
if self.values is not None:
yield from self.values
def __repr__(self, indent=0) -> str:
nl_indent = '\n' + ' ' * (4 * indent)
parts = [f" log2_size={self.log2_size}"]
if self.nx_ny is not None:
parts.append(f" nx_ny={self.nx_ny.__repr__(indent + 1)}")
if self.px_ny is not None:
parts.append(f" px_ny={self.px_ny.__repr__(indent + 1)}")
if self.nx_py is not None:
parts.append(f" nx_py={self.nx_py.__repr__(indent + 1)}")
if self.px_py is not None:
parts.append(f" px_py={self.px_py.__repr__(indent + 1)}")
fn = "interior"
if self.values is not None:
if len(self.values) == 0:
parts.append(" values=[]")
else:
prefix = " values=[" + nl_indent
for i in self.values:
parts.append(f"{prefix} {i}")
prefix = ""
parts.append(" ]")
fn = "leaf"
parts.append(")")
sep = ',' + nl_indent
return f"{fn}({nl_indent}{sep.join(parts)}"
def split(
self,
split_x: float,
split_y: float,
delta: float,
max_leaf_size: int, *,
trace: None | Callable[[str]],
):
assert self.values is not None, "can't split interior node"
if len(self.values) <= max_leaf_size:
return
values = self.values
self.values = None # convert to interior node
for i in values:
x, y, _ = i
if x < split_x:
if y < split_y:
if self.nx_ny is None:
self.nx_ny = _QuadTreeNode.leaf(log2_size=self.log2_size - 1)
assert self.nx_ny.values is not None
self.nx_ny.values.append(i)
else:
if self.nx_py is None:
self.nx_py = _QuadTreeNode.leaf(log2_size=self.log2_size - 1)
assert self.nx_py.values is not None
self.nx_py.values.append(i)
else:
if y < split_y:
if self.px_ny is None:
self.px_ny = _QuadTreeNode.leaf(log2_size=self.log2_size - 1)
assert self.px_ny.values is not None
self.px_ny.values.append(i)
else:
if self.px_py is None:
self.px_py = _QuadTreeNode.leaf(log2_size=self.log2_size - 1)
assert self.px_py.values is not None
self.px_py.values.append(i)
if trace is not None:
trace(f"split leaf split_x={split_x} "
f"split_y={split_y} delta={delta} self={self}")
if self.nx_ny is not None:
self.nx_ny.split(
split_x=split_x - delta,
split_y=split_y - delta,
delta=delta * 0.5,
max_leaf_size=max_leaf_size,
trace=trace,
)
if self.nx_py is not None:
self.nx_py.split(
split_x=split_x - delta,
split_y=split_y + delta,
delta=delta * 0.5,
max_leaf_size=max_leaf_size,
trace=trace,
)
if self.px_ny is not None:
self.px_ny.split(
split_x=split_x + delta,
split_y=split_y - delta,
delta=delta * 0.5,
max_leaf_size=max_leaf_size,
trace=trace,
)
if self.px_py is not None:
self.px_py.split(
split_x=split_x + delta,
split_y=split_y + delta,
delta=delta * 0.5,
max_leaf_size=max_leaf_size,
trace=trace,
)
def range(
self, *,
min_x: float,
max_x: float,
min_y: float,
max_y: float,
node_min_x: float,
node_max_x: float,
node_min_y: float,
node_max_y: float,
trace: None | Callable[[str]],
) -> Iterable[tuple[float, float, _V]]:
if self.values is not None:
# leaf node
if trace is not None:
trace(
f"range(\n"
f" min_x={min_x}, max_x={max_x},\n"
f" min_y={min_y}, max_y={max_y},\n"
f" node_min_x={node_min_x}, node_max_x={node_max_x},\n"
f" node_min_y={node_min_y}, node_max_y={node_max_y},\n"
f"): leaf {self}"
)
for i in self.values:
x, y, _ = i
if min_x <= x <= max_x and min_y <= y <= max_y:
if trace is not None:
trace(f"yielding {i}")
yield i
else:
if trace is not None:
trace(f"skipping {i}")
return
# interior node
split_x = (node_min_x + node_max_x) * 0.5
split_y = (node_min_y + node_max_y) * 0.5
if trace is not None:
trace(
f"range(\n"
f" min_x={min_x}, max_x={max_x},\n"
f" min_y={min_y}, max_y={max_y},\n"
f" node_min_x={node_min_x}, node_max_x={node_max_x},\n"
f" node_min_y={node_min_y}, node_max_y={node_max_y},\n"
f"): interior split_x={split_x} split_y={split_y} {self}"
)
if self.nx_ny is not None and min_x <= split_x and min_y <= split_y:
yield from self.nx_ny.range(
min_x=min_x,
max_x=max_x,
min_y=min_y,
max_y=max_y,
node_min_x=node_min_x,
node_max_x=split_x,
node_min_y=node_min_y,
node_max_y=split_y,
trace=trace,
)
if self.nx_py is not None and min_x <= split_x and max_y >= split_y:
yield from self.nx_py.range(
min_x=min_x,
max_x=max_x,
min_y=min_y,
max_y=max_y,
node_min_x=node_min_x,
node_max_x=split_x,
node_min_y=split_y,
node_max_y=node_max_y,
trace=trace,
)
if self.px_ny is not None and max_x >= split_x and min_y <= split_y:
yield from self.px_ny.range(
min_x=min_x,
max_x=max_x,
min_y=min_y,
max_y=max_y,
node_min_x=split_x,
node_max_x=node_max_x,
node_min_y=node_min_y,
node_max_y=split_y,
trace=trace,
)
if self.px_py is not None and max_x >= split_x and max_y >= split_y:
yield from self.px_py.range(
min_x=min_x,
max_x=max_x,
min_y=min_y,
max_y=max_y,
node_min_x=split_x,
node_max_x=node_max_x,
node_min_y=split_y,
node_max_y=node_max_y,
trace=trace,
)
class QuadTree(Generic[_V]):
root: _QuadTreeNode[_V]
max_leaf_size: int
def __init__(
self,
v: Iterable[tuple[float, float, _V]] = (), *,
max_leaf_size = 16,
):
self.max_leaf_size = max_leaf_size
if isinstance(v, _QuadTreeNode):
self.root = v
else:
self.root = _QuadTreeNode.leaf(log2_size=0)
for x, y, value in v:
self.insert(x, y, value)
def __repr__(self):
return f"QuadTree({self.root})"
def __iter__(self) -> Iterator[tuple[float, float, _V]]:
return self.root.__iter__()
def expand_root(self, target_log2_size: int, *, trace: None | Callable[[str]] = None):
root = self.root
if target_log2_size > root.log2_size and root.values is not None:
# leaf node -- just expand in place
if trace is not None:
trace(f"expand_root({target_log2_size}) leaf")
root.log2_size = target_log2_size
while target_log2_size > root.log2_size:
if trace is not None:
trace(f"expand_root({target_log2_size}) interior: {root}")
if root.nx_ny is not None:
root.nx_ny = _QuadTreeNode.interior(
log2_size=root.log2_size,
px_py=root.nx_ny,
)
if root.px_ny is not None:
root.px_ny = _QuadTreeNode.interior(
log2_size=root.log2_size,
nx_py=root.px_ny,
)
if root.nx_py is not None:
root.nx_py = _QuadTreeNode.interior(
log2_size=root.log2_size,
px_ny=root.nx_py,
)
if root.px_py is not None:
root.px_py = _QuadTreeNode.interior(
log2_size=root.log2_size,
nx_ny=root.px_py,
)
root.log2_size += 1
if trace is not None:
trace(f"expand_root({target_log2_size}) done: {root}")
def range(
self,
min_x: float,
max_x: float,
min_y: float,
max_y: float, *,
trace: None | Callable[[str]] = None,
) -> Iterable[tuple[float, float, _V]]:
assert not (isnan(min_x) or isnan(max_x)
or isnan(min_y) or isnan(max_y))
if min_x > max_x or min_y > max_y:
return ()
size = ldexp(1.0, self.root.log2_size)
return self.root.range(
min_x=min_x,
max_x=max_x,
min_y=min_y,
max_y=max_y,
node_min_x=-size,
node_max_x=size,
node_min_y=-size,
node_max_y=size,
trace=trace,
)
def insert(self, x: float, y: float, value: _V, *, trace: None | Callable[[str]] = None):
assert isfinite(x) and isfinite(y), "invalid coordinates"
if trace is not None:
trace(f"insert({x}, {y}, {value!r})")
abs_x = abs(x)
abs_y = abs(y)
abs_max = max(abs_x, abs_y)
_, log2_abs_max = frexp(abs_max)
if trace is not None:
trace(f"insert({x}, {y}, _): log2_abs_max={log2_abs_max}")
self.expand_root(log2_abs_max, trace=trace)
node = self.root
split_x = 0.0
split_y = 0.0
delta = ldexp(0.5, node.log2_size)
# walk down tree
while node.values is None:
if trace is not None:
trace(f"insert({x}, {y}, _): interior: split_x={split_x} "
f"split_y={split_y} delta={delta} node={node}")
if x < split_x:
if y < split_y:
if node.nx_ny is None:
node.nx_ny = _QuadTreeNode.leaf(node.log2_size - 1, [(x, y, value)])
if trace is not None:
trace(f"insert({x}, {y}, _): insert nx_ny leaf: node={node}")
return
node = node.nx_ny
split_y -= delta
else:
if node.nx_py is None:
node.nx_py = _QuadTreeNode.leaf(node.log2_size - 1, [(x, y, value)])
if trace is not None:
trace(f"insert({x}, {y}, _): insert nx_py leaf: node={node}")
return
node = node.nx_py
split_y += delta
split_x -= delta
else:
if y < split_y:
if node.px_ny is None:
node.px_ny = _QuadTreeNode.leaf(node.log2_size - 1, [(x, y, value)])
if trace is not None:
trace(f"insert({x}, {y}, _): insert px_ny leaf: node={node}")
return
node = node.px_ny
split_y -= delta
else:
if node.px_py is None:
node.px_py = _QuadTreeNode.leaf(node.log2_size - 1, [(x, y, value)])
if trace is not None:
trace(f"insert({x}, {y}, _): insert px_py leaf: node={node}")
return
node = node.px_py
split_y += delta
split_x += delta
delta *= 0.5
# got to a leaf
node.values.append((x, y, value))
if trace is not None:
trace(f"insert({x}, {y}, _): leaf node={node}")
node.split(
split_x=split_x,
split_y=split_y,
delta=delta,
max_leaf_size=self.max_leaf_size,
trace=trace,
)
class _QuadTreeTest(unittest.TestCase):
def test_quad_tree_range(self):
items = []
for x in range(8):
x += 0.5
x /= 4
for y in range(8):
y += 0.5
y /= 4
items.append((x, y, f"{x},{y}"))
items.sort()
qt = QuadTree(items)
def check_range(min_x: float, max_x: float, min_y: float, max_y: float):
with self.subTest(min_x=min_x, max_x=max_x, min_y=min_y, max_y=max_y):
expected = list(filter(
lambda i: min_x <= i[0] <= max_x and min_y <= i[1] <= max_y, items))
range_result = sorted(qt.range(
min_x=min_x, max_x=max_x, min_y=min_y, max_y=max_y, trace=print))
print(f"expected={expected}")
print(f"range_result={range_result}")
self.assertEqual(expected, range_result)
check_range(-1, 1, -1, 1)
check_range(0.125, 0.25, 0.125, 0.25)
check_range(items[0][0], items[0][0], items[0][1], items[0][1])
check_range(1, 0, -1, 1)
def test_quad_tree_insert(self):
qt = QuadTree(max_leaf_size=4)
for x in range(8):
x += 0.5
x /= 4
for y in range(8):
y += 0.5
y /= 4
qt.insert(x, y, f"{x},{y}")
self.assertEqual(repr(qt), """QuadTree(interior(
log2_size=1,
px_py=interior(
log2_size=0,
nx_ny=interior(
log2_size=-1,
nx_ny=leaf(
log2_size=-2,
values=[
(0.125, 0.125, '0.125,0.125'),
(0.125, 0.375, '0.125,0.375'),
(0.375, 0.125, '0.375,0.125'),
(0.375, 0.375, '0.375,0.375'),
],
),
px_ny=leaf(
log2_size=-2,
values=[
(0.625, 0.125, '0.625,0.125'),
(0.625, 0.375, '0.625,0.375'),
(0.875, 0.125, '0.875,0.125'),
(0.875, 0.375, '0.875,0.375'),
],
),
nx_py=leaf(
log2_size=-2,
values=[
(0.125, 0.625, '0.125,0.625'),
(0.125, 0.875, '0.125,0.875'),
(0.375, 0.625, '0.375,0.625'),
(0.375, 0.875, '0.375,0.875'),
],
),
px_py=leaf(
log2_size=-2,
values=[
(0.625, 0.625, '0.625,0.625'),
(0.625, 0.875, '0.625,0.875'),
(0.875, 0.625, '0.875,0.625'),
(0.875, 0.875, '0.875,0.875'),
],
),
),
px_ny=interior(
log2_size=-1,
nx_ny=leaf(
log2_size=-2,
values=[
(1.125, 0.125, '1.125,0.125'),
(1.125, 0.375, '1.125,0.375'),
(1.375, 0.125, '1.375,0.125'),
(1.375, 0.375, '1.375,0.375'),
],
),
px_ny=leaf(
log2_size=-2,
values=[
(1.625, 0.125, '1.625,0.125'),
(1.625, 0.375, '1.625,0.375'),
(1.875, 0.125, '1.875,0.125'),
(1.875, 0.375, '1.875,0.375'),
],
),
nx_py=leaf(
log2_size=-2,
values=[
(1.125, 0.625, '1.125,0.625'),
(1.125, 0.875, '1.125,0.875'),
(1.375, 0.625, '1.375,0.625'),
(1.375, 0.875, '1.375,0.875'),
],
),
px_py=leaf(
log2_size=-2,
values=[
(1.625, 0.625, '1.625,0.625'),
(1.625, 0.875, '1.625,0.875'),
(1.875, 0.625, '1.875,0.625'),
(1.875, 0.875, '1.875,0.875'),
],
),
),
nx_py=interior(
log2_size=-1,
nx_ny=leaf(
log2_size=-2,
values=[
(0.125, 1.125, '0.125,1.125'),
(0.125, 1.375, '0.125,1.375'),
(0.375, 1.125, '0.375,1.125'),
(0.375, 1.375, '0.375,1.375'),
],
),
px_ny=leaf(
log2_size=-2,
values=[
(0.625, 1.125, '0.625,1.125'),
(0.625, 1.375, '0.625,1.375'),
(0.875, 1.125, '0.875,1.125'),
(0.875, 1.375, '0.875,1.375'),
],
),
nx_py=leaf(
log2_size=-2,
values=[
(0.125, 1.625, '0.125,1.625'),
(0.125, 1.875, '0.125,1.875'),
(0.375, 1.625, '0.375,1.625'),
(0.375, 1.875, '0.375,1.875'),
],
),
px_py=leaf(
log2_size=-2,
values=[
(0.625, 1.625, '0.625,1.625'),
(0.625, 1.875, '0.625,1.875'),
(0.875, 1.625, '0.875,1.625'),
(0.875, 1.875, '0.875,1.875'),
],
),
),
px_py=interior(
log2_size=-1,
nx_ny=leaf(
log2_size=-2,
values=[
(1.125, 1.125, '1.125,1.125'),
(1.125, 1.375, '1.125,1.375'),
(1.375, 1.125, '1.375,1.125'),
(1.375, 1.375, '1.375,1.375'),
],
),
px_ny=leaf(
log2_size=-2,
values=[
(1.625, 1.125, '1.625,1.125'),
(1.625, 1.375, '1.625,1.375'),
(1.875, 1.125, '1.875,1.125'),
(1.875, 1.375, '1.875,1.375'),
],
),
nx_py=leaf(
log2_size=-2,
values=[
(1.125, 1.625, '1.125,1.625'),
(1.125, 1.875, '1.125,1.875'),
(1.375, 1.625, '1.375,1.625'),
(1.375, 1.875, '1.375,1.875'),
],
),
px_py=leaf(
log2_size=-2,
values=[
(1.625, 1.625, '1.625,1.625'),
(1.625, 1.875, '1.625,1.875'),
(1.875, 1.625, '1.875,1.625'),
(1.875, 1.875, '1.875,1.875'),
],
),
),
),
))""")

View file

@ -0,0 +1,78 @@
from collections import abc
from typing import Callable, Generic, Iterable, Iterator, Protocol, TypeAlias, TypeVar, overload
_T = TypeVar("_T")
_T_contra = TypeVar("_T_contra", contravariant=True)
class _SupportsLT(Protocol[_T_contra]):
def __lt__(self, other: _T_contra, /) -> bool: ...
class _SupportsGT(Protocol[_T_contra]):
def __gt__(self, other: _T_contra, /) -> bool: ...
_SupportsRichComparison: TypeAlias = _SupportsLT | _SupportsGT
class SetById(abc.MutableSet[_T], Generic[_T]):
__slots__ = "__data",
def __init__(self, items: None | Iterable[_T]=None, /):
if items is None:
items = ()
self.__data = {id(i): i for i in items}
def __contains__(self, x: object) -> bool:
return id(x) in self.__data
def __iter__(self) -> Iterator[_T]:
return iter(self.__data.values())
def __len__(self) -> int:
return len(self.__data)
def add(self, value: _T) -> None:
self.__data[id(value)] = value
def discard(self, value: _T) -> None:
key = id(value)
data = self.__data
if key in data:
del data[key]
def remove(self, value: _T) -> None:
try:
del self.__data[id(value)]
except KeyError:
raise KeyError(value) from None
def pop(self) -> _T:
return self.__data.popitem()[1]
def clear(self) -> None:
return self.__data.clear()
def __repr__(self) -> str:
if len(self.__data) == 0:
return "SetById()"
return f"SetById([{', '.join(map(repr, self))}])"
@overload
def sort(
self, *,
key: None = None,
reverse: bool = False,
): ...
@overload
def sort(
self, *,
key: Callable[[_T], _SupportsRichComparison],
reverse: bool = False,
): ...
def sort(
self, *,
key: None | Callable[[_T], _SupportsRichComparison] = None,
reverse: bool = False,
):
adj_key = None if key is None else lambda kv: key(kv[1])
self.__data = dict(sorted(self.__data.items(), key=adj_key, reverse=reverse))