package utils

import (
	"crypto/rand"
	"crypto/sha256"
	"encoding/base64"
	"encoding/hex"
	"fmt"
	"sync"
	"time"

	"github.com/golang-jwt/jwt/v5"
	"go.uber.org/zap"
)

type JWTClaims struct {
	UserID   uint   `json:"user_id"`
	Username string `json:"username"`
	Role     string `json:"role"`
	BranchID uint   `json:"branch_id"`
	jwt.RegisteredClaims
}

// TokenBlacklist interface for token revocation
type TokenBlacklist interface {
	Add(tokenHash string, expiry time.Time) error
	IsBlacklisted(tokenHash string) bool
	Cleanup() error
	Size() int
}

// InMemoryTokenBlacklist implements TokenBlacklist using in-memory storage
type InMemoryTokenBlacklist struct {
	tokens map[string]time.Time
	mutex  sync.RWMutex
}

// NewInMemoryTokenBlacklist creates a new in-memory token blacklist
func NewInMemoryTokenBlacklist() *InMemoryTokenBlacklist {
	blacklist := &InMemoryTokenBlacklist{
		tokens: make(map[string]time.Time),
	}

	// Start cleanup goroutine
	go blacklist.startCleanup()

	return blacklist
}

// Add adds a token to the blacklist
func (b *InMemoryTokenBlacklist) Add(tokenHash string, expiry time.Time) error {
	b.mutex.Lock()
	defer b.mutex.Unlock()

	b.tokens[tokenHash] = expiry
	return nil
}

// IsBlacklisted checks if a token is blacklisted
func (b *InMemoryTokenBlacklist) IsBlacklisted(tokenHash string) bool {
	b.mutex.Lock()
	defer b.mutex.Unlock()

	expiry, exists := b.tokens[tokenHash]

	if !exists {
		return false
	}

	// If token is expired, remove it and return false
	if time.Now().After(expiry) {
		delete(b.tokens, tokenHash)
		return false
	}

	return true
}

// Cleanup removes expired tokens from blacklist
func (b *InMemoryTokenBlacklist) Cleanup() error {
	b.mutex.Lock()
	defer b.mutex.Unlock()

	now := time.Now()
	for tokenHash, expiry := range b.tokens {
		if now.After(expiry) {
			delete(b.tokens, tokenHash)
		}
	}

	return nil
}

// Size returns the number of blacklisted tokens
func (b *InMemoryTokenBlacklist) Size() int {
	b.mutex.RLock()
	defer b.mutex.RUnlock()

	return len(b.tokens)
}

// startCleanup runs periodic cleanup of expired tokens
func (b *InMemoryTokenBlacklist) startCleanup() {
	ticker := time.NewTicker(1 * time.Hour)
	defer ticker.Stop()

	for range ticker.C {
		b.Cleanup()
	}
}

var (
	blacklistMu sync.RWMutex

	// Global blacklist instance
	globalBlacklist TokenBlacklist = NewInMemoryTokenBlacklist()

	userRevocationsMu sync.RWMutex
	userRevocations   = make(map[uint]time.Time)
)

func getTokenBlacklist() TokenBlacklist {
	blacklistMu.RLock()
	defer blacklistMu.RUnlock()
	return globalBlacklist
}

func isUserRevoked(userID uint, issuedAt time.Time) bool {
	userRevocationsMu.RLock()
	revokedAt, exists := userRevocations[userID]
	userRevocationsMu.RUnlock()
	if !exists {
		return false
	}

	return !issuedAt.After(revokedAt)
}

// SetTokenBlacklist sets the global token blacklist implementation
func SetTokenBlacklist(blacklist TokenBlacklist) {
	if blacklist == nil {
		return
	}

	blacklistMu.Lock()
	defer blacklistMu.Unlock()
	globalBlacklist = blacklist
}

// GetTokenBlacklist returns the global token blacklist
func GetTokenBlacklist() TokenBlacklist {
	return getTokenBlacklist()
}

// ResetTokenStateForTests resets global token state so tests remain isolated and deterministic.
func ResetTokenStateForTests() {
	SetTokenBlacklist(NewInMemoryTokenBlacklist())

	userRevocationsMu.Lock()
	defer userRevocationsMu.Unlock()
	userRevocations = make(map[uint]time.Time)
}

// GenerateToken creates access and refresh tokens
func GenerateToken(userID uint, username string, role string, branchID uint, secret string) (string, string, error) {
	// Access Token (15 minutes)
	claims := JWTClaims{
		UserID:   userID,
		Username: username,
		Role:     role,
		BranchID: branchID,
		RegisteredClaims: jwt.RegisteredClaims{
			ExpiresAt: jwt.NewNumericDate(time.Now().Add(15 * time.Minute)),
			IssuedAt:  jwt.NewNumericDate(time.Now()),
			ID:        generateJTI(),
		},
	}
	accessToken := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
	at, err := accessToken.SignedString([]byte(secret))
	if err != nil {
		return "", "", err
	}

	// Refresh Token (7 days)
	rtClaims := jwt.RegisteredClaims{
		ExpiresAt: jwt.NewNumericDate(time.Now().Add(7 * 24 * time.Hour)),
		IssuedAt:  jwt.NewNumericDate(time.Now()),
		Subject:   fmt.Sprintf("%d", userID),
		ID:        generateJTI(),
	}
	refreshToken := jwt.NewWithClaims(jwt.SigningMethodHS256, rtClaims)
	rt, err := refreshToken.SignedString([]byte(secret))

	return at, rt, err
}

// HashToken creates a SHA-256 hash of the token
func HashToken(token string) string {
	hash := sha256.Sum256([]byte(token))
	return hex.EncodeToString(hash[:])
}

// ValidateToken validates a JWT token and checks blacklist
func ValidateToken(tokenString string, secret string) (*JWTClaims, error) {
	// Check if token is blacklisted first
	tokenHash := HashToken(tokenString)
	if getTokenBlacklist().IsBlacklisted(tokenHash) {
		return nil, fmt.Errorf("token has been revoked")
	}

	token, err := jwt.ParseWithClaims(tokenString, &JWTClaims{}, func(token *jwt.Token) (interface{}, error) {
		return []byte(secret), nil
	}, jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Alg()}))

	if err != nil {
		return nil, err
	}

	if claims, ok := token.Claims.(*JWTClaims); ok && token.Valid {
		if claims.IssuedAt != nil && isUserRevoked(claims.UserID, claims.IssuedAt.Time) {
			return nil, fmt.Errorf("token has been revoked")
		}
		return claims, nil
	}

	return nil, jwt.ErrSignatureInvalid
}

// RevokeToken adds a token to the blacklist
func RevokeToken(tokenString string, secret string) error {
	// Parse token to get expiry time
	token, err := jwt.ParseWithClaims(tokenString, &JWTClaims{}, func(token *jwt.Token) (interface{}, error) {
		return []byte(secret), nil
	}, jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Alg()}))

	if err != nil {
		// Even if token is invalid, we should still blacklist it
		// Use a default expiry time
		tokenHash := HashToken(tokenString)
		return getTokenBlacklist().Add(tokenHash, time.Now().Add(24*time.Hour))
	}

	if claims, ok := token.Claims.(*JWTClaims); ok {
		tokenHash := HashToken(tokenString)
		expiry := claims.ExpiresAt.Time
		return getTokenBlacklist().Add(tokenHash, expiry)
	}

	return fmt.Errorf("invalid token claims")
}

// RevokeAllUserTokens revokes all tokens for a specific user.
// Any token issued before or exactly at the revocation timestamp is rejected.
func RevokeAllUserTokens(userID uint) error {
	revokedAt := time.Now()

	userRevocationsMu.Lock()
	userRevocations[userID] = revokedAt
	userRevocationsMu.Unlock()

	if Log != nil {
		Info("All user tokens revoked",
			zap.Uint("user_id", userID),
			zap.Time("revoked_at", revokedAt))
	}

	return nil
}

// generateJTI generates a unique JWT ID
func generateJTI() string {
	b := make([]byte, 24)
	if _, err := rand.Read(b); err != nil {
		return fmt.Sprintf("%d", time.Now().UnixNano())
	}

	return base64.RawURLEncoding.EncodeToString(b)
}

// GetBlacklistStats returns statistics about the token blacklist
func GetBlacklistStats() map[string]interface{} {
	blacklist := getTokenBlacklist()
	return map[string]interface{}{
		"size": blacklist.Size(),
		"type": fmt.Sprintf("%T", blacklist),
	}
}
