package middleware

import (
	"net/http"
	"net/http/httptest"
	"strings"
	"testing"

	"system-altrak/internal/config"
	"system-altrak/pkg/utils"

	"github.com/gofiber/fiber/v2"
)

func getCookieValue(resp *http.Response, cookieName string) string {
	for _, cookie := range resp.Cookies() {
		if cookie.Name == cookieName {
			return cookie.Value
		}
	}
	return ""
}

func TestRoleRequiredMissingRole(t *testing.T) {
	app := fiber.New()
	app.Get("/secure", RoleRequired("admin"), func(c *fiber.Ctx) error {
		return c.SendStatus(fiber.StatusOK)
	})

	req := httptest.NewRequest(http.MethodGet, "/secure", nil)
	resp, err := app.Test(req)
	if err != nil {
		t.Fatalf("request failed: %v", err)
	}

	if resp.StatusCode != fiber.StatusUnauthorized {
		t.Fatalf("expected status %d, got %d", fiber.StatusUnauthorized, resp.StatusCode)
	}
}

func TestRoleRequiredAllowedRole(t *testing.T) {
	app := fiber.New()
	app.Use(func(c *fiber.Ctx) error {
		c.Locals("role", "admin")
		return c.Next()
	})
	app.Get("/secure", RoleRequired("admin"), func(c *fiber.Ctx) error {
		return c.SendStatus(fiber.StatusOK)
	})

	req := httptest.NewRequest(http.MethodGet, "/secure", nil)
	resp, err := app.Test(req)
	if err != nil {
		t.Fatalf("request failed: %v", err)
	}

	if resp.StatusCode != fiber.StatusOK {
		t.Fatalf("expected status %d, got %d", fiber.StatusOK, resp.StatusCode)
	}
}

func TestOptionalJWTMiddlewareRejectsInvalidBearerToken(t *testing.T) {
	secret := "test-secret-with-minimum-32-characters-value"
	cfg := &config.Config{JWTSecret: secret}

	app := fiber.New()
	app.Get("/csrf-token", OptionalJWTMiddleware(cfg), func(c *fiber.Ctx) error {
		return c.SendStatus(fiber.StatusOK)
	})

	req := httptest.NewRequest(http.MethodGet, "/csrf-token", nil)
	req.Header.Set("Authorization", "Bearer invalid.token.value")
	resp, err := app.Test(req)
	if err != nil {
		t.Fatalf("request failed: %v", err)
	}

	if resp.StatusCode != fiber.StatusUnauthorized {
		t.Fatalf("expected status %d, got %d", fiber.StatusUnauthorized, resp.StatusCode)
	}
}

func TestOptionalJWTMiddlewareAllowsAnonymousRequest(t *testing.T) {
	secret := "test-secret-with-minimum-32-characters-value"
	cfg := &config.Config{JWTSecret: secret}

	app := fiber.New()
	app.Get("/csrf-token", OptionalJWTMiddleware(cfg), func(c *fiber.Ctx) error {
		return c.SendStatus(fiber.StatusOK)
	})

	req := httptest.NewRequest(http.MethodGet, "/csrf-token", nil)
	resp, err := app.Test(req)
	if err != nil {
		t.Fatalf("request failed: %v", err)
	}

	if resp.StatusCode != fiber.StatusOK {
		t.Fatalf("expected status %d, got %d", fiber.StatusOK, resp.StatusCode)
	}
}

func TestOptionalJWTMiddlewareAnonymousAllowsInvalidBearerToken(t *testing.T) {
	secret := "test-secret-with-minimum-32-characters-value"
	cfg := &config.Config{JWTSecret: secret}

	app := fiber.New()
	app.Get("/csrf-token", OptionalJWTMiddlewareAnonymous(cfg), func(c *fiber.Ctx) error {
		return c.SendStatus(fiber.StatusOK)
	})

	req := httptest.NewRequest(http.MethodGet, "/csrf-token", nil)
	req.Header.Set("Authorization", "Bearer invalid.token.value")
	resp, err := app.Test(req)
	if err != nil {
		t.Fatalf("request failed: %v", err)
	}

	if resp.StatusCode != fiber.StatusOK {
		t.Fatalf("expected status %d, got %d", fiber.StatusOK, resp.StatusCode)
	}
}

func TestJWTMiddlewareRejectsQueryTokenOnAPIRoute(t *testing.T) {
	secret := "test-secret-with-minimum-32-characters-value"
	cfg := &config.Config{JWTSecret: secret}
	accessToken, _, err := utils.GenerateToken(3001, "admin", "admin", 1, secret)
	if err != nil {
		t.Fatalf("failed to generate token: %v", err)
	}

	app := fiber.New()
	app.Use(JWTMiddleware(cfg))
	app.Get("/resource", func(c *fiber.Ctx) error {
		return c.SendStatus(fiber.StatusOK)
	})

	req := httptest.NewRequest(http.MethodGet, "/resource?token="+accessToken, nil)
	resp, err := app.Test(req)
	if err != nil {
		t.Fatalf("request failed: %v", err)
	}

	if resp.StatusCode != fiber.StatusUnauthorized {
		t.Fatalf("expected status %d, got %d", fiber.StatusUnauthorized, resp.StatusCode)
	}
}

func TestJWTMiddlewareRejectsQueryTokenForWebSocketPath(t *testing.T) {
	secret := "test-secret-with-minimum-32-characters-value"
	cfg := &config.Config{JWTSecret: secret}
	accessToken, _, err := utils.GenerateToken(3002, "admin", "admin", 1, secret)
	if err != nil {
		t.Fatalf("failed to generate token: %v", err)
	}

	app := fiber.New()
	app.Use(JWTMiddleware(cfg))
	app.Get("/ws", func(c *fiber.Ctx) error {
		return c.SendStatus(fiber.StatusOK)
	})

	req := httptest.NewRequest(http.MethodGet, "/ws?token="+accessToken, nil)
	resp, err := app.Test(req)
	if err != nil {
		t.Fatalf("request failed: %v", err)
	}

	if resp.StatusCode != fiber.StatusUnauthorized {
		t.Fatalf("expected status %d, got %d", fiber.StatusUnauthorized, resp.StatusCode)
	}
}

func TestCreateEditModuleRequiredAllowsAdminForConfiguredModule(t *testing.T) {
	app := fiber.New()
	app.Use(func(c *fiber.Ctx) error {
		c.Locals("role", "admin")
		return c.Next()
	})
	app.Post("/customer-profile", CreateEditModuleRequired("customer-profile"), func(c *fiber.Ctx) error {
		return c.SendStatus(fiber.StatusOK)
	})

	req := httptest.NewRequest(http.MethodPost, "/customer-profile", nil)
	resp, err := app.Test(req)
	if err != nil {
		t.Fatalf("request failed: %v", err)
	}

	if resp.StatusCode != fiber.StatusOK {
		t.Fatalf("expected status %d, got %d", fiber.StatusOK, resp.StatusCode)
	}
}

func TestCreateEditModuleRequiredDeniesAdminForNonConfiguredModule(t *testing.T) {
	app := fiber.New()
	app.Use(func(c *fiber.Ctx) error {
		c.Locals("role", "admin")
		return c.Next()
	})
	app.Post("/settings", CreateEditModuleRequired("settings"), func(c *fiber.Ctx) error {
		return c.SendStatus(fiber.StatusOK)
	})

	req := httptest.NewRequest(http.MethodPost, "/settings", nil)
	resp, err := app.Test(req)
	if err != nil {
		t.Fatalf("request failed: %v", err)
	}

	if resp.StatusCode != fiber.StatusForbidden {
		t.Fatalf("expected status %d, got %d", fiber.StatusForbidden, resp.StatusCode)
	}
}

func TestCSRFMiddlewareTokenRotation(t *testing.T) {
	secret := "test-secret-with-minimum-32-characters-value"
	cfg := &config.Config{JWTSecret: secret}
	SetCSRFSecret(secret)
	accessToken, _, err := utils.GenerateToken(1001, "admin", "admin", 1, secret)
	if err != nil {
		t.Fatalf("failed to generate token: %v", err)
	}

	app := fiber.New()
	app.Use(JWTMiddleware(cfg))
	app.Use(CSRFMiddleware())
	app.Get("/resource", func(c *fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) })
	app.Post("/resource", func(c *fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) })

	getReq := httptest.NewRequest(http.MethodGet, "/resource", nil)
	getReq.Header.Set("Authorization", "Bearer "+accessToken)
	getResp, err := app.Test(getReq)
	if err != nil {
		t.Fatalf("GET request failed: %v", err)
	}

	csrfToken := getResp.Header.Get("X-CSRF-Token")
	csrfCookie := getCookieValue(getResp, "csrf_token")
	if csrfToken == "" || csrfCookie == "" {
		t.Fatal("expected CSRF token and cookie to be issued")
	}

	postReq := httptest.NewRequest(http.MethodPost, "/resource", nil)
	postReq.Header.Set("Authorization", "Bearer "+accessToken)
	postReq.Header.Set("X-CSRF-Token", csrfToken)
	postReq.Header.Set("Cookie", "csrf_token="+csrfCookie)
	postResp, err := app.Test(postReq)
	if err != nil {
		t.Fatalf("POST request failed: %v", err)
	}

	if postResp.StatusCode != fiber.StatusOK {
		t.Fatalf("expected status %d, got %d", fiber.StatusOK, postResp.StatusCode)
	}

	rotatedToken := postResp.Header.Get("X-CSRF-Token")
	rotatedCookie := getCookieValue(postResp, "csrf_token")
	if rotatedToken == "" || rotatedCookie == "" {
		t.Fatal("expected rotated CSRF token and cookie after state-changing request")
	}
	if rotatedToken == csrfToken {
		t.Fatal("expected CSRF token to rotate after state-changing request")
	}

	postReqReplay := httptest.NewRequest(http.MethodPost, "/resource", nil)
	postReqReplay.Header.Set("Authorization", "Bearer "+accessToken)
	postReqReplay.Header.Set("X-CSRF-Token", rotatedToken)
	postReqReplay.Header.Set("Cookie", "csrf_token="+rotatedCookie)
	postReplayResp, err := app.Test(postReqReplay)
	if err != nil {
		t.Fatalf("replay POST request failed: %v", err)
	}

	if postReplayResp.StatusCode != fiber.StatusOK {
		t.Fatalf("expected replay status %d, got %d", fiber.StatusOK, postReplayResp.StatusCode)
	}
}

func TestCSRFMiddlewareOwnerBinding(t *testing.T) {
	secret := "test-secret-with-minimum-32-characters-value"
	cfg := &config.Config{JWTSecret: secret}
	SetCSRFSecret(secret)
	accessTokenUser1, _, err := utils.GenerateToken(2001, "admin", "admin", 1, secret)
	if err != nil {
		t.Fatalf("failed to generate user1 token: %v", err)
	}
	accessTokenUser2, _, err := utils.GenerateToken(2002, "admin", "admin", 1, secret)
	if err != nil {
		t.Fatalf("failed to generate user2 token: %v", err)
	}

	app := fiber.New()
	app.Use(JWTMiddleware(cfg))
	app.Use(CSRFMiddleware())
	app.Get("/resource", func(c *fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) })
	app.Post("/resource", func(c *fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) })

	getReq := httptest.NewRequest(http.MethodGet, "/resource", nil)
	getReq.Header.Set("Authorization", "Bearer "+accessTokenUser1)
	getResp, err := app.Test(getReq)
	if err != nil {
		t.Fatalf("GET request failed: %v", err)
	}

	csrfToken := getResp.Header.Get("X-CSRF-Token")
	csrfCookie := getCookieValue(getResp, "csrf_token")

	postReq := httptest.NewRequest(http.MethodPost, "/resource", nil)
	postReq.Header.Set("Authorization", "Bearer "+accessTokenUser2)
	postReq.Header.Set("X-CSRF-Token", csrfToken)
	postReq.Header.Set("Cookie", "csrf_token="+csrfCookie)
	postResp, err := app.Test(postReq)
	if err != nil {
		t.Fatalf("POST request failed: %v", err)
	}

	if postResp.StatusCode != fiber.StatusForbidden {
		t.Fatalf("expected status %d, got %d", fiber.StatusForbidden, postResp.StatusCode)
	}
}

func TestSecurityHeadersMiddlewareSetsCSPHeaders(t *testing.T) {
	app := fiber.New()
	app.Use(SecurityHeadersMiddleware())
	app.Get("/", func(c *fiber.Ctx) error {
		return c.SendStatus(fiber.StatusOK)
	})

	req := httptest.NewRequest(http.MethodGet, "http://example.test/", nil)
	resp, err := app.Test(req)
	if err != nil {
		t.Fatalf("request failed: %v", err)
	}

	csp := resp.Header.Get("Content-Security-Policy")
	if csp == "" {
		t.Fatal("expected Content-Security-Policy header")
	}
	if !strings.Contains(csp, "object-src 'none'") {
		t.Fatalf("expected enforced CSP to contain object-src none, got: %s", csp)
	}
	// Next.js static export requires 'unsafe-inline' for hydration scripts/styles.
	if !strings.Contains(csp, "'unsafe-inline'") {
		t.Fatalf("expected enforced CSP to include unsafe-inline for Next.js compatibility, got: %s", csp)
	}
	if strings.Contains(csp, "cdn.tailwindcss.com") {
		t.Fatalf("expected enforced CSP without tailwind CDN, got: %s", csp)
	}

	reportOnly := resp.Header.Get("Content-Security-Policy-Report-Only")
	if reportOnly == "" {
		t.Fatal("expected Content-Security-Policy-Report-Only header")
	}
	if !strings.Contains(reportOnly, "'unsafe-inline'") {
		t.Fatalf("expected report-only CSP to include unsafe-inline, got: %s", reportOnly)
	}
	if !strings.Contains(reportOnly, "object-src 'none'") {
		t.Fatalf("expected report-only CSP to contain object-src none, got: %s", reportOnly)
	}
	if strings.Contains(reportOnly, "cdn.tailwindcss.com") {
		t.Fatalf("expected report-only CSP without tailwind CDN, got: %s", reportOnly)
	}
	if !strings.Contains(reportOnly, "report-uri /api/security/csp-report") {
		t.Fatalf("expected report-only CSP report-uri endpoint, got: %s", reportOnly)
	}

	if got := resp.Header.Get("Strict-Transport-Security"); got != "" {
		t.Fatalf("expected no HSTS header on HTTP request, got: %s", got)
	}
}

func TestSecurityHeadersMiddlewareSetsHSTSForHTTPS(t *testing.T) {
	app := fiber.New()
	app.Use(SecurityHeadersMiddleware())
	app.Get("/", func(c *fiber.Ctx) error {
		return c.SendStatus(fiber.StatusOK)
	})

	req := httptest.NewRequest(http.MethodGet, "http://example.test/", nil)
	req.Header.Set("X-Forwarded-Proto", "https")
	resp, err := app.Test(req)
	if err != nil {
		t.Fatalf("request failed: %v", err)
	}

	sts := resp.Header.Get("Strict-Transport-Security")
	if sts == "" {
		t.Fatal("expected Strict-Transport-Security header on HTTPS request")
	}
}
