194 lines
4.6 KiB
Go
194 lines
4.6 KiB
Go
package middleware
|
|
|
|
import (
|
|
"context"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"sync"
|
|
"testing"
|
|
|
|
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
|
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
|
"github.com/gin-gonic/gin"
|
|
)
|
|
|
|
type testLogSink struct {
|
|
mu sync.Mutex
|
|
events []*logger.LogEvent
|
|
}
|
|
|
|
func (s *testLogSink) WriteLogEvent(event *logger.LogEvent) {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
s.events = append(s.events, event)
|
|
}
|
|
|
|
func (s *testLogSink) list() []*logger.LogEvent {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
out := make([]*logger.LogEvent, len(s.events))
|
|
copy(out, s.events)
|
|
return out
|
|
}
|
|
|
|
func initMiddlewareTestLogger(t *testing.T) *testLogSink {
|
|
t.Helper()
|
|
if err := logger.Init(logger.InitOptions{
|
|
Level: "debug",
|
|
Format: "json",
|
|
ServiceName: "sub2api",
|
|
Environment: "test",
|
|
Output: logger.OutputOptions{
|
|
ToStdout: false,
|
|
ToFile: false,
|
|
},
|
|
}); err != nil {
|
|
t.Fatalf("init logger: %v", err)
|
|
}
|
|
sink := &testLogSink{}
|
|
logger.SetSink(sink)
|
|
t.Cleanup(func() {
|
|
logger.SetSink(nil)
|
|
})
|
|
return sink
|
|
}
|
|
|
|
func TestRequestLogger_GenerateAndPropagateRequestID(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
r := gin.New()
|
|
r.Use(RequestLogger())
|
|
r.GET("/t", func(c *gin.Context) {
|
|
reqID, ok := c.Request.Context().Value(ctxkey.RequestID).(string)
|
|
if !ok || reqID == "" {
|
|
t.Fatalf("request_id missing in context")
|
|
}
|
|
if got := c.Writer.Header().Get(requestIDHeader); got != reqID {
|
|
t.Fatalf("response header request_id mismatch, header=%q ctx=%q", got, reqID)
|
|
}
|
|
c.Status(http.StatusOK)
|
|
})
|
|
|
|
w := httptest.NewRecorder()
|
|
req := httptest.NewRequest(http.MethodGet, "/t", nil)
|
|
r.ServeHTTP(w, req)
|
|
if w.Code != http.StatusOK {
|
|
t.Fatalf("status=%d", w.Code)
|
|
}
|
|
if w.Header().Get(requestIDHeader) == "" {
|
|
t.Fatalf("X-Request-ID should be set")
|
|
}
|
|
}
|
|
|
|
func TestRequestLogger_KeepIncomingRequestID(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
r := gin.New()
|
|
r.Use(RequestLogger())
|
|
r.GET("/t", func(c *gin.Context) {
|
|
reqID, _ := c.Request.Context().Value(ctxkey.RequestID).(string)
|
|
if reqID != "rid-fixed" {
|
|
t.Fatalf("request_id=%q, want rid-fixed", reqID)
|
|
}
|
|
c.Status(http.StatusOK)
|
|
})
|
|
|
|
w := httptest.NewRecorder()
|
|
req := httptest.NewRequest(http.MethodGet, "/t", nil)
|
|
req.Header.Set(requestIDHeader, "rid-fixed")
|
|
r.ServeHTTP(w, req)
|
|
if w.Code != http.StatusOK {
|
|
t.Fatalf("status=%d", w.Code)
|
|
}
|
|
if got := w.Header().Get(requestIDHeader); got != "rid-fixed" {
|
|
t.Fatalf("header=%q, want rid-fixed", got)
|
|
}
|
|
}
|
|
|
|
func TestLogger_AccessLogIncludesCoreFields(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
sink := initMiddlewareTestLogger(t)
|
|
|
|
r := gin.New()
|
|
r.Use(Logger())
|
|
r.Use(func(c *gin.Context) {
|
|
ctx := c.Request.Context()
|
|
ctx = context.WithValue(ctx, ctxkey.AccountID, int64(101))
|
|
ctx = context.WithValue(ctx, ctxkey.Platform, "openai")
|
|
ctx = context.WithValue(ctx, ctxkey.Model, "gpt-5")
|
|
c.Request = c.Request.WithContext(ctx)
|
|
c.Next()
|
|
})
|
|
r.GET("/api/test", func(c *gin.Context) {
|
|
c.Status(http.StatusCreated)
|
|
})
|
|
|
|
w := httptest.NewRecorder()
|
|
req := httptest.NewRequest(http.MethodGet, "/api/test", nil)
|
|
r.ServeHTTP(w, req)
|
|
if w.Code != http.StatusCreated {
|
|
t.Fatalf("status=%d", w.Code)
|
|
}
|
|
|
|
events := sink.list()
|
|
if len(events) == 0 {
|
|
t.Fatalf("expected at least one log event")
|
|
}
|
|
found := false
|
|
for _, event := range events {
|
|
if event == nil || event.Message != "http request completed" {
|
|
continue
|
|
}
|
|
found = true
|
|
switch v := event.Fields["status_code"].(type) {
|
|
case int:
|
|
if v != http.StatusCreated {
|
|
t.Fatalf("status_code field mismatch: %v", v)
|
|
}
|
|
case int64:
|
|
if v != int64(http.StatusCreated) {
|
|
t.Fatalf("status_code field mismatch: %v", v)
|
|
}
|
|
default:
|
|
t.Fatalf("status_code type mismatch: %T", v)
|
|
}
|
|
switch v := event.Fields["account_id"].(type) {
|
|
case int64:
|
|
if v != 101 {
|
|
t.Fatalf("account_id field mismatch: %v", v)
|
|
}
|
|
case int:
|
|
if v != 101 {
|
|
t.Fatalf("account_id field mismatch: %v", v)
|
|
}
|
|
default:
|
|
t.Fatalf("account_id type mismatch: %T", v)
|
|
}
|
|
if event.Fields["platform"] != "openai" || event.Fields["model"] != "gpt-5" {
|
|
t.Fatalf("platform/model mismatch: %+v", event.Fields)
|
|
}
|
|
}
|
|
if !found {
|
|
t.Fatalf("access log event not found")
|
|
}
|
|
}
|
|
|
|
func TestLogger_HealthPathSkipped(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
sink := initMiddlewareTestLogger(t)
|
|
|
|
r := gin.New()
|
|
r.Use(Logger())
|
|
r.GET("/health", func(c *gin.Context) {
|
|
c.Status(http.StatusOK)
|
|
})
|
|
|
|
w := httptest.NewRecorder()
|
|
req := httptest.NewRequest(http.MethodGet, "/health", nil)
|
|
r.ServeHTTP(w, req)
|
|
if w.Code != http.StatusOK {
|
|
t.Fatalf("status=%d", w.Code)
|
|
}
|
|
if len(sink.list()) != 0 {
|
|
t.Fatalf("health endpoint should not write access log")
|
|
}
|
|
}
|