Files

75 lines
1.7 KiB
Go

package middleware
import (
"context"
"net/http"
"maintainarr/internal/models"
"maintainarr/internal/services"
)
type contextKey string
const userKey contextKey = "user"
func RequireAuth(sessions *services.SessionService, repo *services.Repository) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
session, err := sessions.Get(r)
if err != nil {
http.Redirect(w, r, "/login", http.StatusSeeOther)
return
}
rawUserID, ok := session.Values["user_id"].(int64)
if !ok {
http.Redirect(w, r, "/login", http.StatusSeeOther)
return
}
user, err := repo.GetUserByID(r.Context(), rawUserID)
if err != nil {
http.Redirect(w, r, "/login", http.StatusSeeOther)
return
}
ctx := context.WithValue(r.Context(), userKey, user)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}
func RequireRole(role models.Role) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
user := CurrentUser(r)
if user == nil {
http.Redirect(w, r, "/login", http.StatusSeeOther)
return
}
if user.Role == models.RoleAdmin {
next.ServeHTTP(w, r)
return
}
if role == models.RoleEditor && user.Role == models.RoleEditor {
next.ServeHTTP(w, r)
return
}
if role == models.RoleViewer && (user.Role == models.RoleViewer || user.Role == models.RoleEditor) {
next.ServeHTTP(w, r)
return
}
http.Error(w, "forbidden", http.StatusForbidden)
})
}
}
func CurrentUser(r *http.Request) *models.User {
user, _ := r.Context().Value(userKey).(*models.User)
return user
}