diff --git a/sbysrc/sby_status.py b/sbysrc/sby_status.py index 3d5c295..fb3fc82 100644 --- a/sbysrc/sby_status.py +++ b/sbysrc/sby_status.py @@ -17,61 +17,52 @@ Fn = TypeVar("Fn", bound=Callable[..., Any]) def transaction(method: Fn) -> Fn: @wraps(method) def wrapper(self: SbyStatusDb, *args: Any, **kwargs: Any) -> Any: - if self._transaction_active: + if self.con.in_transaction: return method(self, *args, **kwargs) try: - self.log_debug(f"begin {method.__name__!r} transaction") - self.db.execute("begin") - self._transaction_active = True - result = method(self, *args, **kwargs) - self.db.execute("commit") - self._transaction_active = False - self.log_debug(f"comitted {method.__name__!r} transaction") - return result - except sqlite3.OperationalError as err: - self.log_debug(f"failed {method.__name__!r} transaction {err}") - self.db.rollback() - self._transaction_active = False + with self.con: + self.log_debug(f"begin {method.__name__!r} transaction") + self.db.execute("begin") + result = method(self, *args, **kwargs) except Exception as err: self.log_debug(f"failed {method.__name__!r} transaction {err}") - self.db.rollback() - self._transaction_active = False - raise + if not isinstance(err, sqlite3.OperationalError): + raise + else: + self.log_debug(f"comitted {method.__name__!r} transaction") + return result + try: - self.log_debug( - f"retrying {method.__name__!r} transaction once in immediate mode" - ) - self.db.execute("begin immediate") - self._transaction_active = True - result = method(self, *args, **kwargs) - self.db.execute("commit") - self._transaction_active = False - self.log_debug(f"comitted {method.__name__!r} transaction") - return result + with self.con: + self.log_debug( + f"retrying {method.__name__!r} transaction once in immediate mode" + ) + self.db.execute("begin immediate") + result = method(self, *args, **kwargs) except Exception as err: self.log_debug(f"failed {method.__name__!r} transaction {err}") - self.db.rollback() - self._transaction_active = False raise + else: + self.log_debug(f"comitted {method.__name__!r} transaction") + return result return wrapper # type: ignore class SbyStatusDb: def __init__(self, path: Path, task, timeout: float = 5.0): - self.debug = False + self.debug = True self.task = task - self._transaction_active = False setup = not os.path.exists(path) - self.db = sqlite3.connect(path, isolation_level=None, timeout=timeout) + self.con = sqlite3.connect(path, isolation_level=None, timeout=timeout) + self.db = self.con.cursor() self.db.row_factory = sqlite3.Row - cur = self.db.cursor() - cur.execute("PRAGMA journal_mode=WAL") - cur.execute("PRAGMA synchronous=0") - self.db.commit() + with self.con: + self.db.execute("PRAGMA journal_mode=WAL") + self.db.execute("PRAGMA synchronous=0") if setup: self._setup() @@ -245,7 +236,6 @@ class SbyStatusDb: ), ) - @transaction def all_tasks(self): rows = self.db.execute( """ @@ -255,7 +245,6 @@ class SbyStatusDb: return {row["id"]: dict(row) for row in rows} - @transaction def all_task_properties(self): rows = self.db.execute( """ @@ -271,7 +260,6 @@ class SbyStatusDb: return {row["id"]: get_result(row) for row in rows} - @transaction def all_task_property_statuses(self): rows = self.db.execute( """ @@ -287,7 +275,6 @@ class SbyStatusDb: return {row["id"]: get_result(row) for row in rows} - @transaction def all_status_data(self): return ( self.all_tasks(),