diff --git a/sbysrc/sby.py b/sbysrc/sby.py index 31835be..dc02971 100644 --- a/sbysrc/sby.py +++ b/sbysrc/sby.py @@ -85,12 +85,14 @@ if status_show or status_reset: status_db = SbyStatusDb(status_path, task=None) - if status_show: - status_db.print_status_summary() - sys.exit(0) - if status_reset: status_db.reset() + elif status_db.test_schema(): + print(f"ERROR: Status database does not match expected formatted. Use --statusreset to reset.") + sys.exit(1) + + if status_show: + status_db.print_status_summary() status_db.db.close() sys.exit(0) diff --git a/sbysrc/sby_status.py b/sbysrc/sby_status.py index e154205..cb91174 100644 --- a/sbysrc/sby_status.py +++ b/sbysrc/sby_status.py @@ -4,6 +4,7 @@ import sqlite3 import os import time import json +import re from collections import defaultdict from functools import wraps from pathlib import Path @@ -13,6 +14,45 @@ from sby_design import SbyProperty, pretty_path Fn = TypeVar("Fn", bound=Callable[..., Any]) +SQLSCRIPT = """\ +CREATE TABLE task ( + id INTEGER PRIMARY KEY, + workdir TEXT, + mode TEXT, + created REAL +); +CREATE TABLE task_status ( + id INTEGER PRIMARY KEY, + task INTEGER, + status TEXT, + data TEXT, + created REAL, + FOREIGN KEY(task) REFERENCES task(id) +); +CREATE TABLE task_property ( + id INTEGER PRIMARY KEY, + task INTEGER, + src TEXT, + name TEXT, + created REAL, + FOREIGN KEY(task) REFERENCES task(id) +); +CREATE TABLE task_property_status ( + id INTEGER PRIMARY KEY, + task_property INTEGER, + status TEXT, + data TEXT, + created REAL, + FOREIGN KEY(task_property) REFERENCES task_property(id) +); +CREATE TABLE task_property_data ( + id INTEGER PRIMARY KEY, + task_property INTEGER, + kind TEXT, + data TEXT, + created REAL, + FOREIGN KEY(task_property) REFERENCES task_property(id) +);""" def transaction(method: Fn) -> Fn: @wraps(method) @@ -79,50 +119,17 @@ class SbyStatusDb: @transaction def _setup(self): - script = """ - CREATE TABLE task ( - id INTEGER PRIMARY KEY, - workdir TEXT, - mode TEXT, - created REAL - ); - CREATE TABLE task_status ( - id INTEGER PRIMARY KEY, - task INTEGER, - status TEXT, - data TEXT, - created REAL, - FOREIGN KEY(task) REFERENCES task(id) - ); - CREATE TABLE task_property ( - id INTEGER PRIMARY KEY, - task INTEGER, - src TEXT, - name TEXT, - created REAL, - FOREIGN KEY(task) REFERENCES task(id) - ); - CREATE TABLE task_property_status ( - id INTEGER PRIMARY KEY, - task_property INTEGER, - status TEXT, - data TEXT, - created REAL, - FOREIGN KEY(task_property) REFERENCES task_property(id) - ); - CREATE TABLE task_property_data ( - id INTEGER PRIMARY KEY, - task_property INTEGER, - kind TEXT, - data TEXT, - created REAL, - FOREIGN KEY(task_property) REFERENCES task_property(id) - ); - """ - for statement in script.split(";\n"): + for statement in SQLSCRIPT.split(";\n"): statement = statement.strip() if statement: self.db.execute(statement) + self.db.execute("""PRAGMA foreign_keys = ON;""") + + def test_schema(self) -> bool: + schema = self.db.execute("SELECT sql FROM sqlite_master;").fetchall() + schema_script = '\n'.join(str(sql[0] + ';') for sql in schema) + self._tables = re.findall(r"CREATE TABLE (\w+) \(", schema_script) + return schema_script != SQLSCRIPT @transaction def create_task(self, workdir: str, mode: str) -> int: @@ -284,11 +291,18 @@ class SbyStatusDb: @transaction def reset(self): - self.db.execute("""DELETE FROM task_property_status""") - self.db.execute("""DELETE FROM task_property_data""") - self.db.execute("""DELETE FROM task_property""") - self.db.execute("""DELETE FROM task_status""") - self.db.execute("""DELETE FROM task""") + hard_reset = self.test_schema() + # table names can't be parameters, so we need to use f-strings + # but it is safe to use here because it comes from the regex "\w+" + for table in self._tables: + if hard_reset: + self.log_debug(f"dropping {table}") + self.db.execute(f"DROP TABLE {table}") + else: + self.log_debug(f"clearing {table}") + self.db.execute(f"DELETE FROM {table}") + if hard_reset: + self._setup() def print_status_summary(self): tasks, task_properties, task_property_statuses = self.all_status_data()