- Create internal/config package with unified config structs and validation - Abstract viper dependency behind config.Loader interface for better testability - Replace manual config parsing and type assertions with type-safe loading - Consolidate AuthConfig, SiteConfig, and DiscoveryConfig into single package - Add comprehensive validation with clear error messages - Remove ~200 lines of duplicate config handling code - Maintain backward compatibility with existing config files
191 lines
4.1 KiB
Go
191 lines
4.1 KiB
Go
package config
|
|
|
|
import (
|
|
"fmt"
|
|
"net/url"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
)
|
|
|
|
func validate(config *Config) error {
|
|
if err := validateDatabase(config.Database); err != nil {
|
|
return fmt.Errorf("database config: %w", err)
|
|
}
|
|
|
|
if err := validateAPI(config.API); err != nil {
|
|
return fmt.Errorf("api config: %w", err)
|
|
}
|
|
|
|
if err := validateCLI(config.CLI); err != nil {
|
|
return fmt.Errorf("cli config: %w", err)
|
|
}
|
|
|
|
if err := validateServer(config.Server); err != nil {
|
|
return fmt.Errorf("server config: %w", err)
|
|
}
|
|
|
|
if err := validateAuth(config.Auth); err != nil {
|
|
return fmt.Errorf("auth config: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func validateDatabase(config DatabaseConfig) error {
|
|
if config.Path == "" {
|
|
return fmt.Errorf("path is required")
|
|
}
|
|
|
|
if strings.Contains(config.Path, "postgres://") || strings.Contains(config.Path, "postgresql://") {
|
|
if _, err := url.Parse(config.Path); err != nil {
|
|
return fmt.Errorf("invalid PostgreSQL connection string: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
dir := filepath.Dir(config.Path)
|
|
if _, err := os.Stat(dir); os.IsNotExist(err) {
|
|
return fmt.Errorf("database directory does not exist: %s", dir)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func validateAPI(config APIConfig) error {
|
|
if config.URL != "" {
|
|
if _, err := url.Parse(config.URL); err != nil {
|
|
return fmt.Errorf("invalid URL: %w", err)
|
|
}
|
|
}
|
|
|
|
if config.URL != "" && config.Key == "" {
|
|
return fmt.Errorf("api key is required when api url is specified")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func validateCLI(config CLIConfig) error {
|
|
if config.SiteID == "" {
|
|
return fmt.Errorf("site_id is required")
|
|
}
|
|
|
|
if strings.TrimSpace(config.SiteID) != config.SiteID {
|
|
return fmt.Errorf("site_id cannot have leading or trailing whitespace")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func validateServer(config ServerConfig) error {
|
|
if config.Port < 1 || config.Port > 65535 {
|
|
return fmt.Errorf("port must be between 1 and 65535, got %d", config.Port)
|
|
}
|
|
|
|
if config.Host == "" {
|
|
return fmt.Errorf("host is required")
|
|
}
|
|
|
|
siteIDs := make(map[string]bool)
|
|
for i, site := range config.Sites {
|
|
if err := validateSite(site); err != nil {
|
|
return fmt.Errorf("site %d: %w", i, err)
|
|
}
|
|
|
|
if siteIDs[site.SiteID] {
|
|
return fmt.Errorf("duplicate site_id: %s", site.SiteID)
|
|
}
|
|
siteIDs[site.SiteID] = true
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func validateSite(config SiteConfig) error {
|
|
if config.SiteID == "" {
|
|
return fmt.Errorf("site_id is required")
|
|
}
|
|
|
|
if config.Path == "" {
|
|
return fmt.Errorf("path is required")
|
|
}
|
|
|
|
if config.SourcePath != "" {
|
|
if _, err := os.Stat(config.SourcePath); os.IsNotExist(err) {
|
|
return fmt.Errorf("source_path does not exist: %s", config.SourcePath)
|
|
}
|
|
}
|
|
|
|
if config.Discovery != nil {
|
|
if err := validateDiscovery(*config.Discovery); err != nil {
|
|
return fmt.Errorf("discovery config: %w", err)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func validateDiscovery(config DiscoveryConfig) error {
|
|
return nil
|
|
}
|
|
|
|
func validateAuth(config AuthConfig) error {
|
|
if config.Provider == "" {
|
|
return fmt.Errorf("provider is required")
|
|
}
|
|
|
|
validProviders := []string{"mock", "authentik"}
|
|
valid := false
|
|
for _, provider := range validProviders {
|
|
if config.Provider == provider {
|
|
valid = true
|
|
break
|
|
}
|
|
}
|
|
if !valid {
|
|
return fmt.Errorf("invalid provider %q, must be one of: %s", config.Provider, strings.Join(validProviders, ", "))
|
|
}
|
|
|
|
if config.Provider == "authentik" {
|
|
if config.OIDC == nil {
|
|
return fmt.Errorf("oidc config is required for authentik provider")
|
|
}
|
|
if err := validateOIDC(*config.OIDC); err != nil {
|
|
return fmt.Errorf("oidc config: %w", err)
|
|
}
|
|
}
|
|
|
|
if config.JWTSecret == "" {
|
|
return fmt.Errorf("jwt_secret is required")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func validateOIDC(config OIDCConfig) error {
|
|
if config.Endpoint == "" {
|
|
return fmt.Errorf("endpoint is required")
|
|
}
|
|
|
|
if _, err := url.Parse(config.Endpoint); err != nil {
|
|
return fmt.Errorf("invalid endpoint URL: %w", err)
|
|
}
|
|
|
|
if config.ClientID == "" {
|
|
return fmt.Errorf("client_id is required")
|
|
}
|
|
|
|
if config.ClientSecret == "" {
|
|
return fmt.Errorf("client_secret is required")
|
|
}
|
|
|
|
if config.RedirectURL != "" {
|
|
if _, err := url.Parse(config.RedirectURL); err != nil {
|
|
return fmt.Errorf("invalid redirect_url: %w", err)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|