package middleware

import (
	"crypto/hmac"
	"crypto/rand"
	"crypto/sha256"
	"crypto/subtle"
	"encoding/base64"
	"encoding/json"
	"errors"
	"fmt"
	"strconv"
	"strings"
	"sync"
	"time"

	"github.com/gofiber/fiber/v2"
)

const csrfTokenTTL = 1 * time.Hour

var (
	csrfSecretMu sync.RWMutex
	csrfSecret   []byte

	errCSRFTokenMissing       = errors.New("csrf token missing")
	errCSRFTokenMalformed     = errors.New("csrf token malformed")
	errCSRFTokenInvalid       = errors.New("csrf token invalid")
	errCSRFTokenExpired       = errors.New("csrf token expired")
	errCSRFTokenOwnerMismatch = errors.New("csrf token owner mismatch")
)

type csrfPayload struct {
	OwnerKey  string `json:"o"`
	ExpiresAt int64  `json:"e"`
	Nonce     string `json:"n"`
}

// SetCSRFSecret sets a shared secret so tokens stay valid across multiple instances.
func SetCSRFSecret(secret string) {
	trimmed := strings.TrimSpace(secret)
	if trimmed == "" {
		return
	}

	csrfSecretMu.Lock()
	csrfSecret = []byte(trimmed)
	csrfSecretMu.Unlock()
}

func getCSRFSecret() ([]byte, error) {
	csrfSecretMu.RLock()
	defer csrfSecretMu.RUnlock()

	if len(csrfSecret) == 0 {
		return nil, errors.New("csrf secret is not configured")
	}

	secretCopy := make([]byte, len(csrfSecret))
	copy(secretCopy, csrfSecret)
	return secretCopy, nil
}

// generateToken creates a cryptographically secure random token
func generateToken() (string, error) {
	b := make([]byte, 32)
	if _, err := rand.Read(b); err != nil {
		return "", err
	}
	return base64.RawURLEncoding.EncodeToString(b), nil
}

// CSRFMiddleware provides CSRF protection
func CSRFMiddleware() fiber.Handler {
	return func(c *fiber.Ctx) error {
		ownerKey := getCSRFOwnerKey(c)

		if isSafeMethod(c.Method()) {
			token := c.Cookies("csrf_token")
			if token == "" || validateCSRFToken(token, ownerKey) != nil {
				if err := issueAndAttachCSRFToken(c, ownerKey); err != nil {
					return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{
						"success": false,
						"message": "Failed to generate CSRF token",
					})
				}
			} else {
				c.Set("X-CSRF-Token", token)
			}

			return c.Next()
		}

		headerToken := c.Get("X-CSRF-Token")
		if headerToken == "" {
			headerToken = c.FormValue("csrf_token")
		}
		cookieToken := c.Cookies("csrf_token")

		if headerToken == "" || cookieToken == "" {
			_ = issueAndAttachCSRFToken(c, ownerKey)
			return c.Status(fiber.StatusForbidden).JSON(fiber.Map{
				"success": false,
				"message": "CSRF token missing",
			})
		}

		if subtle.ConstantTimeCompare([]byte(headerToken), []byte(cookieToken)) != 1 {
			_ = issueAndAttachCSRFToken(c, ownerKey)
			return c.Status(fiber.StatusForbidden).JSON(fiber.Map{
				"success": false,
				"message": "CSRF token mismatch",
			})
		}

		if err := validateCSRFToken(headerToken, ownerKey); err != nil {
			_ = issueAndAttachCSRFToken(c, ownerKey)
			return csrfValidationError(c, err)
		}

		// Rotate token for every state-changing request.
		if err := issueAndAttachCSRFToken(c, ownerKey); err != nil {
			return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{
				"success": false,
				"message": "Failed to generate CSRF token",
			})
		}

		return c.Next()
	}
}

func isSafeMethod(method string) bool {
	switch method {
	case fiber.MethodGet, fiber.MethodHead, fiber.MethodOptions:
		return true
	default:
		return false
	}
}

func getCSRFOwnerKey(c *fiber.Ctx) string {
	if userID := extractUserID(c); userID > 0 {
		return fmt.Sprintf("user:%d", userID)
	}

	return fmt.Sprintf("anon:%s:%s", c.IP(), c.Get("User-Agent"))
}

func extractUserID(c *fiber.Ctx) uint {
	switch v := c.Locals("user_id").(type) {
	case uint:
		return v
	case int:
		if v > 0 {
			return uint(v)
		}
	case int64:
		if v > 0 {
			return uint(v)
		}
	case float64:
		if v > 0 {
			return uint(v)
		}
	case string:
		parsed, err := strconv.ParseUint(v, 10, 64)
		if err == nil && parsed > 0 {
			return uint(parsed)
		}
	}

	return 0
}

func issueCSRFToken(ownerKey string) (string, error) {
	secret, err := getCSRFSecret()
	if err != nil {
		return "", err
	}

	nonce, err := generateToken()
	if err != nil {
		return "", err
	}

	payload := csrfPayload{
		OwnerKey:  ownerKey,
		ExpiresAt: time.Now().Add(csrfTokenTTL).Unix(),
		Nonce:     nonce,
	}
	payloadJSON, err := json.Marshal(payload)
	if err != nil {
		return "", err
	}

	payloadPart := base64.RawURLEncoding.EncodeToString(payloadJSON)

	mac := hmac.New(sha256.New, secret)
	if _, err := mac.Write([]byte(payloadPart)); err != nil {
		return "", err
	}
	sigPart := base64.RawURLEncoding.EncodeToString(mac.Sum(nil))

	return payloadPart + "." + sigPart, nil
}

func issueAndAttachCSRFToken(c *fiber.Ctx, ownerKey string) error {
	token, err := issueCSRFToken(ownerKey)
	if err != nil {
		return err
	}

	setCSRFCookie(c, token)
	c.Set("X-CSRF-Token", token)
	return nil
}

func validateCSRFToken(token, ownerKey string) error {
	if token == "" {
		return errCSRFTokenMissing
	}

	secret, err := getCSRFSecret()
	if err != nil {
		return err
	}

	payloadPart, sigPart, ok := strings.Cut(token, ".")
	if !ok || payloadPart == "" || sigPart == "" {
		return errCSRFTokenMalformed
	}

	providedSig, err := base64.RawURLEncoding.DecodeString(sigPart)
	if err != nil {
		return errCSRFTokenMalformed
	}

	mac := hmac.New(sha256.New, secret)
	if _, err := mac.Write([]byte(payloadPart)); err != nil {
		return errCSRFTokenInvalid
	}
	expectedSig := mac.Sum(nil)

	if subtle.ConstantTimeCompare(providedSig, expectedSig) != 1 {
		return errCSRFTokenInvalid
	}

	payloadBytes, err := base64.RawURLEncoding.DecodeString(payloadPart)
	if err != nil {
		return errCSRFTokenMalformed
	}

	var payload csrfPayload
	if err := json.Unmarshal(payloadBytes, &payload); err != nil {
		return errCSRFTokenMalformed
	}

	if payload.OwnerKey == "" || payload.ExpiresAt == 0 || payload.Nonce == "" {
		return errCSRFTokenMalformed
	}

	if time.Now().Unix() > payload.ExpiresAt {
		return errCSRFTokenExpired
	}

	if subtle.ConstantTimeCompare([]byte(payload.OwnerKey), []byte(ownerKey)) != 1 {
		return errCSRFTokenOwnerMismatch
	}

	return nil
}

func csrfValidationError(c *fiber.Ctx, err error) error {
	message := "Invalid CSRF token"

	switch {
	case errors.Is(err, errCSRFTokenMissing):
		message = "CSRF token missing"
	case errors.Is(err, errCSRFTokenExpired):
		message = "CSRF token expired"
	case errors.Is(err, errCSRFTokenOwnerMismatch):
		message = "CSRF token owner mismatch"
	}

	return c.Status(fiber.StatusForbidden).JSON(fiber.Map{
		"success": false,
		"message": message,
	})
}

func setCSRFCookie(c *fiber.Ctx, token string) {
	c.Cookie(&fiber.Cookie{
		Name:     "csrf_token",
		Value:    token,
		Path:     "/",
		Expires:  time.Now().Add(csrfTokenTTL),
		HTTPOnly: true,
		Secure:   c.Secure(),
		SameSite: "Strict",
	})
}

// GetCSRFToken returns a new CSRF token (for API endpoints)
func GetCSRFToken(c *fiber.Ctx) error {
	ownerKey := getCSRFOwnerKey(c)
	token, err := issueCSRFToken(ownerKey)
	if err != nil {
		return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{
			"success": false,
			"message": "Failed to generate CSRF token",
		})
	}

	setCSRFCookie(c, token)
	c.Set("X-CSRF-Token", token)

	return c.JSON(fiber.Map{
		"success": true,
		"token":   token,
	})
}
