package engine import ( "database/sql" "fmt" _ "github.com/mattn/go-sqlite3" ) var db *sql.DB // InitDB initializes the database connection and runs migrations func InitDB() error { if db != nil { return nil // Already initialized } dbPath, err := GetDBPath() if err != nil { return fmt.Errorf("failed to get database path: %w", err) } // Open database connection database, err := sql.Open("sqlite3", dbPath) if err != nil { return fmt.Errorf("failed to open database: %w", err) } // Enable foreign keys if _, err := database.Exec("PRAGMA foreign_keys = ON"); err != nil { return fmt.Errorf("failed to enable foreign keys: %w", err) } db = database // Run migrations if err := runMigrations(); err != nil { return fmt.Errorf("failed to run migrations: %w", err) } return nil } // GetDB returns the database connection func GetDB() *sql.DB { return db } // CloseDB closes the database connection func CloseDB() error { if db != nil { return db.Close() } return nil } // runMigrations runs database migrations func runMigrations() error { // Create schema_version table if it doesn't exist _, err := db.Exec(` CREATE TABLE IF NOT EXISTS schema_version ( version INTEGER PRIMARY KEY, applied_at INTEGER NOT NULL ) `) if err != nil { return fmt.Errorf("failed to create schema_version table: %w", err) } // Get current schema version var currentVersion int err = db.QueryRow("SELECT COALESCE(MAX(version), 0) FROM schema_version").Scan(¤tVersion) if err != nil { return fmt.Errorf("failed to get current schema version: %w", err) } // Define migrations migrations := []struct { version int sql string }{ { version: 1, sql: ` CREATE TABLE tasks ( id INTEGER PRIMARY KEY AUTOINCREMENT, uuid TEXT UNIQUE NOT NULL, status INTEGER NOT NULL DEFAULT 80, description TEXT NOT NULL, project TEXT, priority INTEGER DEFAULT 1, created INTEGER NOT NULL, modified INTEGER NOT NULL, start INTEGER, end INTEGER, due INTEGER, scheduled INTEGER, wait INTEGER, until_date INTEGER, recurrence_duration INTEGER, parent_uuid TEXT, FOREIGN KEY (parent_uuid) REFERENCES tasks(uuid) ON DELETE CASCADE ); CREATE INDEX idx_tasks_status ON tasks(status); CREATE INDEX idx_tasks_uuid ON tasks(uuid); CREATE INDEX idx_tasks_parent ON tasks(parent_uuid); CREATE INDEX idx_tasks_due ON tasks(due); CREATE INDEX idx_tasks_project ON tasks(project); CREATE TABLE tags ( task_id INTEGER NOT NULL, tag TEXT NOT NULL, FOREIGN KEY (task_id) REFERENCES tasks(id) ON DELETE CASCADE, PRIMARY KEY (task_id, tag) ); CREATE INDEX idx_tags_tag ON tags(tag); CREATE TABLE working_set ( display_id INTEGER PRIMARY KEY, task_uuid TEXT NOT NULL, FOREIGN KEY (task_uuid) REFERENCES tasks(uuid) ON DELETE CASCADE ); `, }, } // Apply pending migrations for _, migration := range migrations { if migration.version > currentVersion { tx, err := db.Begin() if err != nil { return fmt.Errorf("failed to begin transaction for migration %d: %w", migration.version, err) } // Execute migration SQL if _, err := tx.Exec(migration.sql); err != nil { tx.Rollback() return fmt.Errorf("failed to execute migration %d: %w", migration.version, err) } // Record migration if _, err := tx.Exec( "INSERT INTO schema_version (version, applied_at) VALUES (?, ?)", migration.version, getCurrentTimestamp(), ); err != nil { tx.Rollback() return fmt.Errorf("failed to record migration %d: %w", migration.version, err) } if err := tx.Commit(); err != nil { return fmt.Errorf("failed to commit migration %d: %w", migration.version, err) } } } return nil } // getCurrentTimestamp returns the current Unix timestamp func getCurrentTimestamp() int64 { return timeNow().Unix() }