package middleware import ( "context" "net/http" "net/http/httptest" "sync" "testing" "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" "github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/gin-gonic/gin" ) type testLogSink struct { mu sync.Mutex events []*logger.LogEvent } func (s *testLogSink) WriteLogEvent(event *logger.LogEvent) { s.mu.Lock() defer s.mu.Unlock() s.events = append(s.events, event) } func (s *testLogSink) list() []*logger.LogEvent { s.mu.Lock() defer s.mu.Unlock() out := make([]*logger.LogEvent, len(s.events)) copy(out, s.events) return out } func initMiddlewareTestLogger(t *testing.T) *testLogSink { t.Helper() if err := logger.Init(logger.InitOptions{ Level: "debug", Format: "json", ServiceName: "sub2api", Environment: "test", Output: logger.OutputOptions{ ToStdout: false, ToFile: false, }, }); err != nil { t.Fatalf("init logger: %v", err) } sink := &testLogSink{} logger.SetSink(sink) t.Cleanup(func() { logger.SetSink(nil) }) return sink } func TestRequestLogger_GenerateAndPropagateRequestID(t *testing.T) { gin.SetMode(gin.TestMode) r := gin.New() r.Use(RequestLogger()) r.GET("/t", func(c *gin.Context) { reqID, ok := c.Request.Context().Value(ctxkey.RequestID).(string) if !ok || reqID == "" { t.Fatalf("request_id missing in context") } if got := c.Writer.Header().Get(requestIDHeader); got != reqID { t.Fatalf("response header request_id mismatch, header=%q ctx=%q", got, reqID) } c.Status(http.StatusOK) }) w := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/t", nil) r.ServeHTTP(w, req) if w.Code != http.StatusOK { t.Fatalf("status=%d", w.Code) } if w.Header().Get(requestIDHeader) == "" { t.Fatalf("X-Request-ID should be set") } } func TestRequestLogger_KeepIncomingRequestID(t *testing.T) { gin.SetMode(gin.TestMode) r := gin.New() r.Use(RequestLogger()) r.GET("/t", func(c *gin.Context) { reqID, _ := c.Request.Context().Value(ctxkey.RequestID).(string) if reqID != "rid-fixed" { t.Fatalf("request_id=%q, want rid-fixed", reqID) } c.Status(http.StatusOK) }) w := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/t", nil) req.Header.Set(requestIDHeader, "rid-fixed") r.ServeHTTP(w, req) if w.Code != http.StatusOK { t.Fatalf("status=%d", w.Code) } if got := w.Header().Get(requestIDHeader); got != "rid-fixed" { t.Fatalf("header=%q, want rid-fixed", got) } } func TestLogger_AccessLogIncludesCoreFields(t *testing.T) { gin.SetMode(gin.TestMode) sink := initMiddlewareTestLogger(t) r := gin.New() r.Use(Logger()) r.Use(func(c *gin.Context) { ctx := c.Request.Context() ctx = context.WithValue(ctx, ctxkey.AccountID, int64(101)) ctx = context.WithValue(ctx, ctxkey.Platform, "openai") ctx = context.WithValue(ctx, ctxkey.Model, "gpt-5") c.Request = c.Request.WithContext(ctx) c.Next() }) r.GET("/api/test", func(c *gin.Context) { c.Status(http.StatusCreated) }) w := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/api/test", nil) r.ServeHTTP(w, req) if w.Code != http.StatusCreated { t.Fatalf("status=%d", w.Code) } events := sink.list() if len(events) == 0 { t.Fatalf("expected at least one log event") } found := false for _, event := range events { if event == nil || event.Message != "http request completed" { continue } found = true switch v := event.Fields["status_code"].(type) { case int: if v != http.StatusCreated { t.Fatalf("status_code field mismatch: %v", v) } case int64: if v != int64(http.StatusCreated) { t.Fatalf("status_code field mismatch: %v", v) } default: t.Fatalf("status_code type mismatch: %T", v) } switch v := event.Fields["account_id"].(type) { case int64: if v != 101 { t.Fatalf("account_id field mismatch: %v", v) } case int: if v != 101 { t.Fatalf("account_id field mismatch: %v", v) } default: t.Fatalf("account_id type mismatch: %T", v) } if event.Fields["platform"] != "openai" || event.Fields["model"] != "gpt-5" { t.Fatalf("platform/model mismatch: %+v", event.Fields) } } if !found { t.Fatalf("access log event not found") } } func TestLogger_HealthPathSkipped(t *testing.T) { gin.SetMode(gin.TestMode) sink := initMiddlewareTestLogger(t) r := gin.New() r.Use(Logger()) r.GET("/health", func(c *gin.Context) { c.Status(http.StatusOK) }) w := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/health", nil) r.ServeHTTP(w, req) if w.Code != http.StatusOK { t.Fatalf("status=%d", w.Code) } if len(sink.list()) != 0 { t.Fatalf("health endpoint should not write access log") } }