package engine import ( "fmt" "strconv" "strings" "time" "github.com/google/uuid" ) const undoStackLimit = 10 // RecordUndo records a CLI operation as undoable. // Called AFTER the mutation so the change_log entry exists. func RecordUndo(opType string, taskUUID uuid.UUID) error { db := GetDB() if db == nil { return fmt.Errorf("database not initialized") } // Find the change_log entry just created by this mutation var changeLogID int64 err := db.QueryRow( "SELECT MAX(id) FROM change_log WHERE task_uuid = ?", taskUUID.String(), ).Scan(&changeLogID) if err != nil { return fmt.Errorf("failed to find change_log entry: %w", err) } // Insert into undo_stack _, err = db.Exec( "INSERT INTO undo_stack (created_at, op_type, task_uuid, change_log_id) VALUES (?, ?, ?, ?)", timeNow().Unix(), opType, taskUUID.String(), changeLogID, ) if err != nil { return fmt.Errorf("failed to record undo: %w", err) } // Evict old entries beyond the limit _, err = db.Exec( "DELETE FROM undo_stack WHERE id NOT IN (SELECT id FROM undo_stack ORDER BY id DESC LIMIT ?)", undoStackLimit, ) if err != nil { return fmt.Errorf("failed to evict old undo entries: %w", err) } return nil } // PopUndo pops the most recent undo entry and reverts the task. // Returns a description of what was undone. func PopUndo() (string, error) { db := GetDB() if db == nil { return "", fmt.Errorf("database not initialized") } // Get the most recent undo entry var ( undoID int64 opType string taskUUIDStr string changeLogID int64 ) err := db.QueryRow( "SELECT id, op_type, task_uuid, change_log_id FROM undo_stack ORDER BY id DESC LIMIT 1", ).Scan(&undoID, &opType, &taskUUIDStr, &changeLogID) if err != nil { return "", fmt.Errorf("nothing to undo") } taskUUID, err := uuid.Parse(taskUUIDStr) if err != nil { return "", fmt.Errorf("invalid task UUID in undo stack: %w", err) } // Remove the entry from the stack _, err = db.Exec("DELETE FROM undo_stack WHERE id = ?", undoID) if err != nil { return "", fmt.Errorf("failed to pop undo entry: %w", err) } // Perform the revert based on op type switch opType { case "add": return undoAdd(taskUUID) case "done", "delete", "modify", "start", "stop": return undoRestore(opType, taskUUID, changeLogID) default: return "", fmt.Errorf("unknown undo operation: %s", opType) } } // undoAdd reverts an add by hard-deleting the task. // For recurring tasks, also deletes the template. func undoAdd(taskUUID uuid.UUID) (string, error) { db := GetDB() task, err := GetTask(taskUUID) if err != nil { return "", fmt.Errorf("failed to load task for undo: %w", err) } desc := task.Description // If this is a recurring instance, also delete the template if task.ParentUUID != nil { _, err = db.Exec("DELETE FROM tasks WHERE uuid = ?", task.ParentUUID.String()) if err != nil { return "", fmt.Errorf("failed to delete recurring template: %w", err) } } // Hard delete the task _, err = db.Exec("DELETE FROM tasks WHERE uuid = ?", taskUUID.String()) if err != nil { return "", fmt.Errorf("failed to delete task: %w", err) } return fmt.Sprintf("Undid add: removed \"%s\"", desc), nil } // undoRestore reverts a mutation by restoring the prior state from change_log. func undoRestore(opType string, taskUUID uuid.UUID, changeLogID int64) (string, error) { db := GetDB() // Find the change_log entry BEFORE this one for the same task var priorData string err := db.QueryRow( "SELECT data FROM change_log WHERE task_uuid = ? AND id < ? ORDER BY id DESC LIMIT 1", taskUUID.String(), changeLogID, ).Scan(&priorData) if err != nil { return "", fmt.Errorf("no prior state found in change_log (cannot undo)") } // Parse the prior state task, err := GetTask(taskUUID) if err != nil { return "", fmt.Errorf("failed to load task: %w", err) } // Apply the prior state from change_log data if err := applyChangeLogData(task, priorData); err != nil { return "", fmt.Errorf("failed to restore prior state: %w", err) } // Save the restored task if err := task.Save(); err != nil { return "", fmt.Errorf("failed to save restored task: %w", err) } // Reconcile tags if err := reconcileTagsFromChangeLog(task, priorData); err != nil { return "", fmt.Errorf("failed to reconcile tags: %w", err) } return fmt.Sprintf("Undid %s: restored \"%s\"", opType, task.Description), nil } // applyChangeLogData parses change_log data and applies it to a task. // The data format is "key: value\n" lines (same format used by sync). func applyChangeLogData(task *Task, data string) error { lines := strings.Split(data, "\n") for _, line := range lines { line = strings.TrimSpace(line) if line == "" { continue } parts := strings.SplitN(line, ": ", 2) if len(parts) != 2 { continue } key := parts[0] value := parts[1] switch key { case "description": task.Description = value case "status": switch value { case "pending": task.Status = StatusPending case "completed": task.Status = StatusCompleted case "deleted": task.Status = StatusDeleted case "recurring": task.Status = StatusRecurring } case "priority": switch value { case "H": task.Priority = PriorityHigh case "M": task.Priority = PriorityMedium case "L": task.Priority = PriorityLow default: task.Priority = PriorityDefault } case "project": task.Project = &value case "created": if ts, err := strconv.ParseInt(value, 10, 64); err == nil { task.Created = time.Unix(ts, 0) } case "modified": // Don't restore modified — it'll be set by Save() case "start": if ts, err := strconv.ParseInt(value, 10, 64); err == nil { t := time.Unix(ts, 0) task.Start = &t } case "end": if ts, err := strconv.ParseInt(value, 10, 64); err == nil { t := time.Unix(ts, 0) task.End = &t } case "due": if ts, err := strconv.ParseInt(value, 10, 64); err == nil { t := time.Unix(ts, 0) task.Due = &t } case "scheduled": if ts, err := strconv.ParseInt(value, 10, 64); err == nil { t := time.Unix(ts, 0) task.Scheduled = &t } case "wait": if ts, err := strconv.ParseInt(value, 10, 64); err == nil { t := time.Unix(ts, 0) task.Wait = &t } case "until": if ts, err := strconv.ParseInt(value, 10, 64); err == nil { t := time.Unix(ts, 0) task.Until = &t } case "recurrence": if ns, err := strconv.ParseInt(value, 10, 64); err == nil { d := time.Duration(ns) task.RecurrenceDuration = &d } case "parent_uuid": if u, err := uuid.Parse(value); err == nil { task.ParentUUID = &u } case "annotations": // Annotations are stored as JSON in the change_log task.Annotations = sqlToAnnotations(value) case "tags": // Tags are handled separately by reconcileTagsFromChangeLog } } // Clear fields that aren't present in the change_log data (they were NULL) fieldPresent := make(map[string]bool) for _, line := range lines { parts := strings.SplitN(strings.TrimSpace(line), ": ", 2) if len(parts) == 2 { fieldPresent[parts[0]] = true } } if !fieldPresent["project"] { task.Project = nil } if !fieldPresent["start"] { task.Start = nil } if !fieldPresent["end"] { task.End = nil } if !fieldPresent["due"] { task.Due = nil } if !fieldPresent["scheduled"] { task.Scheduled = nil } if !fieldPresent["wait"] { task.Wait = nil } if !fieldPresent["until"] { task.Until = nil } if !fieldPresent["recurrence"] { task.RecurrenceDuration = nil } if !fieldPresent["parent_uuid"] { task.ParentUUID = nil } if !fieldPresent["annotations"] { task.Annotations = nil } return nil } // reconcileTagsFromChangeLog restores tags from change_log data. func reconcileTagsFromChangeLog(task *Task, data string) error { // Parse desired tags from change_log var desiredTags []string for _, line := range strings.Split(data, "\n") { line = strings.TrimSpace(line) if strings.HasPrefix(line, "tags: ") { tagStr := strings.TrimPrefix(line, "tags: ") for _, tag := range strings.Split(tagStr, ",") { tag = strings.TrimSpace(tag) if tag != "" { desiredTags = append(desiredTags, tag) } } } } // Get current tags currentTags, _ := task.GetTags() // Remove tags not in desired set desired := make(map[string]bool) for _, t := range desiredTags { desired[t] = true } for _, tag := range currentTags { if !desired[tag] { task.RemoveTag(tag) } } // Add missing tags current := make(map[string]bool) for _, t := range currentTags { current[t] = true } for _, tag := range desiredTags { if !current[tag] { task.AddTag(tag) } } return nil }