package handlers import ( "context" "crypto/rand" "encoding/base64" "encoding/json" "net/http" "time" "git.jnss.me/joakim/opal/internal/auth" "git.jnss.me/joakim/opal/internal/engine" ) var authConfig = auth.LoadConfig() var oauthClient *auth.OAuthClient func init() { if authConfig.OAuthEnabled { oauthClient = auth.NewOAuthClient(authConfig) } } // GetLoginURL returns the OAuth authorization URL func GetLoginURL(w http.ResponseWriter, r *http.Request) { if !authConfig.OAuthEnabled { errorResponse(w, http.StatusNotImplemented, "OAuth not enabled") return } state := generateState() url := oauthClient.GetAuthURL(state) jsonResponse(w, http.StatusOK, map[string]string{ "url": url, "state": state, }) } // OAuthCallback handles the OAuth callback func OAuthCallback(w http.ResponseWriter, r *http.Request) { code := r.URL.Query().Get("code") if code == "" { errorResponse(w, http.StatusBadRequest, "missing code parameter") return } // Exchange code for token ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() oauthToken, err := oauthClient.ExchangeCode(ctx, code) if err != nil { errorResponse(w, http.StatusInternalServerError, "failed to exchange code: "+err.Error()) return } // Get user info userInfo, err := oauthClient.GetUserInfo(ctx, oauthToken.AccessToken) if err != nil { errorResponse(w, http.StatusInternalServerError, "failed to get user info: "+err.Error()) return } // Find or create user user, err := engine.FindOrCreateOAuthUser(userInfo.Sub, userInfo.Username, userInfo.Email) if err != nil { errorResponse(w, http.StatusInternalServerError, "failed to create user: "+err.Error()) return } // Generate JWT accessToken, expiresAt, err := auth.GenerateJWT(user.ID, user.Username, user.Email, authConfig) if err != nil { errorResponse(w, http.StatusInternalServerError, "failed to generate token: "+err.Error()) return } // Generate refresh token refreshToken, err := auth.GenerateRefreshToken() if err != nil { errorResponse(w, http.StatusInternalServerError, "failed to generate refresh token: "+err.Error()) return } // Store refresh token if err := auth.StoreRefreshToken(user.ID, refreshToken, authConfig.RefreshTokenExpiry); err != nil { errorResponse(w, http.StatusInternalServerError, "failed to store refresh token: "+err.Error()) return } jsonResponse(w, http.StatusOK, map[string]interface{}{ "access_token": accessToken, "refresh_token": refreshToken, "expires_at": expiresAt, "token_type": "Bearer", "user": map[string]interface{}{ "id": user.ID, "username": user.Username, "email": user.Email, }, }) } // RefreshToken handles token refresh func RefreshToken(w http.ResponseWriter, r *http.Request) { var req struct { RefreshToken string `json:"refresh_token"` } if err := json.NewDecoder(r.Body).Decode(&req); err != nil { errorResponse(w, http.StatusBadRequest, "invalid request") return } // Validate refresh token userID, err := auth.ValidateRefreshToken(req.RefreshToken) if err != nil { errorResponse(w, http.StatusUnauthorized, "invalid refresh token: "+err.Error()) return } // Get user user, err := engine.GetUser(userID) if err != nil { errorResponse(w, http.StatusNotFound, "user not found") return } // Generate new access token accessToken, expiresAt, err := auth.GenerateJWT(user.ID, user.Username, user.Email, authConfig) if err != nil { errorResponse(w, http.StatusInternalServerError, "failed to generate token") return } jsonResponse(w, http.StatusOK, map[string]interface{}{ "access_token": accessToken, "expires_at": expiresAt, "token_type": "Bearer", }) } // Logout revokes refresh token func Logout(w http.ResponseWriter, r *http.Request) { var req struct { RefreshToken string `json:"refresh_token"` } if err := json.NewDecoder(r.Body).Decode(&req); err != nil { errorResponse(w, http.StatusBadRequest, "invalid request") return } if err := auth.RevokeRefreshToken(req.RefreshToken); err != nil { errorResponse(w, http.StatusInternalServerError, "failed to revoke token") return } jsonResponse(w, http.StatusOK, map[string]string{"message": "logged out"}) } func generateState() string { b := make([]byte, 16) rand.Read(b) return base64.URLEncoding.EncodeToString(b) }