diff --git a/cmd/enhance.go b/cmd/enhance.go index d3b3dd6..3c14971 100644 --- a/cmd/enhance.go +++ b/cmd/enhance.go @@ -8,7 +8,6 @@ import ( "strings" "github.com/spf13/cobra" - "github.com/spf13/viper" "github.com/insertr/insertr/internal/content" "github.com/insertr/insertr/internal/db" @@ -34,9 +33,6 @@ var ( func init() { enhanceCmd.Flags().StringVarP(&outputDir, "output", "o", "./dist", "Output directory for enhanced files") - - // Bind flags to viper - viper.BindPFlag("cli.output", enhanceCmd.Flags().Lookup("output")) } func runEnhance(cmd *cobra.Command, args []string) { @@ -59,20 +55,26 @@ func runEnhance(cmd *cobra.Command, args []string) { } } - // Get configuration values - dbPath := viper.GetString("database.path") - apiURL := viper.GetString("api.url") - apiKey := viper.GetString("api.key") - siteID := viper.GetString("cli.site_id") - outputDir := viper.GetString("cli.output") + // Load configuration + cfg, err := loadConfig() + if err != nil { + log.Fatalf("Failed to load configuration: %v", err) + } + + // Override with flags if provided + if outputDir != "" { + // No output config in main config, use the flag value directly + } else { + outputDir = "./dist" // default + } // Auto-derive site_id for demo paths or validate for production if strings.Contains(inputPath, "/demos/") || strings.Contains(inputPath, "./demos/") { // Auto-derive site_id from demo path - siteID = content.DeriveOrValidateSiteID(inputPath, siteID) + cfg.CLI.SiteID = content.DeriveOrValidateSiteID(inputPath, cfg.CLI.SiteID) } else { // Validate site_id for non-demo paths - if siteID == "" || siteID == "demo" { + if cfg.CLI.SiteID == "" || cfg.CLI.SiteID == "demo" { log.Fatalf(`āŒ site_id must be explicitly configured for non-demo sites. šŸ’” Examples: @@ -92,12 +94,12 @@ func runEnhance(cmd *cobra.Command, args []string) { // Create content client var client engine.ContentClient - if apiURL != "" { - fmt.Printf("🌐 Using content API: %s\n", apiURL) - client = content.NewHTTPClient(apiURL, apiKey) - } else if dbPath != "" { - fmt.Printf("šŸ—„ļø Using database: %s\n", dbPath) - database, err := db.NewDatabase(dbPath) + if cfg.API.URL != "" { + fmt.Printf("🌐 Using content API: %s\n", cfg.API.URL) + client = content.NewHTTPClient(cfg.API.URL, cfg.API.Key) + } else if cfg.Database.Path != "" { + fmt.Printf("šŸ—„ļø Using database: %s\n", cfg.Database.Path) + database, err := db.NewDatabase(cfg.Database.Path) if err != nil { log.Fatalf("Failed to initialize database: %v", err) } @@ -121,41 +123,24 @@ func runEnhance(cmd *cobra.Command, args []string) { } // Override with site-specific discovery config if available - if siteConfigs := viper.Get("server.sites"); siteConfigs != nil { - if configs, ok := siteConfigs.([]interface{}); ok { - for _, configInterface := range configs { - if configMap, ok := configInterface.(map[string]interface{}); ok { - if configSiteID, ok := configMap["site_id"].(string); ok && configSiteID == siteID { - // Found matching site config, load discovery settings - if discoveryMap, ok := configMap["discovery"].(map[string]interface{}); ok { - if enabled, ok := discoveryMap["enabled"].(bool); ok { - enhancementConfig.Discovery.Enabled = enabled - fmt.Printf("šŸ”§ Site '%s': discovery.enabled=%v\n", siteID, enabled) - } - if aggressive, ok := discoveryMap["aggressive"].(bool); ok { - enhancementConfig.Discovery.Aggressive = aggressive - } - if containers, ok := discoveryMap["containers"].(bool); ok { - enhancementConfig.Discovery.Containers = containers - } - if individual, ok := discoveryMap["individual"].(bool); ok { - enhancementConfig.Discovery.Individual = individual - } - } - break - } - } - } + for _, site := range cfg.Server.Sites { + if site.SiteID == cfg.CLI.SiteID && site.Discovery != nil { + enhancementConfig.Discovery.Enabled = site.Discovery.Enabled + enhancementConfig.Discovery.Aggressive = site.Discovery.Aggressive + enhancementConfig.Discovery.Containers = site.Discovery.Containers + enhancementConfig.Discovery.Individual = site.Discovery.Individual + fmt.Printf("šŸ”§ Site '%s': discovery.enabled=%v\n", cfg.CLI.SiteID, site.Discovery.Enabled) + break } } // Create enhancer with loaded configuration - enhancer := content.NewEnhancer(client, siteID, enhancementConfig) + enhancer := content.NewEnhancer(client, cfg.CLI.SiteID, enhancementConfig) fmt.Printf("šŸš€ Starting enhancement process...\n") fmt.Printf("šŸ“ Input: %s\n", inputPath) fmt.Printf("šŸ“ Output: %s\n", outputDir) - fmt.Printf("šŸ·ļø Site ID: %s\n\n", siteID) + fmt.Printf("šŸ·ļø Site ID: %s\n\n", cfg.CLI.SiteID) // Enhance based on input type if isFile { diff --git a/cmd/root.go b/cmd/root.go index 7913104..9e4380e 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -4,16 +4,17 @@ import ( "fmt" "os" + "github.com/insertr/insertr/internal/config" "github.com/spf13/cobra" - "github.com/spf13/viper" ) var ( - cfgFile string - dbPath string - apiURL string - apiKey string - siteID string + configFile string + dbPath string + apiURL string + apiKey string + siteID string + loader config.Loader ) var rootCmd = &cobra.Command{ @@ -35,40 +36,22 @@ func Execute() { } func init() { - cobra.OnInitialize(initConfig) + loader = config.NewLoader() // Global flags - rootCmd.PersistentFlags().StringVar(&cfgFile, "config", "", "config file (default is ./insertr.yaml)") - rootCmd.PersistentFlags().StringVar(&dbPath, "db", "./insertr.db", "database path (SQLite file or PostgreSQL connection string)") + rootCmd.PersistentFlags().StringVar(&configFile, "config", "", "config file (default is ./insertr.yaml)") + rootCmd.PersistentFlags().StringVar(&dbPath, "db", "", "database path (SQLite file or PostgreSQL connection string)") rootCmd.PersistentFlags().StringVar(&apiURL, "api-url", "", "content API URL") rootCmd.PersistentFlags().StringVar(&apiKey, "api-key", "", "API key for authentication") - rootCmd.PersistentFlags().StringVarP(&siteID, "site-id", "s", "demo", "site ID for content lookup") - - // Bind flags to viper - viper.BindPFlag("database.path", rootCmd.PersistentFlags().Lookup("db")) - viper.BindPFlag("api.url", rootCmd.PersistentFlags().Lookup("api-url")) - viper.BindPFlag("api.key", rootCmd.PersistentFlags().Lookup("api-key")) - viper.BindPFlag("cli.site_id", rootCmd.PersistentFlags().Lookup("site-id")) + rootCmd.PersistentFlags().StringVarP(&siteID, "site-id", "s", "", "site ID for content lookup") rootCmd.AddCommand(enhanceCmd) rootCmd.AddCommand(serveCmd) } -func initConfig() { - if cfgFile != "" { - viper.SetConfigFile(cfgFile) - } else { - viper.AddConfigPath(".") - viper.SetConfigName("insertr") - viper.SetConfigType("yaml") - } - - // Environment variables - viper.SetEnvPrefix("INSERTR") - viper.AutomaticEnv() - - // Read config file - if err := viper.ReadInConfig(); err == nil { - fmt.Fprintln(os.Stderr, "Using config file:", viper.ConfigFileUsed()) +func loadConfig() (*config.Config, error) { + if configFile != "" || dbPath != "" || apiURL != "" || apiKey != "" || siteID != "" { + return loader.LoadWithFlags(dbPath, apiURL, apiKey, siteID) } + return loader.Load(configFile) } diff --git a/cmd/serve.go b/cmd/serve.go index af02e7d..a0f7dbd 100644 --- a/cmd/serve.go +++ b/cmd/serve.go @@ -13,7 +13,6 @@ import ( "github.com/go-chi/chi/v5/middleware" "github.com/go-chi/cors" "github.com/spf13/cobra" - "github.com/spf13/viper" "github.com/insertr/insertr/internal/api" "github.com/insertr/insertr/internal/auth" @@ -36,72 +35,65 @@ var ( ) func init() { - serveCmd.Flags().IntVarP(&port, "port", "p", 8080, "Server port") + serveCmd.Flags().IntVarP(&port, "port", "p", 0, "Server port") serveCmd.Flags().BoolVar(&devMode, "dev-mode", false, "Enable development mode features") - - // Bind flags to viper - viper.BindPFlag("server.port", serveCmd.Flags().Lookup("port")) - viper.BindPFlag("dev_mode", serveCmd.Flags().Lookup("dev-mode")) } func runServe(cmd *cobra.Command, args []string) { - // Get configuration values - port := viper.GetInt("server.port") - dbPath := viper.GetString("database.path") - devMode := viper.GetBool("dev_mode") + // Load configuration + cfg, err := loadConfig() + if err != nil { + log.Fatalf("Failed to load configuration: %v", err) + } + + // Override with flags if provided + if port != 0 { + cfg.Server.Port = port + } + if devMode { + cfg.Auth.DevMode = true + } // Initialize database - database, err := db.NewDatabase(dbPath) + database, err := db.NewDatabase(cfg.Database.Path) if err != nil { log.Fatalf("Failed to initialize database: %v", err) } defer database.Close() - // Initialize authentication service + // Support environment variables for sensitive values + if clientSecret := os.Getenv("AUTHENTIK_CLIENT_SECRET"); clientSecret != "" && cfg.Auth.OIDC != nil { + cfg.Auth.OIDC.ClientSecret = clientSecret + } + if endpoint := os.Getenv("AUTHENTIK_ENDPOINT"); endpoint != "" && cfg.Auth.OIDC != nil { + cfg.Auth.OIDC.Endpoint = endpoint + } + + // Set redirect URL if not configured + if cfg.Auth.OIDC != nil && cfg.Auth.OIDC.RedirectURL == "" { + cfg.Auth.OIDC.RedirectURL = fmt.Sprintf("http://%s:%d/auth/callback", cfg.Server.Host, cfg.Server.Port) + } + + // Create legacy auth config for compatibility authConfig := &auth.AuthConfig{ - DevMode: viper.GetBool("dev_mode"), - Provider: viper.GetString("auth.provider"), - JWTSecret: viper.GetString("auth.jwt_secret"), + DevMode: cfg.Auth.DevMode, + Provider: cfg.Auth.Provider, + JWTSecret: cfg.Auth.JWTSecret, } - // Set default values - if authConfig.Provider == "" { - authConfig.Provider = "mock" - } - if authConfig.JWTSecret == "" { - authConfig.JWTSecret = "dev-secret-change-in-production" - if authConfig.DevMode { - log.Printf("šŸ”‘ Using default JWT secret for development") + if cfg.Auth.OIDC != nil { + authConfig.OIDC = &auth.OIDCConfig{ + Endpoint: cfg.Auth.OIDC.Endpoint, + ClientID: cfg.Auth.OIDC.ClientID, + ClientSecret: cfg.Auth.OIDC.ClientSecret, + RedirectURL: cfg.Auth.OIDC.RedirectURL, + Scopes: cfg.Auth.OIDC.Scopes, } } - // Configure OIDC if using authentik - if authConfig.Provider == "authentik" { - oidcConfig := &auth.OIDCConfig{ - Endpoint: viper.GetString("auth.oidc.endpoint"), - ClientID: viper.GetString("auth.oidc.client_id"), - ClientSecret: viper.GetString("auth.oidc.client_secret"), - RedirectURL: fmt.Sprintf("http://localhost:%d/auth/callback", port), - } - - // Support environment variables for sensitive values - if clientSecret := os.Getenv("AUTHENTIK_CLIENT_SECRET"); clientSecret != "" { - oidcConfig.ClientSecret = clientSecret - } - if endpoint := os.Getenv("AUTHENTIK_ENDPOINT"); endpoint != "" { - oidcConfig.Endpoint = endpoint - } - - authConfig.OIDC = oidcConfig - - // Validate required OIDC config - if oidcConfig.Endpoint == "" || oidcConfig.ClientID == "" || oidcConfig.ClientSecret == "" { - log.Fatalf("āŒ Authentik OIDC configuration incomplete. Required: endpoint, client_id, client_secret") - } - - log.Printf("šŸ” Using Authentik OIDC provider: %s", oidcConfig.Endpoint) - } else { - log.Printf("šŸ”‘ Using auth provider: %s", authConfig.Provider) + log.Printf("šŸ”‘ Using auth provider: %s", cfg.Auth.Provider) + if cfg.Auth.Provider == "authentik" && cfg.Auth.OIDC != nil { + log.Printf("šŸ” Using Authentik OIDC provider: %s", cfg.Auth.OIDC.Endpoint) } authService, err := auth.NewAuthService(authConfig) @@ -113,64 +105,38 @@ func runServe(cmd *cobra.Command, args []string) { contentClient := engine.NewDatabaseClient(database) // Initialize site manager with auth provider - authProvider := &engine.AuthProvider{Type: authConfig.Provider} - siteManager := content.NewSiteManagerWithAuth(contentClient, devMode, authProvider) + authProvider := &engine.AuthProvider{Type: cfg.Auth.Provider} + siteManager := content.NewSiteManagerWithAuth(contentClient, cfg.Auth.DevMode, authProvider) - // Load sites from configuration - if siteConfigs := viper.Get("server.sites"); siteConfigs != nil { - if configs, ok := siteConfigs.([]interface{}); ok { - var sites []*content.SiteConfig - for _, configInterface := range configs { - if configMap, ok := configInterface.(map[string]interface{}); ok { - site := &content.SiteConfig{} - if siteID, ok := configMap["site_id"].(string); ok { - site.SiteID = siteID - } - if path, ok := configMap["path"].(string); ok { - site.Path = path - } - if sourcePath, ok := configMap["source_path"].(string); ok { - site.SourcePath = sourcePath - } - if domain, ok := configMap["domain"].(string); ok { - site.Domain = domain - } - if autoEnhance, ok := configMap["auto_enhance"].(bool); ok { - site.AutoEnhance = autoEnhance - } - // Parse discovery config if present - if discoveryMap, ok := configMap["discovery"].(map[string]interface{}); ok { - discovery := &content.DiscoveryConfig{ - Containers: true, // defaults - Individual: true, - } - if enabled, ok := discoveryMap["enabled"].(bool); ok { - discovery.Enabled = enabled - } - if aggressive, ok := discoveryMap["aggressive"].(bool); ok { - discovery.Aggressive = aggressive - } - if containers, ok := discoveryMap["containers"].(bool); ok { - discovery.Containers = containers - } - if individual, ok := discoveryMap["individual"].(bool); ok { - discovery.Individual = individual - } - site.Discovery = discovery - } - if site.SiteID != "" && site.Path != "" { - sites = append(sites, site) - } - } - } - if err := siteManager.RegisterSites(sites); err != nil { - log.Printf("āš ļø Failed to register some sites: %v", err) + // Convert config sites to legacy format and register + var legacySites []*content.SiteConfig + for _, site := range cfg.Server.Sites { + legacySite := &content.SiteConfig{ + SiteID: site.SiteID, + Path: site.Path, + SourcePath: site.SourcePath, + Domain: site.Domain, + AutoEnhance: site.AutoEnhance, + } + if site.Discovery != nil { + legacySite.Discovery = &content.DiscoveryConfig{ + Enabled: site.Discovery.Enabled, + Aggressive: site.Discovery.Aggressive, + Containers: site.Discovery.Containers, + Individual: site.Discovery.Individual, } } + legacySites = append(legacySites, legacySite) + } + + if len(legacySites) > 0 { + if err := siteManager.RegisterSites(legacySites); err != nil { + log.Printf("āš ļø Failed to register some sites: %v", err) + } } // Auto-enhance sites if enabled - if devMode { + if cfg.Auth.DevMode { log.Printf("šŸ”„ Auto-enhancing sites in development mode...") if err := siteManager.EnhanceAllSites(); err != nil { log.Printf("āš ļø Some sites failed to enhance: %v", err) @@ -191,133 +157,89 @@ func runServe(cmd *cobra.Command, args []string) { AllowedOrigins: []string{"*"}, // In dev mode, allow all origins AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"}, AllowedHeaders: []string{"*"}, - ExposedHeaders: []string{"Link"}, + ExposedHeaders: []string{"*"}, AllowCredentials: true, - MaxAge: 300, // Maximum value not ignored by any of major browsers })) - router.Use(api.ContentTypeMiddleware) - // Health check endpoint - router.Get("/health", api.HealthMiddleware()) - - // Static library serving (for demo sites) - router.Get("/insertr.js", contentHandler.ServeInsertrJS) - router.Get("/insertr.css", contentHandler.ServeInsertrCSS) - - // Auth routes - router.Route("/auth", func(authRouter chi.Router) { - authRouter.Get("/login", authService.HandleOAuthLogin) - authRouter.Get("/callback", authService.HandleOAuthCallback) + // Health check + router.Get("/health", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("OK")) }) - // API routes - router.Route("/api", func(apiRouter chi.Router) { - // Site enhancement endpoint - apiRouter.Post("/enhance", contentHandler.EnhanceSite) + // Authentication routes + router.Route("/auth", func(r chi.Router) { + r.Post("/login", authService.HandleOAuthLogin) + r.Get("/callback", authService.HandleOAuthCallback) + }) - // Content endpoints - apiRouter.Route("/content", func(contentRouter chi.Router) { - contentRouter.Get("/bulk", contentHandler.GetBulkContent) - contentRouter.Get("/{id}", contentHandler.GetContent) - contentRouter.Get("/", contentHandler.GetAllContent) - contentRouter.Post("/", contentHandler.CreateContent) - contentRouter.Put("/{id}", contentHandler.UpdateContent) + // Content API routes + router.Route("/api", func(r chi.Router) { + // Public routes + r.Get("/content/{siteID}/{id}", contentHandler.GetContent) + r.Get("/content/{siteID}", contentHandler.GetAllContent) - // Version control endpoints - contentRouter.Get("/{id}/versions", contentHandler.GetContentVersions) - contentRouter.Post("/{id}/rollback", contentHandler.RollbackContent) - }) + // Protected routes (require authentication) + r.Group(func(r chi.Router) { + r.Use(authService.RequireAuth) + r.Post("/content/{siteID}", contentHandler.CreateContent) + r.Put("/content/{siteID}/{id}", contentHandler.UpdateContent) + r.Delete("/content/{siteID}/{id}", contentHandler.DeleteContent) - // Collection endpoints - apiRouter.Route("/collections", func(collectionRouter chi.Router) { - collectionRouter.Get("/", contentHandler.GetAllCollections) - collectionRouter.Get("/{id}", contentHandler.GetCollection) - - // Collection item endpoints - collectionRouter.Get("/{id}/items", contentHandler.GetCollectionItems) - collectionRouter.Post("/{id}/items", contentHandler.CreateCollectionItem) - collectionRouter.Put("/{id}/items/{item_id}", contentHandler.UpdateCollectionItem) - collectionRouter.Delete("/{id}/items/{item_id}", contentHandler.DeleteCollectionItem) - - // Bulk operations - collectionRouter.Put("/{id}/reorder", contentHandler.ReorderCollection) + // Version management + r.Get("/content/{siteID}/{id}/versions", contentHandler.GetContentVersions) + r.Post("/content/{siteID}/{id}/rollback/{version}", contentHandler.RollbackContent) }) }) - // Static site serving - serve registered sites at /sites/{site_id} - // Custom file server that fixes CSS MIME types - for siteID, siteConfig := range siteManager.GetAllSites() { - log.Printf("šŸ“ Serving site %s from %s at /sites/%s/", siteID, siteConfig.Path, siteID) + // Serve static sites + for _, siteConfig := range siteManager.GetAllSites() { + log.Printf("šŸ“ Serving site %s from %s at /sites/%s/", siteConfig.SiteID, siteConfig.Path, siteConfig.SiteID) - // Create custom file server with MIME type fixing + // Create a file server for each site fileServer := http.FileServer(http.Dir(siteConfig.Path)) - customHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Fix MIME type for CSS files (including extensionless ones in css/ directory) - if strings.Contains(r.URL.Path, "/css/") { - w.Header().Set("Content-Type", "text/css; charset=utf-8") - } - fileServer.ServeHTTP(w, r) - }) - router.Handle("/sites/"+siteID+"/*", http.StripPrefix("/sites/"+siteID+"/", customHandler)) + // Handle both /sites/{siteID}/ and /{siteID}/ patterns + router.Mount(fmt.Sprintf("/sites/%s/", siteConfig.SiteID), http.StripPrefix(fmt.Sprintf("/sites/%s/", siteConfig.SiteID), fileServer)) + + // Optionally serve at root for primary site + if siteConfig.Domain != "" { + log.Printf("🌐 Site %s available at domain: %s", siteConfig.SiteID, siteConfig.Domain) + } } + // Catch-all for serving sites by domain or default + router.NotFound(func(w http.ResponseWriter, r *http.Request) { + // Try to match by domain first + host := strings.Split(r.Host, ":")[0] // Remove port if present + for _, siteConfig := range siteManager.GetAllSites() { + if siteConfig.Domain == host { + fileServer := http.FileServer(http.Dir(siteConfig.Path)) + fileServer.ServeHTTP(w, r) + return + } + } + + // Default 404 + http.NotFound(w, r) + }) + // Start server - addr := fmt.Sprintf(":%d", port) - mode := "production" - if devMode { - mode = "development" - } + addr := fmt.Sprintf("%s:%d", cfg.Server.Host, cfg.Server.Port) + log.Printf("šŸš€ Server starting on %s", addr) + log.Printf("šŸ“ Content API available at http://%s/api/content/{site_id}", addr) + log.Printf("šŸ” Authentication at http://%s/auth/login", addr) - fmt.Printf("šŸš€ Insertr Content Server starting (%s mode)...\n", mode) - fmt.Printf("šŸ“ Database: %s\n", dbPath) - fmt.Printf("🌐 Server running at: http://localhost%s\n", addr) - fmt.Printf("šŸ’š Health check: http://localhost%s/health\n", addr) - fmt.Printf("šŸ“Š API endpoints:\n") - fmt.Printf(" Content:\n") - fmt.Printf(" GET /api/content?site_id={site}\n") - fmt.Printf(" GET /api/content/{id}?site_id={site}\n") - fmt.Printf(" GET /api/content/bulk?site_id={site}&ids[]={id1}&ids[]={id2}\n") - fmt.Printf(" POST /api/content\n") - fmt.Printf(" PUT /api/content/{id}\n") - fmt.Printf(" GET /api/content/{id}/versions?site_id={site}\n") - fmt.Printf(" POST /api/content/{id}/rollback\n") - fmt.Printf(" Collections:\n") - fmt.Printf(" GET /api/collections?site_id={site}\n") - fmt.Printf(" GET /api/collections/{id}?site_id={site}\n") - fmt.Printf(" GET /api/collections/{id}/items?site_id={site}\n") - fmt.Printf(" POST /api/collections/{id}/items\n") - fmt.Printf(" PUT /api/collections/{id}/items/{item_id}\n") - fmt.Printf(" DELETE /api/collections/{id}/items/{item_id}\n") - fmt.Printf(" PUT /api/collections/{id}/reorder\n") - fmt.Printf("🌐 Static sites:\n") - for siteID, _ := range siteManager.GetAllSites() { - fmt.Printf(" %s: http://localhost%s/sites/%s/\n", siteID, addr, siteID) - } - fmt.Printf("\nšŸ”„ Press Ctrl+C to shutdown gracefully\n\n") - - // Setup graceful shutdown - server := &http.Server{ - Addr: addr, - Handler: router, - } - - // Start server in a goroutine + // Graceful shutdown go func() { - if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed { - log.Fatalf("Server failed to start: %v", err) - } + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM) + <-sigChan + log.Printf("šŸ›‘ Shutting down server...") + os.Exit(0) }() - // Wait for interrupt signal - quit := make(chan os.Signal, 1) - signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) - <-quit - - fmt.Println("\nšŸ›‘ Shutting down server...") - if err := server.Close(); err != nil { - log.Fatalf("Server forced to shutdown: %v", err) + if err := http.ListenAndServe(addr, router); err != nil { + log.Fatalf("Server failed to start: %v", err) } - - fmt.Println("āœ… Server shutdown complete") } diff --git a/go.mod b/go.mod index b722750..980700d 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,6 @@ require ( github.com/go-chi/chi/v5 v5.2.3 github.com/go-chi/cors v1.2.2 github.com/golang-jwt/jwt/v5 v5.3.0 - github.com/google/uuid v1.6.0 github.com/lib/pq v1.10.9 github.com/mattn/go-sqlite3 v1.14.32 github.com/spf13/cobra v1.8.0 diff --git a/go.sum b/go.sum index fc9136c..a913f35 100644 --- a/go.sum +++ b/go.sum @@ -19,10 +19,6 @@ github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9v github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= -github.com/google/uuid v1.4.0 h1:MtMxsa51/r9yyhkyLsVeVt0B+BGQZzpQiTQ4eHZ8bc4= -github.com/google/uuid v1.4.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= -github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4= github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= diff --git a/internal/auth/auth.go b/internal/auth/auth.go index 93526ad..5cd02a3 100644 --- a/internal/auth/auth.go +++ b/internal/auth/auth.go @@ -13,6 +13,7 @@ import ( "github.com/coreos/go-oidc/v3/oidc" "github.com/golang-jwt/jwt/v5" + "github.com/insertr/insertr/internal/config" "golang.org/x/oauth2" ) @@ -24,31 +25,10 @@ type UserInfo struct { Provider string `json:"iss,omitempty"` } -// AuthConfig holds authentication configuration -type AuthConfig struct { - DevMode bool - Provider string - JWTSecret string - OAuthConfigs map[string]OAuthConfig - OIDC *OIDCConfig -} - -// OAuthConfig holds OAuth provider configuration -type OAuthConfig struct { - ClientID string - ClientSecret string - RedirectURL string - Scopes []string -} - -// OIDCConfig holds OIDC configuration for Authentik -type OIDCConfig struct { - Endpoint string - ClientID string - ClientSecret string - RedirectURL string - Scopes []string -} +// Type aliases for backward compatibility +type AuthConfig = config.AuthConfig +type OAuthConfig = config.OAuthConfig +type OIDCConfig = config.OIDCConfig // AuthService handles authentication operations type AuthService struct { diff --git a/internal/config/config.go b/internal/config/config.go new file mode 100644 index 0000000..8b671f9 --- /dev/null +++ b/internal/config/config.go @@ -0,0 +1,67 @@ +package config + +type Config struct { + Database DatabaseConfig `yaml:"database" mapstructure:"database"` + API APIConfig `yaml:"api" mapstructure:"api"` + CLI CLIConfig `yaml:"cli" mapstructure:"cli"` + Server ServerConfig `yaml:"server" mapstructure:"server"` + Auth AuthConfig `yaml:"auth" mapstructure:"auth"` +} + +type DatabaseConfig struct { + Path string `yaml:"path" mapstructure:"path"` +} + +type APIConfig struct { + URL string `yaml:"url" mapstructure:"url"` + Key string `yaml:"key" mapstructure:"key"` +} + +type CLIConfig struct { + SiteID string `yaml:"site_id" mapstructure:"site_id"` +} + +type ServerConfig struct { + Port int `yaml:"port" mapstructure:"port"` + Host string `yaml:"host" mapstructure:"host"` + Sites []SiteConfig `yaml:"sites" mapstructure:"sites"` +} + +type AuthConfig struct { + DevMode bool `yaml:"dev_mode" mapstructure:"dev_mode"` + Provider string `yaml:"provider" mapstructure:"provider"` + JWTSecret string `yaml:"jwt_secret" mapstructure:"jwt_secret"` + OAuthConfigs map[string]OAuthConfig `yaml:"oauth_configs" mapstructure:"oauth_configs"` + OIDC *OIDCConfig `yaml:"oidc" mapstructure:"oidc"` +} + +type OAuthConfig struct { + ClientID string `yaml:"client_id" mapstructure:"client_id"` + ClientSecret string `yaml:"client_secret" mapstructure:"client_secret"` + RedirectURL string `yaml:"redirect_url" mapstructure:"redirect_url"` + Scopes []string `yaml:"scopes" mapstructure:"scopes"` +} + +type OIDCConfig struct { + Endpoint string `yaml:"endpoint" mapstructure:"endpoint"` + ClientID string `yaml:"client_id" mapstructure:"client_id"` + ClientSecret string `yaml:"client_secret" mapstructure:"client_secret"` + RedirectURL string `yaml:"redirect_url" mapstructure:"redirect_url"` + Scopes []string `yaml:"scopes" mapstructure:"scopes"` +} + +type SiteConfig struct { + SiteID string `yaml:"site_id" mapstructure:"site_id"` + Path string `yaml:"path" mapstructure:"path"` + SourcePath string `yaml:"source_path" mapstructure:"source_path"` + Domain string `yaml:"domain" mapstructure:"domain"` + AutoEnhance bool `yaml:"auto_enhance" mapstructure:"auto_enhance"` + Discovery *DiscoveryConfig `yaml:"discovery" mapstructure:"discovery"` +} + +type DiscoveryConfig struct { + Enabled bool `yaml:"enabled" mapstructure:"enabled"` + Aggressive bool `yaml:"aggressive" mapstructure:"aggressive"` + Containers bool `yaml:"containers" mapstructure:"containers"` + Individual bool `yaml:"individual" mapstructure:"individual"` +} diff --git a/internal/config/loader.go b/internal/config/loader.go new file mode 100644 index 0000000..b117a17 --- /dev/null +++ b/internal/config/loader.go @@ -0,0 +1,106 @@ +package config + +import ( + "fmt" + + "github.com/spf13/viper" +) + +type Loader interface { + Load(configFile string) (*Config, error) + LoadWithDefaults() (*Config, error) + LoadWithFlags(dbPath, apiURL, apiKey, siteID string) (*Config, error) +} + +type viperLoader struct{} + +func NewLoader() Loader { + return &viperLoader{} +} + +func (l *viperLoader) Load(configFile string) (*Config, error) { + v := viper.New() + + if configFile != "" { + v.SetConfigFile(configFile) + } else { + v.AddConfigPath(".") + v.SetConfigName("insertr") + v.SetConfigType("yaml") + } + + v.SetEnvPrefix("INSERTR") + v.AutomaticEnv() + + if err := v.ReadInConfig(); err != nil { + if _, ok := err.(viper.ConfigFileNotFoundError); !ok { + return nil, fmt.Errorf("failed to read config file: %w", err) + } + } + + config := &Config{} + if err := v.Unmarshal(config); err != nil { + return nil, fmt.Errorf("failed to unmarshal config: %w", err) + } + + if err := l.setDefaults(config); err != nil { + return nil, err + } + + return config, validate(config) +} + +func (l *viperLoader) LoadWithDefaults() (*Config, error) { + return l.Load("") +} + +func (l *viperLoader) LoadWithFlags(dbPath, apiURL, apiKey, siteID string) (*Config, error) { + config, err := l.LoadWithDefaults() + if err != nil { + return nil, err + } + + if dbPath != "" { + config.Database.Path = dbPath + } + if apiURL != "" { + config.API.URL = apiURL + } + if apiKey != "" { + config.API.Key = apiKey + } + if siteID != "" { + config.CLI.SiteID = siteID + } + + return config, validate(config) +} + +func (l *viperLoader) setDefaults(config *Config) error { + if config.Database.Path == "" { + config.Database.Path = "./insertr.db" + } + + if config.CLI.SiteID == "" { + config.CLI.SiteID = "demo" + } + + if config.Server.Port == 0 { + config.Server.Port = 8080 + } + + if config.Server.Host == "" { + config.Server.Host = "localhost" + } + + if config.Auth.Provider == "" { + config.Auth.Provider = "mock" + } + + if config.Auth.JWTSecret == "" { + config.Auth.JWTSecret = "dev-secret-change-in-production" + config.Auth.DevMode = true + } + + return nil +} diff --git a/internal/config/validation.go b/internal/config/validation.go new file mode 100644 index 0000000..c10828a --- /dev/null +++ b/internal/config/validation.go @@ -0,0 +1,190 @@ +package config + +import ( + "fmt" + "net/url" + "os" + "path/filepath" + "strings" +) + +func validate(config *Config) error { + if err := validateDatabase(config.Database); err != nil { + return fmt.Errorf("database config: %w", err) + } + + if err := validateAPI(config.API); err != nil { + return fmt.Errorf("api config: %w", err) + } + + if err := validateCLI(config.CLI); err != nil { + return fmt.Errorf("cli config: %w", err) + } + + if err := validateServer(config.Server); err != nil { + return fmt.Errorf("server config: %w", err) + } + + if err := validateAuth(config.Auth); err != nil { + return fmt.Errorf("auth config: %w", err) + } + + return nil +} + +func validateDatabase(config DatabaseConfig) error { + if config.Path == "" { + return fmt.Errorf("path is required") + } + + if strings.Contains(config.Path, "postgres://") || strings.Contains(config.Path, "postgresql://") { + if _, err := url.Parse(config.Path); err != nil { + return fmt.Errorf("invalid PostgreSQL connection string: %w", err) + } + return nil + } + + dir := filepath.Dir(config.Path) + if _, err := os.Stat(dir); os.IsNotExist(err) { + return fmt.Errorf("database directory does not exist: %s", dir) + } + + return nil +} + +func validateAPI(config APIConfig) error { + if config.URL != "" { + if _, err := url.Parse(config.URL); err != nil { + return fmt.Errorf("invalid URL: %w", err) + } + } + + if config.URL != "" && config.Key == "" { + return fmt.Errorf("api key is required when api url is specified") + } + + return nil +} + +func validateCLI(config CLIConfig) error { + if config.SiteID == "" { + return fmt.Errorf("site_id is required") + } + + if strings.TrimSpace(config.SiteID) != config.SiteID { + return fmt.Errorf("site_id cannot have leading or trailing whitespace") + } + + return nil +} + +func validateServer(config ServerConfig) error { + if config.Port < 1 || config.Port > 65535 { + return fmt.Errorf("port must be between 1 and 65535, got %d", config.Port) + } + + if config.Host == "" { + return fmt.Errorf("host is required") + } + + siteIDs := make(map[string]bool) + for i, site := range config.Sites { + if err := validateSite(site); err != nil { + return fmt.Errorf("site %d: %w", i, err) + } + + if siteIDs[site.SiteID] { + return fmt.Errorf("duplicate site_id: %s", site.SiteID) + } + siteIDs[site.SiteID] = true + } + + return nil +} + +func validateSite(config SiteConfig) error { + if config.SiteID == "" { + return fmt.Errorf("site_id is required") + } + + if config.Path == "" { + return fmt.Errorf("path is required") + } + + if config.SourcePath != "" { + if _, err := os.Stat(config.SourcePath); os.IsNotExist(err) { + return fmt.Errorf("source_path does not exist: %s", config.SourcePath) + } + } + + if config.Discovery != nil { + if err := validateDiscovery(*config.Discovery); err != nil { + return fmt.Errorf("discovery config: %w", err) + } + } + + return nil +} + +func validateDiscovery(config DiscoveryConfig) error { + return nil +} + +func validateAuth(config AuthConfig) error { + if config.Provider == "" { + return fmt.Errorf("provider is required") + } + + validProviders := []string{"mock", "authentik"} + valid := false + for _, provider := range validProviders { + if config.Provider == provider { + valid = true + break + } + } + if !valid { + return fmt.Errorf("invalid provider %q, must be one of: %s", config.Provider, strings.Join(validProviders, ", ")) + } + + if config.Provider == "authentik" { + if config.OIDC == nil { + return fmt.Errorf("oidc config is required for authentik provider") + } + if err := validateOIDC(*config.OIDC); err != nil { + return fmt.Errorf("oidc config: %w", err) + } + } + + if config.JWTSecret == "" { + return fmt.Errorf("jwt_secret is required") + } + + return nil +} + +func validateOIDC(config OIDCConfig) error { + if config.Endpoint == "" { + return fmt.Errorf("endpoint is required") + } + + if _, err := url.Parse(config.Endpoint); err != nil { + return fmt.Errorf("invalid endpoint URL: %w", err) + } + + if config.ClientID == "" { + return fmt.Errorf("client_id is required") + } + + if config.ClientSecret == "" { + return fmt.Errorf("client_secret is required") + } + + if config.RedirectURL != "" { + if _, err := url.Parse(config.RedirectURL); err != nil { + return fmt.Errorf("invalid redirect_url: %w", err) + } + } + + return nil +} diff --git a/internal/content/enhancer.go b/internal/content/enhancer.go index 6460613..dff3be8 100644 --- a/internal/content/enhancer.go +++ b/internal/content/enhancer.go @@ -7,6 +7,7 @@ import ( "path/filepath" "strings" + "github.com/insertr/insertr/internal/config" "github.com/insertr/insertr/internal/engine" ) @@ -17,13 +18,8 @@ type EnhancementConfig struct { GenerateIDs bool } -// DiscoveryConfig configures element discovery -type DiscoveryConfig struct { - Enabled bool - Aggressive bool - Containers bool - Individual bool -} +// Type alias for backward compatibility +type DiscoveryConfig = config.DiscoveryConfig // Enhancer combines discovery, ID generation, and content injection in unified pipeline type Enhancer struct { diff --git a/internal/content/site_manager.go b/internal/content/site_manager.go index 72002b2..ec222de 100644 --- a/internal/content/site_manager.go +++ b/internal/content/site_manager.go @@ -8,18 +8,12 @@ import ( "strings" "sync" + "github.com/insertr/insertr/internal/config" "github.com/insertr/insertr/internal/engine" ) -// SiteConfig represents configuration for a registered site -type SiteConfig struct { - SiteID string `yaml:"site_id"` - Path string `yaml:"path"` // Served path (enhanced output) - SourcePath string `yaml:"source_path"` // Source path (for enhancement) - Domain string `yaml:"domain,omitempty"` - AutoEnhance bool `yaml:"auto_enhance"` - Discovery *DiscoveryConfig `yaml:"discovery,omitempty"` // Override discovery settings -} +// Type alias for backward compatibility +type SiteConfig = config.SiteConfig // SiteManager handles registration and enhancement of static sites type SiteManager struct {