diff --git a/backend/internal/repository/claude_oauth_service_test.go b/backend/internal/repository/claude_oauth_service_test.go index 3295c222..a7f76056 100644 --- a/backend/internal/repository/claude_oauth_service_test.go +++ b/backend/internal/repository/claude_oauth_service_test.go @@ -5,28 +5,20 @@ import ( "encoding/json" "io" "net/http" - "net/http/httptest" "strings" "testing" "github.com/Wei-Shaw/sub2api/internal/pkg/oauth" + "github.com/imroc/req/v3" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" ) type ClaudeOAuthServiceSuite struct { suite.Suite - srv *httptest.Server client *claudeOAuthService } -func (s *ClaudeOAuthServiceSuite) TearDownTest() { - if s.srv != nil { - s.srv.Close() - s.srv = nil - } -} - // requestCapture holds captured request data for assertions in the main goroutine. type requestCapture struct { path string @@ -37,6 +29,12 @@ type requestCapture struct { contentType string } +func newTestReqClient(rt http.RoundTripper) *req.Client { + c := req.C() + c.GetClient().Transport = rt + return c +} + func (s *ClaudeOAuthServiceSuite) TestGetOrganizationUUID() { tests := []struct { name string @@ -83,17 +81,17 @@ func (s *ClaudeOAuthServiceSuite) TestGetOrganizationUUID() { s.Run(tt.name, func() { var captured requestCapture - s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + rt := newInProcessTransport(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { captured.path = r.URL.Path captured.cookies = r.Cookies() tt.handler(w, r) - })) - defer s.srv.Close() + }), nil) client, ok := NewClaudeOAuthClient().(*claudeOAuthService) require.True(s.T(), ok, "type assertion failed") s.client = client - s.client.baseURL = s.srv.URL + s.client.baseURL = "http://in-process" + s.client.clientFactory = func(string) *req.Client { return newTestReqClient(rt) } got, err := s.client.GetOrganizationUUID(context.Background(), "sess", "") @@ -158,20 +156,20 @@ func (s *ClaudeOAuthServiceSuite) TestGetAuthorizationCode() { s.Run(tt.name, func() { var captured requestCapture - s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + rt := newInProcessTransport(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { captured.path = r.URL.Path captured.method = r.Method captured.cookies = r.Cookies() captured.body, _ = io.ReadAll(r.Body) _ = json.Unmarshal(captured.body, &captured.bodyJSON) tt.handler(w, r) - })) - defer s.srv.Close() + }), nil) client, ok := NewClaudeOAuthClient().(*claudeOAuthService) require.True(s.T(), ok, "type assertion failed") s.client = client - s.client.baseURL = s.srv.URL + s.client.baseURL = "http://in-process" + s.client.clientFactory = func(string) *req.Client { return newTestReqClient(rt) } code, err := s.client.GetAuthorizationCode(context.Background(), "sess", "org-1", oauth.ScopeProfile, "cc", "st", "") @@ -266,19 +264,19 @@ func (s *ClaudeOAuthServiceSuite) TestExchangeCodeForToken() { s.Run(tt.name, func() { var captured requestCapture - s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + rt := newInProcessTransport(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { captured.method = r.Method captured.contentType = r.Header.Get("Content-Type") captured.body, _ = io.ReadAll(r.Body) _ = json.Unmarshal(captured.body, &captured.bodyJSON) tt.handler(w, r) - })) - defer s.srv.Close() + }), nil) client, ok := NewClaudeOAuthClient().(*claudeOAuthService) require.True(s.T(), ok, "type assertion failed") s.client = client - s.client.tokenURL = s.srv.URL + s.client.tokenURL = "http://in-process/token" + s.client.clientFactory = func(string) *req.Client { return newTestReqClient(rt) } resp, err := s.client.ExchangeCodeForToken(context.Background(), tt.code, "ver", "", "", tt.isSetupToken) @@ -362,19 +360,19 @@ func (s *ClaudeOAuthServiceSuite) TestRefreshToken() { s.Run(tt.name, func() { var captured requestCapture - s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + rt := newInProcessTransport(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { captured.method = r.Method captured.contentType = r.Header.Get("Content-Type") captured.body, _ = io.ReadAll(r.Body) _ = json.Unmarshal(captured.body, &captured.bodyJSON) tt.handler(w, r) - })) - defer s.srv.Close() + }), nil) client, ok := NewClaudeOAuthClient().(*claudeOAuthService) require.True(s.T(), ok, "type assertion failed") s.client = client - s.client.tokenURL = s.srv.URL + s.client.tokenURL = "http://in-process/token" + s.client.clientFactory = func(string) *req.Client { return newTestReqClient(rt) } resp, err := s.client.RefreshToken(context.Background(), "rt", "") diff --git a/backend/internal/repository/claude_usage_service_test.go b/backend/internal/repository/claude_usage_service_test.go index 11097b67..c3570076 100644 --- a/backend/internal/repository/claude_usage_service_test.go +++ b/backend/internal/repository/claude_usage_service_test.go @@ -33,7 +33,7 @@ type usageRequestCapture struct { func (s *ClaudeUsageServiceSuite) TestFetchUsage_Success() { var captured usageRequestCapture - s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { captured.authorization = r.Header.Get("Authorization") captured.anthropicBeta = r.Header.Get("anthropic-beta") @@ -59,7 +59,7 @@ func (s *ClaudeUsageServiceSuite) TestFetchUsage_Success() { } func (s *ClaudeUsageServiceSuite) TestFetchUsage_NonOK() { - s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusUnauthorized) _, _ = io.WriteString(w, "nope") })) @@ -73,7 +73,7 @@ func (s *ClaudeUsageServiceSuite) TestFetchUsage_NonOK() { } func (s *ClaudeUsageServiceSuite) TestFetchUsage_BadJSON() { - s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") _, _ = io.WriteString(w, "not-json") })) @@ -86,7 +86,7 @@ func (s *ClaudeUsageServiceSuite) TestFetchUsage_BadJSON() { } func (s *ClaudeUsageServiceSuite) TestFetchUsage_ContextCancel() { - s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Never respond - simulate slow server <-r.Context().Done() })) diff --git a/backend/internal/repository/github_release_service_test.go b/backend/internal/repository/github_release_service_test.go index bf2efd8d..ea849d46 100644 --- a/backend/internal/repository/github_release_service_test.go +++ b/backend/internal/repository/github_release_service_test.go @@ -49,7 +49,7 @@ func (s *GitHubReleaseServiceSuite) TearDownTest() { } func (s *GitHubReleaseServiceSuite) TestDownloadFile_EnforcesMaxSize_ContentLength() { - s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Length", "100") w.WriteHeader(http.StatusOK) _, _ = w.Write(bytes.Repeat([]byte("a"), 100)) @@ -68,7 +68,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_EnforcesMaxSize_ContentLeng } func (s *GitHubReleaseServiceSuite) TestDownloadFile_EnforcesMaxSize_Chunked() { - s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Force chunked encoding (unknown Content-Length) by flushing headers before writing. w.WriteHeader(http.StatusOK) if fl, ok := w.(http.Flusher); ok { @@ -95,7 +95,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_EnforcesMaxSize_Chunked() { } func (s *GitHubReleaseServiceSuite) TestDownloadFile_Success() { - s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) if fl, ok := w.(http.Flusher); ok { fl.Flush() @@ -123,7 +123,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_Success() { } func (s *GitHubReleaseServiceSuite) TestDownloadFile_404() { - s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusNotFound) })) @@ -140,7 +140,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_404() { } func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_Success() { - s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) _, _ = w.Write([]byte("sum")) })) @@ -155,7 +155,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_Success() { } func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_Non200() { - s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusInternalServerError) })) @@ -168,7 +168,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_Non200() { } func (s *GitHubReleaseServiceSuite) TestDownloadFile_ContextCancel() { - s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { <-r.Context().Done() })) @@ -195,7 +195,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_InvalidURL() { } func (s *GitHubReleaseServiceSuite) TestDownloadFile_InvalidDestPath() { - s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) _, _ = w.Write([]byte("content")) })) @@ -233,7 +233,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_Success() { ] }` - s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { require.Equal(s.T(), "/repos/test/repo/releases/latest", r.URL.Path) require.Equal(s.T(), "application/vnd.github.v3+json", r.Header.Get("Accept")) require.Equal(s.T(), "Sub2API-Updater", r.Header.Get("User-Agent")) @@ -258,7 +258,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_Success() { } func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_Non200() { - s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusNotFound) })) @@ -274,7 +274,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_Non200() { } func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_InvalidJSON() { - s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) _, _ = w.Write([]byte("not valid json")) })) @@ -290,7 +290,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_InvalidJSON() { } func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_ContextCancel() { - s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { <-r.Context().Done() })) @@ -308,7 +308,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_ContextCancel() { } func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_ContextCancel() { - s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { <-r.Context().Done() })) diff --git a/backend/internal/repository/http_upstream_test.go b/backend/internal/repository/http_upstream_test.go index 70676b7a..241b490f 100644 --- a/backend/internal/repository/http_upstream_test.go +++ b/backend/internal/repository/http_upstream_test.go @@ -3,7 +3,6 @@ package repository import ( "io" "net/http" - "net/http/httptest" "sync/atomic" "testing" "time" @@ -93,7 +92,7 @@ func (s *HTTPUpstreamSuite) TestAcquireClient_OverLimitReturnsError() { // 验证空代理 URL 时请求直接发送到目标服务器 func (s *HTTPUpstreamSuite) TestDo_WithoutProxy_GoesDirect() { // 创建模拟上游服务器 - upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + upstream := newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { _, _ = io.WriteString(w, "direct") })) s.T().Cleanup(upstream.Close) @@ -115,7 +114,7 @@ func (s *HTTPUpstreamSuite) TestDo_WithHTTPProxy_UsesProxy() { // 用于接收代理请求的通道 seen := make(chan string, 1) // 创建模拟代理服务器 - proxySrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + proxySrv := newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { seen <- r.RequestURI // 记录请求 URI _, _ = io.WriteString(w, "proxied") })) @@ -145,7 +144,7 @@ func (s *HTTPUpstreamSuite) TestDo_WithHTTPProxy_UsesProxy() { // TestDo_EmptyProxy_UsesDirect 测试空代理字符串 // 验证空字符串代理等同于直连 func (s *HTTPUpstreamSuite) TestDo_EmptyProxy_UsesDirect() { - upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + upstream := newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { _, _ = io.WriteString(w, "direct-empty") })) s.T().Cleanup(upstream.Close) diff --git a/backend/internal/repository/inprocess_transport_test.go b/backend/internal/repository/inprocess_transport_test.go new file mode 100644 index 00000000..fbdf2c81 --- /dev/null +++ b/backend/internal/repository/inprocess_transport_test.go @@ -0,0 +1,63 @@ +package repository + +import ( + "bytes" + "io" + "net" + "net/http" + "net/http/httptest" + "sync" + "testing" +) + +type roundTripFunc func(*http.Request) (*http.Response, error) + +func (f roundTripFunc) RoundTrip(r *http.Request) (*http.Response, error) { return f(r) } + +// newInProcessTransport adapts an http.HandlerFunc into an http.RoundTripper without opening sockets. +// It captures the request body (if any) and then rewinds it before invoking the handler. +func newInProcessTransport(handler http.HandlerFunc, capture func(r *http.Request, body []byte)) http.RoundTripper { + return roundTripFunc(func(r *http.Request) (*http.Response, error) { + var body []byte + if r.Body != nil { + body, _ = io.ReadAll(r.Body) + _ = r.Body.Close() + r.Body = io.NopCloser(bytes.NewReader(body)) + } + if capture != nil { + capture(r, body) + } + + rec := httptest.NewRecorder() + handler(rec, r) + return rec.Result(), nil + }) +} + +var ( + canListenOnce sync.Once + canListen bool + canListenErr error +) + +func localListenerAvailable() bool { + canListenOnce.Do(func() { + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + canListenErr = err + canListen = false + return + } + _ = ln.Close() + canListen = true + }) + return canListen +} + +func newLocalTestServer(tb testing.TB, handler http.Handler) *httptest.Server { + tb.Helper() + if !localListenerAvailable() { + tb.Skipf("local listeners are not permitted in this environment: %v", canListenErr) + } + return httptest.NewServer(handler) +} diff --git a/backend/internal/repository/openai_oauth_service_test.go b/backend/internal/repository/openai_oauth_service_test.go index 0a5322d7..51142306 100644 --- a/backend/internal/repository/openai_oauth_service_test.go +++ b/backend/internal/repository/openai_oauth_service_test.go @@ -34,7 +34,7 @@ func (s *OpenAIOAuthServiceSuite) TearDownTest() { } func (s *OpenAIOAuthServiceSuite) setupServer(handler http.HandlerFunc) { - s.srv = httptest.NewServer(handler) + s.srv = newLocalTestServer(s.T(), handler) s.svc = &openaiOAuthService{tokenURL: s.srv.URL} } diff --git a/backend/internal/repository/pricing_service_test.go b/backend/internal/repository/pricing_service_test.go index c51317a4..112c7eaa 100644 --- a/backend/internal/repository/pricing_service_test.go +++ b/backend/internal/repository/pricing_service_test.go @@ -32,7 +32,7 @@ func (s *PricingServiceSuite) TearDownTest() { } func (s *PricingServiceSuite) setupServer(handler http.HandlerFunc) { - s.srv = httptest.NewServer(handler) + s.srv = newLocalTestServer(s.T(), handler) } func (s *PricingServiceSuite) TestFetchPricingJSON_Success() { diff --git a/backend/internal/repository/proxy_probe_service_test.go b/backend/internal/repository/proxy_probe_service_test.go index 74d99c6d..e7270324 100644 --- a/backend/internal/repository/proxy_probe_service_test.go +++ b/backend/internal/repository/proxy_probe_service_test.go @@ -31,7 +31,7 @@ func (s *ProxyProbeServiceSuite) TearDownTest() { } func (s *ProxyProbeServiceSuite) setupProxyServer(handler http.HandlerFunc) { - s.proxySrv = httptest.NewServer(handler) + s.proxySrv = newLocalTestServer(s.T(), handler) } func (s *ProxyProbeServiceSuite) TestProbeProxy_InvalidProxyURL() { diff --git a/backend/internal/repository/turnstile_service_test.go b/backend/internal/repository/turnstile_service_test.go index 3876a007..83e0839a 100644 --- a/backend/internal/repository/turnstile_service_test.go +++ b/backend/internal/repository/turnstile_service_test.go @@ -3,9 +3,9 @@ package repository import ( "context" "encoding/json" + "errors" "io" "net/http" - "net/http/httptest" "net/url" "strings" "testing" @@ -18,7 +18,6 @@ import ( type TurnstileServiceSuite struct { suite.Suite ctx context.Context - srv *httptest.Server verifier *turnstileVerifier received chan url.Values } @@ -31,20 +30,15 @@ func (s *TurnstileServiceSuite) SetupTest() { s.verifier = verifier } -func (s *TurnstileServiceSuite) TearDownTest() { - if s.srv != nil { - s.srv.Close() - s.srv = nil +func (s *TurnstileServiceSuite) setupTransport(handler http.HandlerFunc) { + s.verifier.verifyURL = "http://in-process/turnstile" + s.verifier.httpClient = &http.Client{ + Transport: newInProcessTransport(handler, nil), } } -func (s *TurnstileServiceSuite) setupServer(handler http.HandlerFunc) { - s.srv = httptest.NewServer(handler) - s.verifier.verifyURL = s.srv.URL -} - func (s *TurnstileServiceSuite) TestVerifyToken_SendsFormAndDecodesJSON() { - s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + s.setupTransport(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Capture form data in main goroutine context later body, _ := io.ReadAll(r.Body) values, _ := url.ParseQuery(string(body)) @@ -72,7 +66,7 @@ func (s *TurnstileServiceSuite) TestVerifyToken_SendsFormAndDecodesJSON() { func (s *TurnstileServiceSuite) TestVerifyToken_ContentType() { var contentType string - s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + s.setupTransport(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { contentType = r.Header.Get("Content-Type") w.Header().Set("Content-Type", "application/json") _ = json.NewEncoder(w).Encode(service.TurnstileVerifyResponse{Success: true}) @@ -84,7 +78,7 @@ func (s *TurnstileServiceSuite) TestVerifyToken_ContentType() { } func (s *TurnstileServiceSuite) TestVerifyToken_EmptyRemoteIP_NotSent() { - s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + s.setupTransport(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { body, _ := io.ReadAll(r.Body) values, _ := url.ParseQuery(string(body)) s.received <- values @@ -105,15 +99,19 @@ func (s *TurnstileServiceSuite) TestVerifyToken_EmptyRemoteIP_NotSent() { } func (s *TurnstileServiceSuite) TestVerifyToken_RequestError() { - s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) - s.srv.Close() + s.verifier.verifyURL = "http://in-process/turnstile" + s.verifier.httpClient = &http.Client{ + Transport: roundTripFunc(func(*http.Request) (*http.Response, error) { + return nil, errors.New("dial failed") + }), + } _, err := s.verifier.VerifyToken(s.ctx, "sk", "token", "1.1.1.1") require.Error(s.T(), err, "expected error when server is closed") } func (s *TurnstileServiceSuite) TestVerifyToken_InvalidJSON() { - s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + s.setupTransport(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") _, _ = io.WriteString(w, "not-valid-json") })) @@ -123,7 +121,7 @@ func (s *TurnstileServiceSuite) TestVerifyToken_InvalidJSON() { } func (s *TurnstileServiceSuite) TestVerifyToken_SuccessFalse() { - s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + s.setupTransport(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") _ = json.NewEncoder(w).Encode(service.TurnstileVerifyResponse{ Success: false, diff --git a/backend/internal/repository/usage_log_repo.go b/backend/internal/repository/usage_log_repo.go index 0371ad0d..51abdbb0 100644 --- a/backend/internal/repository/usage_log_repo.go +++ b/backend/internal/repository/usage_log_repo.go @@ -3,6 +3,7 @@ package repository import ( "context" "database/sql" + "errors" "fmt" "os" "strings" @@ -70,6 +71,9 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) createdAt = time.Now() } + requestID := strings.TrimSpace(log.RequestID) + log.RequestID = requestID + rateMultiplier := log.RateMultiplier query := ` @@ -107,6 +111,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24, $25 ) + ON CONFLICT (request_id, api_key_id) DO NOTHING RETURNING id, created_at ` @@ -115,11 +120,16 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) duration := nullInt(log.DurationMs) firstToken := nullInt(log.FirstTokenMs) + var requestIDArg any + if requestID != "" { + requestIDArg = requestID + } + args := []any{ log.UserID, - log.APIKeyID, + log.ApiKeyID, log.AccountID, - log.RequestID, + requestIDArg, log.Model, groupID, subscriptionID, @@ -143,7 +153,14 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) createdAt, } if err := scanSingleRow(ctx, r.sql, query, args, &log.ID, &log.CreatedAt); err != nil { - return err + if errors.Is(err, sql.ErrNoRows) && requestID != "" { + selectQuery := "SELECT id, created_at FROM usage_logs WHERE request_id = $1 AND api_key_id = $2" + if err := scanSingleRow(ctx, r.sql, selectQuery, []any{requestID, log.ApiKeyID}, &log.ID, &log.CreatedAt); err != nil { + return err + } + } else { + return err + } } log.RateMultiplier = rateMultiplier return nil @@ -183,7 +200,7 @@ func (r *usageLogRepository) ListByUser(ctx context.Context, userID int64, param return r.listUsageLogsWithPagination(ctx, "WHERE user_id = $1", []any{userID}, params) } -func (r *usageLogRepository) ListByAPIKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) { +func (r *usageLogRepository) ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) { return r.listUsageLogsWithPagination(ctx, "WHERE api_key_id = $1", []any{apiKeyID}, params) } @@ -270,8 +287,8 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS r.sql, apiKeyStatsQuery, []any{service.StatusActive}, - &stats.TotalAPIKeys, - &stats.ActiveAPIKeys, + &stats.TotalApiKeys, + &stats.ActiveApiKeys, ); err != nil { return nil, err } @@ -418,8 +435,8 @@ func (r *usageLogRepository) GetUserStatsAggregated(ctx context.Context, userID return &stats, nil } -// GetAPIKeyStatsAggregated returns aggregated usage statistics for an API key using database-level aggregation -func (r *usageLogRepository) GetAPIKeyStatsAggregated(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) { +// GetApiKeyStatsAggregated returns aggregated usage statistics for an API key using database-level aggregation +func (r *usageLogRepository) GetApiKeyStatsAggregated(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) { query := ` SELECT COUNT(*) as total_requests, @@ -623,7 +640,7 @@ func resolveUsageStatsTimezone() string { return "UTC" } -func (r *usageLogRepository) ListByAPIKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) { +func (r *usageLogRepository) ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) { query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE api_key_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC" logs, err := r.queryUsageLogs(ctx, query, apiKeyID, startTime, endTime) return logs, nil, err @@ -709,11 +726,11 @@ type ModelStat = usagestats.ModelStat // UserUsageTrendPoint represents user usage trend data point type UserUsageTrendPoint = usagestats.UserUsageTrendPoint -// APIKeyUsageTrendPoint represents API key usage trend data point -type APIKeyUsageTrendPoint = usagestats.APIKeyUsageTrendPoint +// ApiKeyUsageTrendPoint represents API key usage trend data point +type ApiKeyUsageTrendPoint = usagestats.ApiKeyUsageTrendPoint -// GetAPIKeyUsageTrend returns usage trend data grouped by API key and date -func (r *usageLogRepository) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) (results []APIKeyUsageTrendPoint, err error) { +// GetApiKeyUsageTrend returns usage trend data grouped by API key and date +func (r *usageLogRepository) GetApiKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) (results []ApiKeyUsageTrendPoint, err error) { dateFormat := "YYYY-MM-DD" if granularity == "hour" { dateFormat = "YYYY-MM-DD HH24:00" @@ -755,10 +772,10 @@ func (r *usageLogRepository) GetAPIKeyUsageTrend(ctx context.Context, startTime, } }() - results = make([]APIKeyUsageTrendPoint, 0) + results = make([]ApiKeyUsageTrendPoint, 0) for rows.Next() { - var row APIKeyUsageTrendPoint - if err = rows.Scan(&row.Date, &row.APIKeyID, &row.KeyName, &row.Requests, &row.Tokens); err != nil { + var row ApiKeyUsageTrendPoint + if err = rows.Scan(&row.Date, &row.ApiKeyID, &row.KeyName, &row.Requests, &row.Tokens); err != nil { return nil, err } results = append(results, row) @@ -844,7 +861,7 @@ func (r *usageLogRepository) GetUserDashboardStats(ctx context.Context, userID i r.sql, "SELECT COUNT(*) FROM api_keys WHERE user_id = $1 AND deleted_at IS NULL", []any{userID}, - &stats.TotalAPIKeys, + &stats.TotalApiKeys, ); err != nil { return nil, err } @@ -853,7 +870,7 @@ func (r *usageLogRepository) GetUserDashboardStats(ctx context.Context, userID i r.sql, "SELECT COUNT(*) FROM api_keys WHERE user_id = $1 AND status = $2 AND deleted_at IS NULL", []any{userID, service.StatusActive}, - &stats.ActiveAPIKeys, + &stats.ActiveApiKeys, ); err != nil { return nil, err } @@ -1023,9 +1040,9 @@ func (r *usageLogRepository) ListWithFilters(ctx context.Context, params paginat conditions = append(conditions, fmt.Sprintf("user_id = $%d", len(args)+1)) args = append(args, filters.UserID) } - if filters.APIKeyID > 0 { + if filters.ApiKeyID > 0 { conditions = append(conditions, fmt.Sprintf("api_key_id = $%d", len(args)+1)) - args = append(args, filters.APIKeyID) + args = append(args, filters.ApiKeyID) } if filters.AccountID > 0 { conditions = append(conditions, fmt.Sprintf("account_id = $%d", len(args)+1)) @@ -1145,18 +1162,18 @@ func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs return result, nil } -// BatchAPIKeyUsageStats represents usage stats for a single API key -type BatchAPIKeyUsageStats = usagestats.BatchAPIKeyUsageStats +// BatchApiKeyUsageStats represents usage stats for a single API key +type BatchApiKeyUsageStats = usagestats.BatchApiKeyUsageStats -// GetBatchAPIKeyUsageStats gets today and total actual_cost for multiple API keys -func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*BatchAPIKeyUsageStats, error) { - result := make(map[int64]*BatchAPIKeyUsageStats) +// GetBatchApiKeyUsageStats gets today and total actual_cost for multiple API keys +func (r *usageLogRepository) GetBatchApiKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*BatchApiKeyUsageStats, error) { + result := make(map[int64]*BatchApiKeyUsageStats) if len(apiKeyIDs) == 0 { return result, nil } for _, id := range apiKeyIDs { - result[id] = &BatchAPIKeyUsageStats{APIKeyID: id} + result[id] = &BatchApiKeyUsageStats{ApiKeyID: id} } query := ` @@ -1582,7 +1599,7 @@ func (r *usageLogRepository) hydrateUsageLogAssociations(ctx context.Context, lo if err != nil { return err } - apiKeys, err := r.loadAPIKeys(ctx, ids.apiKeyIDs) + apiKeys, err := r.loadApiKeys(ctx, ids.apiKeyIDs) if err != nil { return err } @@ -1603,8 +1620,8 @@ func (r *usageLogRepository) hydrateUsageLogAssociations(ctx context.Context, lo if user, ok := users[logs[i].UserID]; ok { logs[i].User = user } - if key, ok := apiKeys[logs[i].APIKeyID]; ok { - logs[i].APIKey = key + if key, ok := apiKeys[logs[i].ApiKeyID]; ok { + logs[i].ApiKey = key } if acc, ok := accounts[logs[i].AccountID]; ok { logs[i].Account = acc @@ -1642,7 +1659,7 @@ func collectUsageLogIDs(logs []service.UsageLog) usageLogIDs { for i := range logs { userIDs[logs[i].UserID] = struct{}{} - apiKeyIDs[logs[i].APIKeyID] = struct{}{} + apiKeyIDs[logs[i].ApiKeyID] = struct{}{} accountIDs[logs[i].AccountID] = struct{}{} if logs[i].GroupID != nil { groupIDs[*logs[i].GroupID] = struct{}{} @@ -1676,12 +1693,12 @@ func (r *usageLogRepository) loadUsers(ctx context.Context, ids []int64) (map[in return out, nil } -func (r *usageLogRepository) loadAPIKeys(ctx context.Context, ids []int64) (map[int64]*service.APIKey, error) { - out := make(map[int64]*service.APIKey) +func (r *usageLogRepository) loadApiKeys(ctx context.Context, ids []int64) (map[int64]*service.ApiKey, error) { + out := make(map[int64]*service.ApiKey) if len(ids) == 0 { return out, nil } - models, err := r.client.APIKey.Query().Where(dbapikey.IDIn(ids...)).All(ctx) + models, err := r.client.ApiKey.Query().Where(dbapikey.IDIn(ids...)).All(ctx) if err != nil { return nil, err } @@ -1800,7 +1817,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e log := &service.UsageLog{ ID: id, UserID: userID, - APIKeyID: apiKeyID, + ApiKeyID: apiKeyID, AccountID: accountID, Model: model, InputTokens: inputTokens, diff --git a/backend/internal/repository/usage_log_repo_integration_test.go b/backend/internal/repository/usage_log_repo_integration_test.go index 694b23a4..c0b98e10 100644 --- a/backend/internal/repository/usage_log_repo_integration_test.go +++ b/backend/internal/repository/usage_log_repo_integration_test.go @@ -7,6 +7,8 @@ import ( "testing" "time" + "github.com/google/uuid" + dbent "github.com/Wei-Shaw/sub2api/ent" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/timezone" @@ -35,11 +37,12 @@ func TestUsageLogRepoSuite(t *testing.T) { suite.Run(t, new(UsageLogRepoSuite)) } -func (s *UsageLogRepoSuite) createUsageLog(user *service.User, apiKey *service.APIKey, account *service.Account, inputTokens, outputTokens int, cost float64, createdAt time.Time) *service.UsageLog { +func (s *UsageLogRepoSuite) createUsageLog(user *service.User, apiKey *service.ApiKey, account *service.Account, inputTokens, outputTokens int, cost float64, createdAt time.Time) *service.UsageLog { log := &service.UsageLog{ UserID: user.ID, - APIKeyID: apiKey.ID, + ApiKeyID: apiKey.ID, AccountID: account.ID, + RequestID: uuid.New().String(), // Generate unique RequestID for each log Model: "claude-3", InputTokens: inputTokens, OutputTokens: outputTokens, @@ -55,12 +58,12 @@ func (s *UsageLogRepoSuite) createUsageLog(user *service.User, apiKey *service.A func (s *UsageLogRepoSuite) TestCreate() { user := mustCreateUser(s.T(), s.client, &service.User{Email: "create@test.com"}) - apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-create", Name: "k"}) + apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-create", Name: "k"}) account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-create"}) log := &service.UsageLog{ UserID: user.ID, - APIKeyID: apiKey.ID, + ApiKeyID: apiKey.ID, AccountID: account.ID, Model: "claude-3", InputTokens: 10, @@ -76,7 +79,7 @@ func (s *UsageLogRepoSuite) TestCreate() { func (s *UsageLogRepoSuite) TestGetByID() { user := mustCreateUser(s.T(), s.client, &service.User{Email: "getbyid@test.com"}) - apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-getbyid", Name: "k"}) + apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-getbyid", Name: "k"}) account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-getbyid"}) log := s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now()) @@ -96,7 +99,7 @@ func (s *UsageLogRepoSuite) TestGetByID_NotFound() { func (s *UsageLogRepoSuite) TestDelete() { user := mustCreateUser(s.T(), s.client, &service.User{Email: "delete@test.com"}) - apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-delete", Name: "k"}) + apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-delete", Name: "k"}) account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-delete"}) log := s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now()) @@ -112,7 +115,7 @@ func (s *UsageLogRepoSuite) TestDelete() { func (s *UsageLogRepoSuite) TestListByUser() { user := mustCreateUser(s.T(), s.client, &service.User{Email: "listbyuser@test.com"}) - apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-listbyuser", Name: "k"}) + apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-listbyuser", Name: "k"}) account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-listbyuser"}) s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now()) @@ -124,18 +127,18 @@ func (s *UsageLogRepoSuite) TestListByUser() { s.Require().Equal(int64(2), page.Total) } -// --- ListByAPIKey --- +// --- ListByApiKey --- -func (s *UsageLogRepoSuite) TestListByAPIKey() { +func (s *UsageLogRepoSuite) TestListByApiKey() { user := mustCreateUser(s.T(), s.client, &service.User{Email: "listbyapikey@test.com"}) - apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-listbyapikey", Name: "k"}) + apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-listbyapikey", Name: "k"}) account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-listbyapikey"}) s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now()) s.createUsageLog(user, apiKey, account, 15, 25, 0.6, time.Now()) - logs, page, err := s.repo.ListByAPIKey(s.ctx, apiKey.ID, pagination.PaginationParams{Page: 1, PageSize: 10}) - s.Require().NoError(err, "ListByAPIKey") + logs, page, err := s.repo.ListByApiKey(s.ctx, apiKey.ID, pagination.PaginationParams{Page: 1, PageSize: 10}) + s.Require().NoError(err, "ListByApiKey") s.Require().Len(logs, 2) s.Require().Equal(int64(2), page.Total) } @@ -144,7 +147,7 @@ func (s *UsageLogRepoSuite) TestListByAPIKey() { func (s *UsageLogRepoSuite) TestListByAccount() { user := mustCreateUser(s.T(), s.client, &service.User{Email: "listbyaccount@test.com"}) - apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-listbyaccount", Name: "k"}) + apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-listbyaccount", Name: "k"}) account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-listbyaccount"}) s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now()) @@ -159,7 +162,7 @@ func (s *UsageLogRepoSuite) TestListByAccount() { func (s *UsageLogRepoSuite) TestGetUserStats() { user := mustCreateUser(s.T(), s.client, &service.User{Email: "userstats@test.com"}) - apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-userstats", Name: "k"}) + apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-userstats", Name: "k"}) account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-userstats"}) base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) @@ -179,7 +182,7 @@ func (s *UsageLogRepoSuite) TestGetUserStats() { func (s *UsageLogRepoSuite) TestListWithFilters() { user := mustCreateUser(s.T(), s.client, &service.User{Email: "filters@test.com"}) - apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-filters", Name: "k"}) + apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-filters", Name: "k"}) account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-filters"}) s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now()) @@ -211,8 +214,8 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() { }) group := mustCreateGroup(s.T(), s.client, &service.Group{Name: "g-ul"}) - apiKey1 := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: userToday.ID, Key: "sk-ul-1", Name: "ul1"}) - mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: userOld.ID, Key: "sk-ul-2", Name: "ul2", Status: service.StatusDisabled}) + apiKey1 := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: userToday.ID, Key: "sk-ul-1", Name: "ul1"}) + mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: userOld.ID, Key: "sk-ul-2", Name: "ul2", Status: service.StatusDisabled}) resetAt := now.Add(10 * time.Minute) accNormal := mustCreateAccount(s.T(), s.client, &service.Account{Name: "a-normal", Schedulable: true}) @@ -223,7 +226,7 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() { d1, d2, d3 := 100, 200, 300 logToday := &service.UsageLog{ UserID: userToday.ID, - APIKeyID: apiKey1.ID, + ApiKeyID: apiKey1.ID, AccountID: accNormal.ID, Model: "claude-3", GroupID: &group.ID, @@ -240,7 +243,7 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() { logOld := &service.UsageLog{ UserID: userOld.ID, - APIKeyID: apiKey1.ID, + ApiKeyID: apiKey1.ID, AccountID: accNormal.ID, Model: "claude-3", InputTokens: 5, @@ -254,7 +257,7 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() { logPerf := &service.UsageLog{ UserID: userToday.ID, - APIKeyID: apiKey1.ID, + ApiKeyID: apiKey1.ID, AccountID: accNormal.ID, Model: "claude-3", InputTokens: 1, @@ -272,8 +275,8 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() { s.Require().Equal(baseStats.TotalUsers+2, stats.TotalUsers, "TotalUsers mismatch") s.Require().Equal(baseStats.TodayNewUsers+1, stats.TodayNewUsers, "TodayNewUsers mismatch") s.Require().Equal(baseStats.ActiveUsers+1, stats.ActiveUsers, "ActiveUsers mismatch") - s.Require().Equal(baseStats.TotalAPIKeys+2, stats.TotalAPIKeys, "TotalAPIKeys mismatch") - s.Require().Equal(baseStats.ActiveAPIKeys+1, stats.ActiveAPIKeys, "ActiveAPIKeys mismatch") + s.Require().Equal(baseStats.TotalApiKeys+2, stats.TotalApiKeys, "TotalApiKeys mismatch") + s.Require().Equal(baseStats.ActiveApiKeys+1, stats.ActiveApiKeys, "ActiveApiKeys mismatch") s.Require().Equal(baseStats.TotalAccounts+4, stats.TotalAccounts, "TotalAccounts mismatch") s.Require().Equal(baseStats.ErrorAccounts+1, stats.ErrorAccounts, "ErrorAccounts mismatch") s.Require().Equal(baseStats.RateLimitAccounts+1, stats.RateLimitAccounts, "RateLimitAccounts mismatch") @@ -300,14 +303,14 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() { func (s *UsageLogRepoSuite) TestGetUserDashboardStats() { user := mustCreateUser(s.T(), s.client, &service.User{Email: "userdash@test.com"}) - apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-userdash", Name: "k"}) + apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-userdash", Name: "k"}) account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-userdash"}) s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now()) stats, err := s.repo.GetUserDashboardStats(s.ctx, user.ID) s.Require().NoError(err, "GetUserDashboardStats") - s.Require().Equal(int64(1), stats.TotalAPIKeys) + s.Require().Equal(int64(1), stats.TotalApiKeys) s.Require().Equal(int64(1), stats.TotalRequests) } @@ -315,7 +318,7 @@ func (s *UsageLogRepoSuite) TestGetUserDashboardStats() { func (s *UsageLogRepoSuite) TestGetAccountTodayStats() { user := mustCreateUser(s.T(), s.client, &service.User{Email: "acctoday@test.com"}) - apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-acctoday", Name: "k"}) + apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-acctoday", Name: "k"}) account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-today"}) s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now()) @@ -331,8 +334,8 @@ func (s *UsageLogRepoSuite) TestGetAccountTodayStats() { func (s *UsageLogRepoSuite) TestGetBatchUserUsageStats() { user1 := mustCreateUser(s.T(), s.client, &service.User{Email: "batch1@test.com"}) user2 := mustCreateUser(s.T(), s.client, &service.User{Email: "batch2@test.com"}) - apiKey1 := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user1.ID, Key: "sk-batch1", Name: "k"}) - apiKey2 := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user2.ID, Key: "sk-batch2", Name: "k"}) + apiKey1 := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user1.ID, Key: "sk-batch1", Name: "k"}) + apiKey2 := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user2.ID, Key: "sk-batch2", Name: "k"}) account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-batch"}) s.createUsageLog(user1, apiKey1, account, 10, 20, 0.5, time.Now()) @@ -351,24 +354,24 @@ func (s *UsageLogRepoSuite) TestGetBatchUserUsageStats_Empty() { s.Require().Empty(stats) } -// --- GetBatchAPIKeyUsageStats --- +// --- GetBatchApiKeyUsageStats --- -func (s *UsageLogRepoSuite) TestGetBatchAPIKeyUsageStats() { +func (s *UsageLogRepoSuite) TestGetBatchApiKeyUsageStats() { user := mustCreateUser(s.T(), s.client, &service.User{Email: "batchkey@test.com"}) - apiKey1 := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-batchkey1", Name: "k1"}) - apiKey2 := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-batchkey2", Name: "k2"}) + apiKey1 := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-batchkey1", Name: "k1"}) + apiKey2 := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-batchkey2", Name: "k2"}) account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-batchkey"}) s.createUsageLog(user, apiKey1, account, 10, 20, 0.5, time.Now()) s.createUsageLog(user, apiKey2, account, 15, 25, 0.6, time.Now()) - stats, err := s.repo.GetBatchAPIKeyUsageStats(s.ctx, []int64{apiKey1.ID, apiKey2.ID}) - s.Require().NoError(err, "GetBatchAPIKeyUsageStats") + stats, err := s.repo.GetBatchApiKeyUsageStats(s.ctx, []int64{apiKey1.ID, apiKey2.ID}) + s.Require().NoError(err, "GetBatchApiKeyUsageStats") s.Require().Len(stats, 2) } -func (s *UsageLogRepoSuite) TestGetBatchAPIKeyUsageStats_Empty() { - stats, err := s.repo.GetBatchAPIKeyUsageStats(s.ctx, []int64{}) +func (s *UsageLogRepoSuite) TestGetBatchApiKeyUsageStats_Empty() { + stats, err := s.repo.GetBatchApiKeyUsageStats(s.ctx, []int64{}) s.Require().NoError(err) s.Require().Empty(stats) } @@ -377,7 +380,7 @@ func (s *UsageLogRepoSuite) TestGetBatchAPIKeyUsageStats_Empty() { func (s *UsageLogRepoSuite) TestGetGlobalStats() { user := mustCreateUser(s.T(), s.client, &service.User{Email: "global@test.com"}) - apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-global", Name: "k"}) + apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-global", Name: "k"}) account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-global"}) base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) @@ -402,7 +405,7 @@ func maxTime(a, b time.Time) time.Time { func (s *UsageLogRepoSuite) TestListByUserAndTimeRange() { user := mustCreateUser(s.T(), s.client, &service.User{Email: "timerange@test.com"}) - apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-timerange", Name: "k"}) + apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-timerange", Name: "k"}) account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-timerange"}) base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) @@ -417,11 +420,11 @@ func (s *UsageLogRepoSuite) TestListByUserAndTimeRange() { s.Require().Len(logs, 2) } -// --- ListByAPIKeyAndTimeRange --- +// --- ListByApiKeyAndTimeRange --- -func (s *UsageLogRepoSuite) TestListByAPIKeyAndTimeRange() { +func (s *UsageLogRepoSuite) TestListByApiKeyAndTimeRange() { user := mustCreateUser(s.T(), s.client, &service.User{Email: "keytimerange@test.com"}) - apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-keytimerange", Name: "k"}) + apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-keytimerange", Name: "k"}) account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-keytimerange"}) base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) @@ -431,8 +434,8 @@ func (s *UsageLogRepoSuite) TestListByAPIKeyAndTimeRange() { startTime := base.Add(-1 * time.Hour) endTime := base.Add(2 * time.Hour) - logs, _, err := s.repo.ListByAPIKeyAndTimeRange(s.ctx, apiKey.ID, startTime, endTime) - s.Require().NoError(err, "ListByAPIKeyAndTimeRange") + logs, _, err := s.repo.ListByApiKeyAndTimeRange(s.ctx, apiKey.ID, startTime, endTime) + s.Require().NoError(err, "ListByApiKeyAndTimeRange") s.Require().Len(logs, 2) } @@ -440,7 +443,7 @@ func (s *UsageLogRepoSuite) TestListByAPIKeyAndTimeRange() { func (s *UsageLogRepoSuite) TestListByAccountAndTimeRange() { user := mustCreateUser(s.T(), s.client, &service.User{Email: "acctimerange@test.com"}) - apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-acctimerange", Name: "k"}) + apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-acctimerange", Name: "k"}) account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-acctimerange"}) base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) @@ -459,7 +462,7 @@ func (s *UsageLogRepoSuite) TestListByAccountAndTimeRange() { func (s *UsageLogRepoSuite) TestListByModelAndTimeRange() { user := mustCreateUser(s.T(), s.client, &service.User{Email: "modeltimerange@test.com"}) - apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-modeltimerange", Name: "k"}) + apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-modeltimerange", Name: "k"}) account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-modeltimerange"}) base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) @@ -467,7 +470,7 @@ func (s *UsageLogRepoSuite) TestListByModelAndTimeRange() { // Create logs with different models log1 := &service.UsageLog{ UserID: user.ID, - APIKeyID: apiKey.ID, + ApiKeyID: apiKey.ID, AccountID: account.ID, Model: "claude-3-opus", InputTokens: 10, @@ -480,7 +483,7 @@ func (s *UsageLogRepoSuite) TestListByModelAndTimeRange() { log2 := &service.UsageLog{ UserID: user.ID, - APIKeyID: apiKey.ID, + ApiKeyID: apiKey.ID, AccountID: account.ID, Model: "claude-3-opus", InputTokens: 15, @@ -493,7 +496,7 @@ func (s *UsageLogRepoSuite) TestListByModelAndTimeRange() { log3 := &service.UsageLog{ UserID: user.ID, - APIKeyID: apiKey.ID, + ApiKeyID: apiKey.ID, AccountID: account.ID, Model: "claude-3-sonnet", InputTokens: 20, @@ -515,7 +518,7 @@ func (s *UsageLogRepoSuite) TestListByModelAndTimeRange() { func (s *UsageLogRepoSuite) TestGetAccountWindowStats() { user := mustCreateUser(s.T(), s.client, &service.User{Email: "windowstats@test.com"}) - apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-windowstats", Name: "k"}) + apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-windowstats", Name: "k"}) account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-windowstats"}) now := time.Now() @@ -535,7 +538,7 @@ func (s *UsageLogRepoSuite) TestGetAccountWindowStats() { func (s *UsageLogRepoSuite) TestGetUserUsageTrendByUserID() { user := mustCreateUser(s.T(), s.client, &service.User{Email: "usertrend@test.com"}) - apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-usertrend", Name: "k"}) + apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-usertrend", Name: "k"}) account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-usertrend"}) base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) @@ -552,7 +555,7 @@ func (s *UsageLogRepoSuite) TestGetUserUsageTrendByUserID() { func (s *UsageLogRepoSuite) TestGetUserUsageTrendByUserID_HourlyGranularity() { user := mustCreateUser(s.T(), s.client, &service.User{Email: "usertrendhourly@test.com"}) - apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-usertrendhourly", Name: "k"}) + apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-usertrendhourly", Name: "k"}) account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-usertrendhourly"}) base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) @@ -571,7 +574,7 @@ func (s *UsageLogRepoSuite) TestGetUserUsageTrendByUserID_HourlyGranularity() { func (s *UsageLogRepoSuite) TestGetUserModelStats() { user := mustCreateUser(s.T(), s.client, &service.User{Email: "modelstats@test.com"}) - apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-modelstats", Name: "k"}) + apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-modelstats", Name: "k"}) account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-modelstats"}) base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) @@ -579,7 +582,7 @@ func (s *UsageLogRepoSuite) TestGetUserModelStats() { // Create logs with different models log1 := &service.UsageLog{ UserID: user.ID, - APIKeyID: apiKey.ID, + ApiKeyID: apiKey.ID, AccountID: account.ID, Model: "claude-3-opus", InputTokens: 100, @@ -592,7 +595,7 @@ func (s *UsageLogRepoSuite) TestGetUserModelStats() { log2 := &service.UsageLog{ UserID: user.ID, - APIKeyID: apiKey.ID, + ApiKeyID: apiKey.ID, AccountID: account.ID, Model: "claude-3-sonnet", InputTokens: 50, @@ -618,7 +621,7 @@ func (s *UsageLogRepoSuite) TestGetUserModelStats() { func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters() { user := mustCreateUser(s.T(), s.client, &service.User{Email: "trendfilters@test.com"}) - apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-trendfilters", Name: "k"}) + apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-trendfilters", Name: "k"}) account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-trendfilters"}) base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) @@ -646,7 +649,7 @@ func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters() { func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters_HourlyGranularity() { user := mustCreateUser(s.T(), s.client, &service.User{Email: "trendfilters-h@test.com"}) - apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-trendfilters-h", Name: "k"}) + apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-trendfilters-h", Name: "k"}) account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-trendfilters-h"}) base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) @@ -665,14 +668,14 @@ func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters_HourlyGranularity() { func (s *UsageLogRepoSuite) TestGetModelStatsWithFilters() { user := mustCreateUser(s.T(), s.client, &service.User{Email: "modelfilters@test.com"}) - apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-modelfilters", Name: "k"}) + apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-modelfilters", Name: "k"}) account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-modelfilters"}) base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) log1 := &service.UsageLog{ UserID: user.ID, - APIKeyID: apiKey.ID, + ApiKeyID: apiKey.ID, AccountID: account.ID, Model: "claude-3-opus", InputTokens: 100, @@ -685,7 +688,7 @@ func (s *UsageLogRepoSuite) TestGetModelStatsWithFilters() { log2 := &service.UsageLog{ UserID: user.ID, - APIKeyID: apiKey.ID, + ApiKeyID: apiKey.ID, AccountID: account.ID, Model: "claude-3-sonnet", InputTokens: 50, @@ -719,7 +722,7 @@ func (s *UsageLogRepoSuite) TestGetModelStatsWithFilters() { func (s *UsageLogRepoSuite) TestGetAccountUsageStats() { user := mustCreateUser(s.T(), s.client, &service.User{Email: "accstats@test.com"}) - apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-accstats", Name: "k"}) + apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-accstats", Name: "k"}) account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-accstats"}) base := time.Date(2025, 1, 15, 0, 0, 0, 0, time.UTC) @@ -727,7 +730,7 @@ func (s *UsageLogRepoSuite) TestGetAccountUsageStats() { // Create logs on different days log1 := &service.UsageLog{ UserID: user.ID, - APIKeyID: apiKey.ID, + ApiKeyID: apiKey.ID, AccountID: account.ID, Model: "claude-3-opus", InputTokens: 100, @@ -740,7 +743,7 @@ func (s *UsageLogRepoSuite) TestGetAccountUsageStats() { log2 := &service.UsageLog{ UserID: user.ID, - APIKeyID: apiKey.ID, + ApiKeyID: apiKey.ID, AccountID: account.ID, Model: "claude-3-sonnet", InputTokens: 50, @@ -782,8 +785,8 @@ func (s *UsageLogRepoSuite) TestGetAccountUsageStats_EmptyRange() { func (s *UsageLogRepoSuite) TestGetUserUsageTrend() { user1 := mustCreateUser(s.T(), s.client, &service.User{Email: "usertrend1@test.com"}) user2 := mustCreateUser(s.T(), s.client, &service.User{Email: "usertrend2@test.com"}) - apiKey1 := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user1.ID, Key: "sk-usertrend1", Name: "k1"}) - apiKey2 := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user2.ID, Key: "sk-usertrend2", Name: "k2"}) + apiKey1 := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user1.ID, Key: "sk-usertrend1", Name: "k1"}) + apiKey2 := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user2.ID, Key: "sk-usertrend2", Name: "k2"}) account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-usertrends"}) base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) @@ -799,12 +802,12 @@ func (s *UsageLogRepoSuite) TestGetUserUsageTrend() { s.Require().GreaterOrEqual(len(trend), 2) } -// --- GetAPIKeyUsageTrend --- +// --- GetApiKeyUsageTrend --- -func (s *UsageLogRepoSuite) TestGetAPIKeyUsageTrend() { +func (s *UsageLogRepoSuite) TestGetApiKeyUsageTrend() { user := mustCreateUser(s.T(), s.client, &service.User{Email: "keytrend@test.com"}) - apiKey1 := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-keytrend1", Name: "k1"}) - apiKey2 := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-keytrend2", Name: "k2"}) + apiKey1 := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-keytrend1", Name: "k1"}) + apiKey2 := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-keytrend2", Name: "k2"}) account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-keytrends"}) base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) @@ -815,14 +818,14 @@ func (s *UsageLogRepoSuite) TestGetAPIKeyUsageTrend() { startTime := base.Add(-1 * time.Hour) endTime := base.Add(48 * time.Hour) - trend, err := s.repo.GetAPIKeyUsageTrend(s.ctx, startTime, endTime, "day", 10) - s.Require().NoError(err, "GetAPIKeyUsageTrend") + trend, err := s.repo.GetApiKeyUsageTrend(s.ctx, startTime, endTime, "day", 10) + s.Require().NoError(err, "GetApiKeyUsageTrend") s.Require().GreaterOrEqual(len(trend), 2) } -func (s *UsageLogRepoSuite) TestGetAPIKeyUsageTrend_HourlyGranularity() { +func (s *UsageLogRepoSuite) TestGetApiKeyUsageTrend_HourlyGranularity() { user := mustCreateUser(s.T(), s.client, &service.User{Email: "keytrendh@test.com"}) - apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-keytrendh", Name: "k"}) + apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-keytrendh", Name: "k"}) account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-keytrendh"}) base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) @@ -832,21 +835,21 @@ func (s *UsageLogRepoSuite) TestGetAPIKeyUsageTrend_HourlyGranularity() { startTime := base.Add(-1 * time.Hour) endTime := base.Add(3 * time.Hour) - trend, err := s.repo.GetAPIKeyUsageTrend(s.ctx, startTime, endTime, "hour", 10) - s.Require().NoError(err, "GetAPIKeyUsageTrend hourly") + trend, err := s.repo.GetApiKeyUsageTrend(s.ctx, startTime, endTime, "hour", 10) + s.Require().NoError(err, "GetApiKeyUsageTrend hourly") s.Require().Len(trend, 2) } // --- ListWithFilters (additional filter tests) --- -func (s *UsageLogRepoSuite) TestListWithFilters_APIKeyFilter() { +func (s *UsageLogRepoSuite) TestListWithFilters_ApiKeyFilter() { user := mustCreateUser(s.T(), s.client, &service.User{Email: "filterskey@test.com"}) - apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-filterskey", Name: "k"}) + apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-filterskey", Name: "k"}) account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-filterskey"}) s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now()) - filters := usagestats.UsageLogFilters{APIKeyID: apiKey.ID} + filters := usagestats.UsageLogFilters{ApiKeyID: apiKey.ID} logs, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, filters) s.Require().NoError(err, "ListWithFilters apiKey") s.Require().Len(logs, 1) @@ -855,7 +858,7 @@ func (s *UsageLogRepoSuite) TestListWithFilters_APIKeyFilter() { func (s *UsageLogRepoSuite) TestListWithFilters_TimeRange() { user := mustCreateUser(s.T(), s.client, &service.User{Email: "filterstime@test.com"}) - apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-filterstime", Name: "k"}) + apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-filterstime", Name: "k"}) account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-filterstime"}) base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) @@ -874,7 +877,7 @@ func (s *UsageLogRepoSuite) TestListWithFilters_TimeRange() { func (s *UsageLogRepoSuite) TestListWithFilters_CombinedFilters() { user := mustCreateUser(s.T(), s.client, &service.User{Email: "filterscombined@test.com"}) - apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-filterscombined", Name: "k"}) + apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-filterscombined", Name: "k"}) account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-filterscombined"}) base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) @@ -885,7 +888,7 @@ func (s *UsageLogRepoSuite) TestListWithFilters_CombinedFilters() { endTime := base.Add(2 * time.Hour) filters := usagestats.UsageLogFilters{ UserID: user.ID, - APIKeyID: apiKey.ID, + ApiKeyID: apiKey.ID, StartTime: &startTime, EndTime: &endTime, } diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go index f944458e..2fe00ad5 100644 --- a/backend/internal/server/api_contract_test.go +++ b/backend/internal/server/api_contract_test.go @@ -91,7 +91,7 @@ func TestAPIContracts(t *testing.T) { name: "GET /api/v1/keys (paginated)", setup: func(t *testing.T, deps *contractDeps) { t.Helper() - deps.apiKeyRepo.MustSeed(&service.APIKey{ + deps.apiKeyRepo.MustSeed(&service.ApiKey{ ID: 100, UserID: 1, Key: "sk_custom_1234567890", @@ -135,7 +135,7 @@ func TestAPIContracts(t *testing.T) { { ID: 1, UserID: 1, - APIKeyID: 100, + ApiKeyID: 100, AccountID: 200, Model: "claude-3", InputTokens: 10, @@ -150,7 +150,7 @@ func TestAPIContracts(t *testing.T) { { ID: 2, UserID: 1, - APIKeyID: 100, + ApiKeyID: 100, AccountID: 200, Model: "claude-3", InputTokens: 5, @@ -188,7 +188,7 @@ func TestAPIContracts(t *testing.T) { { ID: 1, UserID: 1, - APIKeyID: 100, + ApiKeyID: 100, AccountID: 200, RequestID: "req_123", Model: "claude-3", @@ -259,13 +259,13 @@ func TestAPIContracts(t *testing.T) { service.SettingKeyRegistrationEnabled: "true", service.SettingKeyEmailVerifyEnabled: "false", - service.SettingKeySMTPHost: "smtp.example.com", - service.SettingKeySMTPPort: "587", - service.SettingKeySMTPUsername: "user", - service.SettingKeySMTPPassword: "secret", - service.SettingKeySMTPFrom: "no-reply@example.com", - service.SettingKeySMTPFromName: "Sub2API", - service.SettingKeySMTPUseTLS: "true", + service.SettingKeySmtpHost: "smtp.example.com", + service.SettingKeySmtpPort: "587", + service.SettingKeySmtpUsername: "user", + service.SettingKeySmtpPassword: "secret", + service.SettingKeySmtpFrom: "no-reply@example.com", + service.SettingKeySmtpFromName: "Sub2API", + service.SettingKeySmtpUseTLS: "true", service.SettingKeyTurnstileEnabled: "true", service.SettingKeyTurnstileSiteKey: "site-key", @@ -274,9 +274,9 @@ func TestAPIContracts(t *testing.T) { service.SettingKeySiteName: "Sub2API", service.SettingKeySiteLogo: "", service.SettingKeySiteSubtitle: "Subtitle", - service.SettingKeyAPIBaseURL: "https://api.example.com", + service.SettingKeyApiBaseUrl: "https://api.example.com", service.SettingKeyContactInfo: "support", - service.SettingKeyDocURL: "https://docs.example.com", + service.SettingKeyDocUrl: "https://docs.example.com", service.SettingKeyDefaultConcurrency: "5", service.SettingKeyDefaultBalance: "1.25", @@ -308,7 +308,12 @@ func TestAPIContracts(t *testing.T) { "contact_info": "support", "doc_url": "https://docs.example.com", "default_concurrency": 5, - "default_balance": 1.25 + "default_balance": 1.25, + "enable_model_fallback": false, + "fallback_model_anthropic": "", + "fallback_model_antigravity": "", + "fallback_model_gemini": "", + "fallback_model_openai": "" } }`, }, @@ -331,7 +336,7 @@ func TestAPIContracts(t *testing.T) { type contractDeps struct { now time.Time router http.Handler - apiKeyRepo *stubAPIKeyRepo + apiKeyRepo *stubApiKeyRepo usageRepo *stubUsageLogRepo settingRepo *stubSettingRepo } @@ -359,20 +364,20 @@ func newContractDeps(t *testing.T) *contractDeps { }, } - apiKeyRepo := newStubAPIKeyRepo(now) - apiKeyCache := stubAPIKeyCache{} + apiKeyRepo := newStubApiKeyRepo(now) + apiKeyCache := stubApiKeyCache{} groupRepo := stubGroupRepo{} userSubRepo := stubUserSubscriptionRepo{} cfg := &config.Config{ Default: config.DefaultConfig{ - APIKeyPrefix: "sk-", + ApiKeyPrefix: "sk-", }, RunMode: config.RunModeStandard, } userService := service.NewUserService(userRepo) - apiKeyService := service.NewAPIKeyService(apiKeyRepo, userRepo, groupRepo, userSubRepo, apiKeyCache, cfg) + apiKeyService := service.NewApiKeyService(apiKeyRepo, userRepo, groupRepo, userSubRepo, apiKeyCache, cfg) usageRepo := newStubUsageLogRepo() usageService := service.NewUsageService(usageRepo, userRepo) @@ -525,25 +530,25 @@ func (r *stubUserRepo) RemoveGroupFromAllowedGroups(ctx context.Context, groupID return 0, errors.New("not implemented") } -type stubAPIKeyCache struct{} +type stubApiKeyCache struct{} -func (stubAPIKeyCache) GetCreateAttemptCount(ctx context.Context, userID int64) (int, error) { +func (stubApiKeyCache) GetCreateAttemptCount(ctx context.Context, userID int64) (int, error) { return 0, nil } -func (stubAPIKeyCache) IncrementCreateAttemptCount(ctx context.Context, userID int64) error { +func (stubApiKeyCache) IncrementCreateAttemptCount(ctx context.Context, userID int64) error { return nil } -func (stubAPIKeyCache) DeleteCreateAttemptCount(ctx context.Context, userID int64) error { +func (stubApiKeyCache) DeleteCreateAttemptCount(ctx context.Context, userID int64) error { return nil } -func (stubAPIKeyCache) IncrementDailyUsage(ctx context.Context, apiKey string) error { +func (stubApiKeyCache) IncrementDailyUsage(ctx context.Context, apiKey string) error { return nil } -func (stubAPIKeyCache) SetDailyUsageExpiry(ctx context.Context, apiKey string, ttl time.Duration) error { +func (stubApiKeyCache) SetDailyUsageExpiry(ctx context.Context, apiKey string, ttl time.Duration) error { return nil } @@ -660,24 +665,24 @@ func (stubUserSubscriptionRepo) BatchUpdateExpiredStatus(ctx context.Context) (i return 0, errors.New("not implemented") } -type stubAPIKeyRepo struct { +type stubApiKeyRepo struct { now time.Time nextID int64 - byID map[int64]*service.APIKey - byKey map[string]*service.APIKey + byID map[int64]*service.ApiKey + byKey map[string]*service.ApiKey } -func newStubAPIKeyRepo(now time.Time) *stubAPIKeyRepo { - return &stubAPIKeyRepo{ +func newStubApiKeyRepo(now time.Time) *stubApiKeyRepo { + return &stubApiKeyRepo{ now: now, nextID: 100, - byID: make(map[int64]*service.APIKey), - byKey: make(map[string]*service.APIKey), + byID: make(map[int64]*service.ApiKey), + byKey: make(map[string]*service.ApiKey), } } -func (r *stubAPIKeyRepo) MustSeed(key *service.APIKey) { +func (r *stubApiKeyRepo) MustSeed(key *service.ApiKey) { if key == nil { return } @@ -686,7 +691,7 @@ func (r *stubAPIKeyRepo) MustSeed(key *service.APIKey) { r.byKey[clone.Key] = &clone } -func (r *stubAPIKeyRepo) Create(ctx context.Context, key *service.APIKey) error { +func (r *stubApiKeyRepo) Create(ctx context.Context, key *service.ApiKey) error { if key == nil { return errors.New("nil key") } @@ -706,38 +711,38 @@ func (r *stubAPIKeyRepo) Create(ctx context.Context, key *service.APIKey) error return nil } -func (r *stubAPIKeyRepo) GetByID(ctx context.Context, id int64) (*service.APIKey, error) { +func (r *stubApiKeyRepo) GetByID(ctx context.Context, id int64) (*service.ApiKey, error) { key, ok := r.byID[id] if !ok { - return nil, service.ErrAPIKeyNotFound + return nil, service.ErrApiKeyNotFound } clone := *key return &clone, nil } -func (r *stubAPIKeyRepo) GetOwnerID(ctx context.Context, id int64) (int64, error) { +func (r *stubApiKeyRepo) GetOwnerID(ctx context.Context, id int64) (int64, error) { key, ok := r.byID[id] if !ok { - return 0, service.ErrAPIKeyNotFound + return 0, service.ErrApiKeyNotFound } return key.UserID, nil } -func (r *stubAPIKeyRepo) GetByKey(ctx context.Context, key string) (*service.APIKey, error) { +func (r *stubApiKeyRepo) GetByKey(ctx context.Context, key string) (*service.ApiKey, error) { found, ok := r.byKey[key] if !ok { - return nil, service.ErrAPIKeyNotFound + return nil, service.ErrApiKeyNotFound } clone := *found return &clone, nil } -func (r *stubAPIKeyRepo) Update(ctx context.Context, key *service.APIKey) error { +func (r *stubApiKeyRepo) Update(ctx context.Context, key *service.ApiKey) error { if key == nil { return errors.New("nil key") } if _, ok := r.byID[key.ID]; !ok { - return service.ErrAPIKeyNotFound + return service.ErrApiKeyNotFound } if key.UpdatedAt.IsZero() { key.UpdatedAt = r.now @@ -748,17 +753,17 @@ func (r *stubAPIKeyRepo) Update(ctx context.Context, key *service.APIKey) error return nil } -func (r *stubAPIKeyRepo) Delete(ctx context.Context, id int64) error { +func (r *stubApiKeyRepo) Delete(ctx context.Context, id int64) error { key, ok := r.byID[id] if !ok { - return service.ErrAPIKeyNotFound + return service.ErrApiKeyNotFound } delete(r.byID, id) delete(r.byKey, key.Key) return nil } -func (r *stubAPIKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) { +func (r *stubApiKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) { ids := make([]int64, 0, len(r.byID)) for id := range r.byID { if r.byID[id].UserID == userID { @@ -776,7 +781,7 @@ func (r *stubAPIKeyRepo) ListByUserID(ctx context.Context, userID int64, params end = len(ids) } - out := make([]service.APIKey, 0, end-start) + out := make([]service.ApiKey, 0, end-start) for _, id := range ids[start:end] { clone := *r.byID[id] out = append(out, clone) @@ -796,7 +801,7 @@ func (r *stubAPIKeyRepo) ListByUserID(ctx context.Context, userID int64, params }, nil } -func (r *stubAPIKeyRepo) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) { +func (r *stubApiKeyRepo) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) { if len(apiKeyIDs) == 0 { return []int64{}, nil } @@ -815,7 +820,7 @@ func (r *stubAPIKeyRepo) VerifyOwnership(ctx context.Context, userID int64, apiK return out, nil } -func (r *stubAPIKeyRepo) CountByUserID(ctx context.Context, userID int64) (int64, error) { +func (r *stubApiKeyRepo) CountByUserID(ctx context.Context, userID int64) (int64, error) { var count int64 for _, key := range r.byID { if key.UserID == userID { @@ -825,24 +830,24 @@ func (r *stubAPIKeyRepo) CountByUserID(ctx context.Context, userID int64) (int64 return count, nil } -func (r *stubAPIKeyRepo) ExistsByKey(ctx context.Context, key string) (bool, error) { +func (r *stubApiKeyRepo) ExistsByKey(ctx context.Context, key string) (bool, error) { _, ok := r.byKey[key] return ok, nil } -func (r *stubAPIKeyRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) { +func (r *stubApiKeyRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) { return nil, nil, errors.New("not implemented") } -func (r *stubAPIKeyRepo) SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.APIKey, error) { +func (r *stubApiKeyRepo) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.ApiKey, error) { return nil, errors.New("not implemented") } -func (r *stubAPIKeyRepo) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) { +func (r *stubApiKeyRepo) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) { return 0, errors.New("not implemented") } -func (r *stubAPIKeyRepo) CountByGroupID(ctx context.Context, groupID int64) (int64, error) { +func (r *stubApiKeyRepo) CountByGroupID(ctx context.Context, groupID int64) (int64, error) { return 0, errors.New("not implemented") } @@ -877,7 +882,7 @@ func (r *stubUsageLogRepo) ListByUser(ctx context.Context, userID int64, params return out, paginationResult(total, params), nil } -func (r *stubUsageLogRepo) ListByAPIKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) { +func (r *stubUsageLogRepo) ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) { return nil, nil, errors.New("not implemented") } @@ -890,7 +895,7 @@ func (r *stubUsageLogRepo) ListByUserAndTimeRange(ctx context.Context, userID in return logs, paginationResult(int64(len(logs)), pagination.PaginationParams{Page: 1, PageSize: 100}), nil } -func (r *stubUsageLogRepo) ListByAPIKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) { +func (r *stubUsageLogRepo) ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) { return nil, nil, errors.New("not implemented") } @@ -922,7 +927,7 @@ func (r *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTi return nil, errors.New("not implemented") } -func (r *stubUsageLogRepo) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error) { +func (r *stubUsageLogRepo) GetApiKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.ApiKeyUsageTrendPoint, error) { return nil, errors.New("not implemented") } @@ -975,7 +980,7 @@ func (r *stubUsageLogRepo) GetUserStatsAggregated(ctx context.Context, userID in }, nil } -func (r *stubUsageLogRepo) GetAPIKeyStatsAggregated(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) { +func (r *stubUsageLogRepo) GetApiKeyStatsAggregated(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) { return nil, errors.New("not implemented") } @@ -995,7 +1000,7 @@ func (r *stubUsageLogRepo) GetBatchUserUsageStats(ctx context.Context, userIDs [ return nil, errors.New("not implemented") } -func (r *stubUsageLogRepo) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) { +func (r *stubUsageLogRepo) GetBatchApiKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchApiKeyUsageStats, error) { return nil, errors.New("not implemented") } @@ -1017,8 +1022,8 @@ func (r *stubUsageLogRepo) ListWithFilters(ctx context.Context, params paginatio // Apply filters var filtered []service.UsageLog for _, log := range logs { - // Apply APIKeyID filter - if filters.APIKeyID > 0 && log.APIKeyID != filters.APIKeyID { + // Apply ApiKeyID filter + if filters.ApiKeyID > 0 && log.ApiKeyID != filters.ApiKeyID { continue } // Apply Model filter @@ -1151,8 +1156,8 @@ func paginationResult(total int64, params pagination.PaginationParams) *paginati // Ensure compile-time interface compliance. var ( _ service.UserRepository = (*stubUserRepo)(nil) - _ service.APIKeyRepository = (*stubAPIKeyRepo)(nil) - _ service.APIKeyCache = (*stubAPIKeyCache)(nil) + _ service.ApiKeyRepository = (*stubApiKeyRepo)(nil) + _ service.ApiKeyCache = (*stubApiKeyCache)(nil) _ service.GroupRepository = (*stubGroupRepo)(nil) _ service.UserSubscriptionRepository = (*stubUserSubscriptionRepo)(nil) _ service.UsageLogRepository = (*stubUsageLogRepo)(nil) diff --git a/backend/internal/service/gemini_multiplatform_test.go b/backend/internal/service/gemini_multiplatform_test.go index 5c604f0f..c9b28a7b 100644 --- a/backend/internal/service/gemini_multiplatform_test.go +++ b/backend/internal/service/gemini_multiplatform_test.go @@ -121,6 +121,12 @@ func (m *mockAccountRepoForGemini) SetRateLimited(ctx context.Context, id int64, func (m *mockAccountRepoForGemini) SetOverloaded(ctx context.Context, id int64, until time.Time) error { return nil } +func (m *mockAccountRepoForGemini) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error { + return nil +} +func (m *mockAccountRepoForGemini) ClearTempUnschedulable(ctx context.Context, id int64) error { + return nil +} func (m *mockAccountRepoForGemini) ClearRateLimit(ctx context.Context, id int64) error { return nil } func (m *mockAccountRepoForGemini) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error { return nil @@ -275,7 +281,7 @@ func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_OAuthPr repo := &mockAccountRepoForGemini{ accounts: []Account{ - {ID: 1, Platform: PlatformGemini, Type: AccountTypeAPIKey, Priority: 1, Status: StatusActive, Schedulable: true, LastUsedAt: nil}, + {ID: 1, Platform: PlatformGemini, Type: AccountTypeApiKey, Priority: 1, Status: StatusActive, Schedulable: true, LastUsedAt: nil}, {ID: 2, Platform: PlatformGemini, Type: AccountTypeOAuth, Priority: 1, Status: StatusActive, Schedulable: true, LastUsedAt: nil}, }, accountsByID: map[int64]*Account{}, diff --git a/backend/internal/service/token_refresher_test.go b/backend/internal/service/token_refresher_test.go index c7505037..0a5135ac 100644 --- a/backend/internal/service/token_refresher_test.go +++ b/backend/internal/service/token_refresher_test.go @@ -197,7 +197,7 @@ func TestClaudeTokenRefresher_CanRefresh(t *testing.T) { { name: "anthropic api-key - cannot refresh", platform: PlatformAnthropic, - accType: AccountTypeAPIKey, + accType: AccountTypeApiKey, want: false, }, { diff --git a/backend/migrations/026_ops_metrics_aggregation_tables.sql b/backend/migrations/026_ops_metrics_aggregation_tables.sql new file mode 100644 index 00000000..e0e47265 --- /dev/null +++ b/backend/migrations/026_ops_metrics_aggregation_tables.sql @@ -0,0 +1,104 @@ +-- Ops monitoring: pre-aggregation tables for dashboard queries +-- +-- Problem: +-- The ops dashboard currently runs percentile_cont + GROUP BY queries over large raw tables +-- (usage_logs, ops_error_logs). These will get slower as data grows. +-- +-- This migration adds schema-only aggregation tables that can be populated by a future background job. +-- No triggers/functions/jobs are created here (schema only). + +-- ============================================ +-- Hourly aggregates (per provider/platform) +-- ============================================ + +CREATE TABLE IF NOT EXISTS ops_metrics_hourly ( + -- Start of the hour bucket (recommended: UTC). + bucket_start TIMESTAMPTZ NOT NULL, + + -- Provider/platform label (e.g. anthropic/openai/gemini). Mirrors ops_* queries that GROUP BY platform. + platform VARCHAR(50) NOT NULL, + + -- Traffic counts (use these to compute rates reliably across ranges). + request_count BIGINT NOT NULL DEFAULT 0, + success_count BIGINT NOT NULL DEFAULT 0, + error_count BIGINT NOT NULL DEFAULT 0, + + -- Error breakdown used by provider health UI. + error_4xx_count BIGINT NOT NULL DEFAULT 0, + error_5xx_count BIGINT NOT NULL DEFAULT 0, + timeout_count BIGINT NOT NULL DEFAULT 0, + + -- Latency aggregates (ms). + avg_latency_ms DOUBLE PRECISION, + p99_latency_ms DOUBLE PRECISION, + + -- Convenience rate (percentage, 0-100). Still keep counts as source of truth. + error_rate DOUBLE PRECISION NOT NULL DEFAULT 0, + + -- When this row was last (re)computed by the background job. + computed_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + + PRIMARY KEY (bucket_start, platform) +); + +CREATE INDEX IF NOT EXISTS idx_ops_metrics_hourly_platform_bucket_start + ON ops_metrics_hourly (platform, bucket_start DESC); + +COMMENT ON TABLE ops_metrics_hourly IS 'Pre-aggregated hourly ops metrics by provider/platform to speed up dashboard queries.'; +COMMENT ON COLUMN ops_metrics_hourly.bucket_start IS 'Start timestamp of the hour bucket (recommended UTC).'; +COMMENT ON COLUMN ops_metrics_hourly.platform IS 'Provider/platform label (anthropic/openai/gemini, etc).'; +COMMENT ON COLUMN ops_metrics_hourly.error_rate IS 'Error rate percentage for the bucket (0-100). Counts remain the source of truth.'; +COMMENT ON COLUMN ops_metrics_hourly.computed_at IS 'When the row was last computed/refreshed.'; + +-- ============================================ +-- Daily aggregates (per provider/platform) +-- ============================================ + +CREATE TABLE IF NOT EXISTS ops_metrics_daily ( + -- Day bucket (recommended: UTC date). + bucket_date DATE NOT NULL, + platform VARCHAR(50) NOT NULL, + + request_count BIGINT NOT NULL DEFAULT 0, + success_count BIGINT NOT NULL DEFAULT 0, + error_count BIGINT NOT NULL DEFAULT 0, + + error_4xx_count BIGINT NOT NULL DEFAULT 0, + error_5xx_count BIGINT NOT NULL DEFAULT 0, + timeout_count BIGINT NOT NULL DEFAULT 0, + + avg_latency_ms DOUBLE PRECISION, + p99_latency_ms DOUBLE PRECISION, + + error_rate DOUBLE PRECISION NOT NULL DEFAULT 0, + computed_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + + PRIMARY KEY (bucket_date, platform) +); + +CREATE INDEX IF NOT EXISTS idx_ops_metrics_daily_platform_bucket_date + ON ops_metrics_daily (platform, bucket_date DESC); + +COMMENT ON TABLE ops_metrics_daily IS 'Pre-aggregated daily ops metrics by provider/platform for longer-term trends.'; +COMMENT ON COLUMN ops_metrics_daily.bucket_date IS 'UTC date of the day bucket (recommended).'; + +-- ============================================ +-- Population strategy (future background job) +-- ============================================ +-- +-- Suggested approach: +-- 1) Compute hourly buckets from raw logs using UTC time-bucketing, then UPSERT into ops_metrics_hourly. +-- 2) Compute daily buckets either directly from raw logs or by rolling up ops_metrics_hourly. +-- +-- Notes: +-- - Ensure the job uses a consistent timezone (recommended: SET TIME ZONE ''UTC'') to avoid bucket drift. +-- - Derive the provider/platform similarly to existing dashboard queries: +-- usage_logs: COALESCE(NULLIF(groups.platform, ''), accounts.platform, '') +-- ops_error_logs: COALESCE(NULLIF(ops_error_logs.platform, ''), groups.platform, accounts.platform, '') +-- - Keep request_count/success_count/error_count as the authoritative values; compute error_rate from counts. +-- +-- Example (hourly) shape (pseudo-SQL): +-- INSERT INTO ops_metrics_hourly (...) +-- SELECT date_trunc('hour', created_at) AS bucket_start, platform, ... +-- FROM (/* aggregate usage_logs + ops_error_logs */) s +-- ON CONFLICT (bucket_start, platform) DO UPDATE SET ...; diff --git a/backend/migrations/027_usage_billing_consistency.sql b/backend/migrations/027_usage_billing_consistency.sql new file mode 100644 index 00000000..eba68512 --- /dev/null +++ b/backend/migrations/027_usage_billing_consistency.sql @@ -0,0 +1,58 @@ +-- 027_usage_billing_consistency.sql +-- Ensure usage_logs idempotency (request_id, api_key_id) and add reconciliation infrastructure. + +-- ----------------------------------------------------------------------------- +-- 1) Normalize legacy request_id values +-- ----------------------------------------------------------------------------- +-- Historically request_id may be inserted as empty string. Convert it to NULL so +-- the upcoming unique index does not break on repeated "" values. +UPDATE usage_logs +SET request_id = NULL +WHERE request_id = ''; + +-- If duplicates already exist for the same (request_id, api_key_id), keep the +-- first row and NULL-out request_id for the rest so the unique index can be +-- created without deleting historical logs. +WITH ranked AS ( + SELECT + id, + ROW_NUMBER() OVER (PARTITION BY api_key_id, request_id ORDER BY id) AS rn + FROM usage_logs + WHERE request_id IS NOT NULL +) +UPDATE usage_logs ul +SET request_id = NULL +FROM ranked r +WHERE ul.id = r.id + AND r.rn > 1; + +-- ----------------------------------------------------------------------------- +-- 2) Idempotency constraint for usage_logs +-- ----------------------------------------------------------------------------- +CREATE UNIQUE INDEX IF NOT EXISTS idx_usage_logs_request_id_api_key_unique + ON usage_logs (request_id, api_key_id); + +-- ----------------------------------------------------------------------------- +-- 3) Reconciliation infrastructure: billing ledger for usage charges +-- ----------------------------------------------------------------------------- +CREATE TABLE IF NOT EXISTS billing_usage_entries ( + id BIGSERIAL PRIMARY KEY, + usage_log_id BIGINT NOT NULL REFERENCES usage_logs(id) ON DELETE CASCADE, + user_id BIGINT NOT NULL REFERENCES users(id) ON DELETE CASCADE, + api_key_id BIGINT NOT NULL REFERENCES api_keys(id) ON DELETE CASCADE, + subscription_id BIGINT REFERENCES user_subscriptions(id) ON DELETE SET NULL, + billing_type SMALLINT NOT NULL, + applied BOOLEAN NOT NULL DEFAULT TRUE, + delta_usd DECIMAL(20, 10) NOT NULL DEFAULT 0, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + +CREATE UNIQUE INDEX IF NOT EXISTS billing_usage_entries_usage_log_id_unique + ON billing_usage_entries (usage_log_id); + +CREATE INDEX IF NOT EXISTS idx_billing_usage_entries_user_time + ON billing_usage_entries (user_id, created_at); + +CREATE INDEX IF NOT EXISTS idx_billing_usage_entries_created_at + ON billing_usage_entries (created_at); +