fix(test): 修复测试和添加数据库迁移

测试修复:
- 修复集成测试中的重复键冲突问题
- 移除 JSON 中多余的尾随逗号
- 新增 inprocess_transport_test.go
- 更新 haiku 模型映射测试用例

数据库迁移:
- 026: 运营指标聚合表
- 027: 使用量与计费一致性约束
This commit is contained in:
ianshaw
2026-01-03 06:36:35 -08:00
parent ff3f514f6b
commit b1702de522
16 changed files with 495 additions and 244 deletions

View File

@@ -5,28 +5,20 @@ import (
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/oauth"
"github.com/imroc/req/v3"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
type ClaudeOAuthServiceSuite struct {
suite.Suite
srv *httptest.Server
client *claudeOAuthService
}
func (s *ClaudeOAuthServiceSuite) TearDownTest() {
if s.srv != nil {
s.srv.Close()
s.srv = nil
}
}
// requestCapture holds captured request data for assertions in the main goroutine.
type requestCapture struct {
path string
@@ -37,6 +29,12 @@ type requestCapture struct {
contentType string
}
func newTestReqClient(rt http.RoundTripper) *req.Client {
c := req.C()
c.GetClient().Transport = rt
return c
}
func (s *ClaudeOAuthServiceSuite) TestGetOrganizationUUID() {
tests := []struct {
name string
@@ -83,17 +81,17 @@ func (s *ClaudeOAuthServiceSuite) TestGetOrganizationUUID() {
s.Run(tt.name, func() {
var captured requestCapture
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
rt := newInProcessTransport(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
captured.path = r.URL.Path
captured.cookies = r.Cookies()
tt.handler(w, r)
}))
defer s.srv.Close()
}), nil)
client, ok := NewClaudeOAuthClient().(*claudeOAuthService)
require.True(s.T(), ok, "type assertion failed")
s.client = client
s.client.baseURL = s.srv.URL
s.client.baseURL = "http://in-process"
s.client.clientFactory = func(string) *req.Client { return newTestReqClient(rt) }
got, err := s.client.GetOrganizationUUID(context.Background(), "sess", "")
@@ -158,20 +156,20 @@ func (s *ClaudeOAuthServiceSuite) TestGetAuthorizationCode() {
s.Run(tt.name, func() {
var captured requestCapture
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
rt := newInProcessTransport(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
captured.path = r.URL.Path
captured.method = r.Method
captured.cookies = r.Cookies()
captured.body, _ = io.ReadAll(r.Body)
_ = json.Unmarshal(captured.body, &captured.bodyJSON)
tt.handler(w, r)
}))
defer s.srv.Close()
}), nil)
client, ok := NewClaudeOAuthClient().(*claudeOAuthService)
require.True(s.T(), ok, "type assertion failed")
s.client = client
s.client.baseURL = s.srv.URL
s.client.baseURL = "http://in-process"
s.client.clientFactory = func(string) *req.Client { return newTestReqClient(rt) }
code, err := s.client.GetAuthorizationCode(context.Background(), "sess", "org-1", oauth.ScopeProfile, "cc", "st", "")
@@ -266,19 +264,19 @@ func (s *ClaudeOAuthServiceSuite) TestExchangeCodeForToken() {
s.Run(tt.name, func() {
var captured requestCapture
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
rt := newInProcessTransport(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
captured.method = r.Method
captured.contentType = r.Header.Get("Content-Type")
captured.body, _ = io.ReadAll(r.Body)
_ = json.Unmarshal(captured.body, &captured.bodyJSON)
tt.handler(w, r)
}))
defer s.srv.Close()
}), nil)
client, ok := NewClaudeOAuthClient().(*claudeOAuthService)
require.True(s.T(), ok, "type assertion failed")
s.client = client
s.client.tokenURL = s.srv.URL
s.client.tokenURL = "http://in-process/token"
s.client.clientFactory = func(string) *req.Client { return newTestReqClient(rt) }
resp, err := s.client.ExchangeCodeForToken(context.Background(), tt.code, "ver", "", "", tt.isSetupToken)
@@ -362,19 +360,19 @@ func (s *ClaudeOAuthServiceSuite) TestRefreshToken() {
s.Run(tt.name, func() {
var captured requestCapture
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
rt := newInProcessTransport(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
captured.method = r.Method
captured.contentType = r.Header.Get("Content-Type")
captured.body, _ = io.ReadAll(r.Body)
_ = json.Unmarshal(captured.body, &captured.bodyJSON)
tt.handler(w, r)
}))
defer s.srv.Close()
}), nil)
client, ok := NewClaudeOAuthClient().(*claudeOAuthService)
require.True(s.T(), ok, "type assertion failed")
s.client = client
s.client.tokenURL = s.srv.URL
s.client.tokenURL = "http://in-process/token"
s.client.clientFactory = func(string) *req.Client { return newTestReqClient(rt) }
resp, err := s.client.RefreshToken(context.Background(), "rt", "")

View File

@@ -33,7 +33,7 @@ type usageRequestCapture struct {
func (s *ClaudeUsageServiceSuite) TestFetchUsage_Success() {
var captured usageRequestCapture
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
captured.authorization = r.Header.Get("Authorization")
captured.anthropicBeta = r.Header.Get("anthropic-beta")
@@ -59,7 +59,7 @@ func (s *ClaudeUsageServiceSuite) TestFetchUsage_Success() {
}
func (s *ClaudeUsageServiceSuite) TestFetchUsage_NonOK() {
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusUnauthorized)
_, _ = io.WriteString(w, "nope")
}))
@@ -73,7 +73,7 @@ func (s *ClaudeUsageServiceSuite) TestFetchUsage_NonOK() {
}
func (s *ClaudeUsageServiceSuite) TestFetchUsage_BadJSON() {
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, _ = io.WriteString(w, "not-json")
}))
@@ -86,7 +86,7 @@ func (s *ClaudeUsageServiceSuite) TestFetchUsage_BadJSON() {
}
func (s *ClaudeUsageServiceSuite) TestFetchUsage_ContextCancel() {
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Never respond - simulate slow server
<-r.Context().Done()
}))

View File

@@ -49,7 +49,7 @@ func (s *GitHubReleaseServiceSuite) TearDownTest() {
}
func (s *GitHubReleaseServiceSuite) TestDownloadFile_EnforcesMaxSize_ContentLength() {
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Length", "100")
w.WriteHeader(http.StatusOK)
_, _ = w.Write(bytes.Repeat([]byte("a"), 100))
@@ -68,7 +68,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_EnforcesMaxSize_ContentLeng
}
func (s *GitHubReleaseServiceSuite) TestDownloadFile_EnforcesMaxSize_Chunked() {
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Force chunked encoding (unknown Content-Length) by flushing headers before writing.
w.WriteHeader(http.StatusOK)
if fl, ok := w.(http.Flusher); ok {
@@ -95,7 +95,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_EnforcesMaxSize_Chunked() {
}
func (s *GitHubReleaseServiceSuite) TestDownloadFile_Success() {
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
if fl, ok := w.(http.Flusher); ok {
fl.Flush()
@@ -123,7 +123,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_Success() {
}
func (s *GitHubReleaseServiceSuite) TestDownloadFile_404() {
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
}))
@@ -140,7 +140,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_404() {
}
func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_Success() {
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("sum"))
}))
@@ -155,7 +155,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_Success() {
}
func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_Non200() {
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
}))
@@ -168,7 +168,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_Non200() {
}
func (s *GitHubReleaseServiceSuite) TestDownloadFile_ContextCancel() {
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
<-r.Context().Done()
}))
@@ -195,7 +195,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_InvalidURL() {
}
func (s *GitHubReleaseServiceSuite) TestDownloadFile_InvalidDestPath() {
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("content"))
}))
@@ -233,7 +233,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_Success() {
]
}`
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(s.T(), "/repos/test/repo/releases/latest", r.URL.Path)
require.Equal(s.T(), "application/vnd.github.v3+json", r.Header.Get("Accept"))
require.Equal(s.T(), "Sub2API-Updater", r.Header.Get("User-Agent"))
@@ -258,7 +258,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_Success() {
}
func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_Non200() {
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
}))
@@ -274,7 +274,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_Non200() {
}
func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_InvalidJSON() {
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("not valid json"))
}))
@@ -290,7 +290,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_InvalidJSON() {
}
func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_ContextCancel() {
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
<-r.Context().Done()
}))
@@ -308,7 +308,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_ContextCancel() {
}
func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_ContextCancel() {
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
<-r.Context().Done()
}))

View File

@@ -3,7 +3,6 @@ package repository
import (
"io"
"net/http"
"net/http/httptest"
"sync/atomic"
"testing"
"time"
@@ -93,7 +92,7 @@ func (s *HTTPUpstreamSuite) TestAcquireClient_OverLimitReturnsError() {
// 验证空代理 URL 时请求直接发送到目标服务器
func (s *HTTPUpstreamSuite) TestDo_WithoutProxy_GoesDirect() {
// 创建模拟上游服务器
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
upstream := newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = io.WriteString(w, "direct")
}))
s.T().Cleanup(upstream.Close)
@@ -115,7 +114,7 @@ func (s *HTTPUpstreamSuite) TestDo_WithHTTPProxy_UsesProxy() {
// 用于接收代理请求的通道
seen := make(chan string, 1)
// 创建模拟代理服务器
proxySrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
proxySrv := newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
seen <- r.RequestURI // 记录请求 URI
_, _ = io.WriteString(w, "proxied")
}))
@@ -145,7 +144,7 @@ func (s *HTTPUpstreamSuite) TestDo_WithHTTPProxy_UsesProxy() {
// TestDo_EmptyProxy_UsesDirect 测试空代理字符串
// 验证空字符串代理等同于直连
func (s *HTTPUpstreamSuite) TestDo_EmptyProxy_UsesDirect() {
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
upstream := newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = io.WriteString(w, "direct-empty")
}))
s.T().Cleanup(upstream.Close)

View File

@@ -0,0 +1,63 @@
package repository
import (
"bytes"
"io"
"net"
"net/http"
"net/http/httptest"
"sync"
"testing"
)
type roundTripFunc func(*http.Request) (*http.Response, error)
func (f roundTripFunc) RoundTrip(r *http.Request) (*http.Response, error) { return f(r) }
// newInProcessTransport adapts an http.HandlerFunc into an http.RoundTripper without opening sockets.
// It captures the request body (if any) and then rewinds it before invoking the handler.
func newInProcessTransport(handler http.HandlerFunc, capture func(r *http.Request, body []byte)) http.RoundTripper {
return roundTripFunc(func(r *http.Request) (*http.Response, error) {
var body []byte
if r.Body != nil {
body, _ = io.ReadAll(r.Body)
_ = r.Body.Close()
r.Body = io.NopCloser(bytes.NewReader(body))
}
if capture != nil {
capture(r, body)
}
rec := httptest.NewRecorder()
handler(rec, r)
return rec.Result(), nil
})
}
var (
canListenOnce sync.Once
canListen bool
canListenErr error
)
func localListenerAvailable() bool {
canListenOnce.Do(func() {
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
canListenErr = err
canListen = false
return
}
_ = ln.Close()
canListen = true
})
return canListen
}
func newLocalTestServer(tb testing.TB, handler http.Handler) *httptest.Server {
tb.Helper()
if !localListenerAvailable() {
tb.Skipf("local listeners are not permitted in this environment: %v", canListenErr)
}
return httptest.NewServer(handler)
}

View File

@@ -34,7 +34,7 @@ func (s *OpenAIOAuthServiceSuite) TearDownTest() {
}
func (s *OpenAIOAuthServiceSuite) setupServer(handler http.HandlerFunc) {
s.srv = httptest.NewServer(handler)
s.srv = newLocalTestServer(s.T(), handler)
s.svc = &openaiOAuthService{tokenURL: s.srv.URL}
}

View File

@@ -32,7 +32,7 @@ func (s *PricingServiceSuite) TearDownTest() {
}
func (s *PricingServiceSuite) setupServer(handler http.HandlerFunc) {
s.srv = httptest.NewServer(handler)
s.srv = newLocalTestServer(s.T(), handler)
}
func (s *PricingServiceSuite) TestFetchPricingJSON_Success() {

View File

@@ -31,7 +31,7 @@ func (s *ProxyProbeServiceSuite) TearDownTest() {
}
func (s *ProxyProbeServiceSuite) setupProxyServer(handler http.HandlerFunc) {
s.proxySrv = httptest.NewServer(handler)
s.proxySrv = newLocalTestServer(s.T(), handler)
}
func (s *ProxyProbeServiceSuite) TestProbeProxy_InvalidProxyURL() {

View File

@@ -3,9 +3,9 @@ package repository
import (
"context"
"encoding/json"
"errors"
"io"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
@@ -18,7 +18,6 @@ import (
type TurnstileServiceSuite struct {
suite.Suite
ctx context.Context
srv *httptest.Server
verifier *turnstileVerifier
received chan url.Values
}
@@ -31,20 +30,15 @@ func (s *TurnstileServiceSuite) SetupTest() {
s.verifier = verifier
}
func (s *TurnstileServiceSuite) TearDownTest() {
if s.srv != nil {
s.srv.Close()
s.srv = nil
func (s *TurnstileServiceSuite) setupTransport(handler http.HandlerFunc) {
s.verifier.verifyURL = "http://in-process/turnstile"
s.verifier.httpClient = &http.Client{
Transport: newInProcessTransport(handler, nil),
}
}
func (s *TurnstileServiceSuite) setupServer(handler http.HandlerFunc) {
s.srv = httptest.NewServer(handler)
s.verifier.verifyURL = s.srv.URL
}
func (s *TurnstileServiceSuite) TestVerifyToken_SendsFormAndDecodesJSON() {
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
s.setupTransport(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Capture form data in main goroutine context later
body, _ := io.ReadAll(r.Body)
values, _ := url.ParseQuery(string(body))
@@ -72,7 +66,7 @@ func (s *TurnstileServiceSuite) TestVerifyToken_SendsFormAndDecodesJSON() {
func (s *TurnstileServiceSuite) TestVerifyToken_ContentType() {
var contentType string
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
s.setupTransport(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
contentType = r.Header.Get("Content-Type")
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(service.TurnstileVerifyResponse{Success: true})
@@ -84,7 +78,7 @@ func (s *TurnstileServiceSuite) TestVerifyToken_ContentType() {
}
func (s *TurnstileServiceSuite) TestVerifyToken_EmptyRemoteIP_NotSent() {
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
s.setupTransport(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, _ := io.ReadAll(r.Body)
values, _ := url.ParseQuery(string(body))
s.received <- values
@@ -105,15 +99,19 @@ func (s *TurnstileServiceSuite) TestVerifyToken_EmptyRemoteIP_NotSent() {
}
func (s *TurnstileServiceSuite) TestVerifyToken_RequestError() {
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
s.srv.Close()
s.verifier.verifyURL = "http://in-process/turnstile"
s.verifier.httpClient = &http.Client{
Transport: roundTripFunc(func(*http.Request) (*http.Response, error) {
return nil, errors.New("dial failed")
}),
}
_, err := s.verifier.VerifyToken(s.ctx, "sk", "token", "1.1.1.1")
require.Error(s.T(), err, "expected error when server is closed")
}
func (s *TurnstileServiceSuite) TestVerifyToken_InvalidJSON() {
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
s.setupTransport(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, _ = io.WriteString(w, "not-valid-json")
}))
@@ -123,7 +121,7 @@ func (s *TurnstileServiceSuite) TestVerifyToken_InvalidJSON() {
}
func (s *TurnstileServiceSuite) TestVerifyToken_SuccessFalse() {
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
s.setupTransport(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(service.TurnstileVerifyResponse{
Success: false,

View File

@@ -3,6 +3,7 @@ package repository
import (
"context"
"database/sql"
"errors"
"fmt"
"os"
"strings"
@@ -70,6 +71,9 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
createdAt = time.Now()
}
requestID := strings.TrimSpace(log.RequestID)
log.RequestID = requestID
rateMultiplier := log.RateMultiplier
query := `
@@ -107,6 +111,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
$14, $15, $16, $17, $18, $19,
$20, $21, $22, $23, $24, $25
)
ON CONFLICT (request_id, api_key_id) DO NOTHING
RETURNING id, created_at
`
@@ -115,11 +120,16 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
duration := nullInt(log.DurationMs)
firstToken := nullInt(log.FirstTokenMs)
var requestIDArg any
if requestID != "" {
requestIDArg = requestID
}
args := []any{
log.UserID,
log.APIKeyID,
log.ApiKeyID,
log.AccountID,
log.RequestID,
requestIDArg,
log.Model,
groupID,
subscriptionID,
@@ -143,7 +153,14 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
createdAt,
}
if err := scanSingleRow(ctx, r.sql, query, args, &log.ID, &log.CreatedAt); err != nil {
return err
if errors.Is(err, sql.ErrNoRows) && requestID != "" {
selectQuery := "SELECT id, created_at FROM usage_logs WHERE request_id = $1 AND api_key_id = $2"
if err := scanSingleRow(ctx, r.sql, selectQuery, []any{requestID, log.ApiKeyID}, &log.ID, &log.CreatedAt); err != nil {
return err
}
} else {
return err
}
}
log.RateMultiplier = rateMultiplier
return nil
@@ -183,7 +200,7 @@ func (r *usageLogRepository) ListByUser(ctx context.Context, userID int64, param
return r.listUsageLogsWithPagination(ctx, "WHERE user_id = $1", []any{userID}, params)
}
func (r *usageLogRepository) ListByAPIKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
func (r *usageLogRepository) ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
return r.listUsageLogsWithPagination(ctx, "WHERE api_key_id = $1", []any{apiKeyID}, params)
}
@@ -270,8 +287,8 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS
r.sql,
apiKeyStatsQuery,
[]any{service.StatusActive},
&stats.TotalAPIKeys,
&stats.ActiveAPIKeys,
&stats.TotalApiKeys,
&stats.ActiveApiKeys,
); err != nil {
return nil, err
}
@@ -418,8 +435,8 @@ func (r *usageLogRepository) GetUserStatsAggregated(ctx context.Context, userID
return &stats, nil
}
// GetAPIKeyStatsAggregated returns aggregated usage statistics for an API key using database-level aggregation
func (r *usageLogRepository) GetAPIKeyStatsAggregated(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
// GetApiKeyStatsAggregated returns aggregated usage statistics for an API key using database-level aggregation
func (r *usageLogRepository) GetApiKeyStatsAggregated(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
query := `
SELECT
COUNT(*) as total_requests,
@@ -623,7 +640,7 @@ func resolveUsageStatsTimezone() string {
return "UTC"
}
func (r *usageLogRepository) ListByAPIKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
func (r *usageLogRepository) ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE api_key_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC"
logs, err := r.queryUsageLogs(ctx, query, apiKeyID, startTime, endTime)
return logs, nil, err
@@ -709,11 +726,11 @@ type ModelStat = usagestats.ModelStat
// UserUsageTrendPoint represents user usage trend data point
type UserUsageTrendPoint = usagestats.UserUsageTrendPoint
// APIKeyUsageTrendPoint represents API key usage trend data point
type APIKeyUsageTrendPoint = usagestats.APIKeyUsageTrendPoint
// ApiKeyUsageTrendPoint represents API key usage trend data point
type ApiKeyUsageTrendPoint = usagestats.ApiKeyUsageTrendPoint
// GetAPIKeyUsageTrend returns usage trend data grouped by API key and date
func (r *usageLogRepository) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) (results []APIKeyUsageTrendPoint, err error) {
// GetApiKeyUsageTrend returns usage trend data grouped by API key and date
func (r *usageLogRepository) GetApiKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) (results []ApiKeyUsageTrendPoint, err error) {
dateFormat := "YYYY-MM-DD"
if granularity == "hour" {
dateFormat = "YYYY-MM-DD HH24:00"
@@ -755,10 +772,10 @@ func (r *usageLogRepository) GetAPIKeyUsageTrend(ctx context.Context, startTime,
}
}()
results = make([]APIKeyUsageTrendPoint, 0)
results = make([]ApiKeyUsageTrendPoint, 0)
for rows.Next() {
var row APIKeyUsageTrendPoint
if err = rows.Scan(&row.Date, &row.APIKeyID, &row.KeyName, &row.Requests, &row.Tokens); err != nil {
var row ApiKeyUsageTrendPoint
if err = rows.Scan(&row.Date, &row.ApiKeyID, &row.KeyName, &row.Requests, &row.Tokens); err != nil {
return nil, err
}
results = append(results, row)
@@ -844,7 +861,7 @@ func (r *usageLogRepository) GetUserDashboardStats(ctx context.Context, userID i
r.sql,
"SELECT COUNT(*) FROM api_keys WHERE user_id = $1 AND deleted_at IS NULL",
[]any{userID},
&stats.TotalAPIKeys,
&stats.TotalApiKeys,
); err != nil {
return nil, err
}
@@ -853,7 +870,7 @@ func (r *usageLogRepository) GetUserDashboardStats(ctx context.Context, userID i
r.sql,
"SELECT COUNT(*) FROM api_keys WHERE user_id = $1 AND status = $2 AND deleted_at IS NULL",
[]any{userID, service.StatusActive},
&stats.ActiveAPIKeys,
&stats.ActiveApiKeys,
); err != nil {
return nil, err
}
@@ -1023,9 +1040,9 @@ func (r *usageLogRepository) ListWithFilters(ctx context.Context, params paginat
conditions = append(conditions, fmt.Sprintf("user_id = $%d", len(args)+1))
args = append(args, filters.UserID)
}
if filters.APIKeyID > 0 {
if filters.ApiKeyID > 0 {
conditions = append(conditions, fmt.Sprintf("api_key_id = $%d", len(args)+1))
args = append(args, filters.APIKeyID)
args = append(args, filters.ApiKeyID)
}
if filters.AccountID > 0 {
conditions = append(conditions, fmt.Sprintf("account_id = $%d", len(args)+1))
@@ -1145,18 +1162,18 @@ func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs
return result, nil
}
// BatchAPIKeyUsageStats represents usage stats for a single API key
type BatchAPIKeyUsageStats = usagestats.BatchAPIKeyUsageStats
// BatchApiKeyUsageStats represents usage stats for a single API key
type BatchApiKeyUsageStats = usagestats.BatchApiKeyUsageStats
// GetBatchAPIKeyUsageStats gets today and total actual_cost for multiple API keys
func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*BatchAPIKeyUsageStats, error) {
result := make(map[int64]*BatchAPIKeyUsageStats)
// GetBatchApiKeyUsageStats gets today and total actual_cost for multiple API keys
func (r *usageLogRepository) GetBatchApiKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*BatchApiKeyUsageStats, error) {
result := make(map[int64]*BatchApiKeyUsageStats)
if len(apiKeyIDs) == 0 {
return result, nil
}
for _, id := range apiKeyIDs {
result[id] = &BatchAPIKeyUsageStats{APIKeyID: id}
result[id] = &BatchApiKeyUsageStats{ApiKeyID: id}
}
query := `
@@ -1582,7 +1599,7 @@ func (r *usageLogRepository) hydrateUsageLogAssociations(ctx context.Context, lo
if err != nil {
return err
}
apiKeys, err := r.loadAPIKeys(ctx, ids.apiKeyIDs)
apiKeys, err := r.loadApiKeys(ctx, ids.apiKeyIDs)
if err != nil {
return err
}
@@ -1603,8 +1620,8 @@ func (r *usageLogRepository) hydrateUsageLogAssociations(ctx context.Context, lo
if user, ok := users[logs[i].UserID]; ok {
logs[i].User = user
}
if key, ok := apiKeys[logs[i].APIKeyID]; ok {
logs[i].APIKey = key
if key, ok := apiKeys[logs[i].ApiKeyID]; ok {
logs[i].ApiKey = key
}
if acc, ok := accounts[logs[i].AccountID]; ok {
logs[i].Account = acc
@@ -1642,7 +1659,7 @@ func collectUsageLogIDs(logs []service.UsageLog) usageLogIDs {
for i := range logs {
userIDs[logs[i].UserID] = struct{}{}
apiKeyIDs[logs[i].APIKeyID] = struct{}{}
apiKeyIDs[logs[i].ApiKeyID] = struct{}{}
accountIDs[logs[i].AccountID] = struct{}{}
if logs[i].GroupID != nil {
groupIDs[*logs[i].GroupID] = struct{}{}
@@ -1676,12 +1693,12 @@ func (r *usageLogRepository) loadUsers(ctx context.Context, ids []int64) (map[in
return out, nil
}
func (r *usageLogRepository) loadAPIKeys(ctx context.Context, ids []int64) (map[int64]*service.APIKey, error) {
out := make(map[int64]*service.APIKey)
func (r *usageLogRepository) loadApiKeys(ctx context.Context, ids []int64) (map[int64]*service.ApiKey, error) {
out := make(map[int64]*service.ApiKey)
if len(ids) == 0 {
return out, nil
}
models, err := r.client.APIKey.Query().Where(dbapikey.IDIn(ids...)).All(ctx)
models, err := r.client.ApiKey.Query().Where(dbapikey.IDIn(ids...)).All(ctx)
if err != nil {
return nil, err
}
@@ -1800,7 +1817,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
log := &service.UsageLog{
ID: id,
UserID: userID,
APIKeyID: apiKeyID,
ApiKeyID: apiKeyID,
AccountID: accountID,
Model: model,
InputTokens: inputTokens,

View File

@@ -7,6 +7,8 @@ import (
"testing"
"time"
"github.com/google/uuid"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
@@ -35,11 +37,12 @@ func TestUsageLogRepoSuite(t *testing.T) {
suite.Run(t, new(UsageLogRepoSuite))
}
func (s *UsageLogRepoSuite) createUsageLog(user *service.User, apiKey *service.APIKey, account *service.Account, inputTokens, outputTokens int, cost float64, createdAt time.Time) *service.UsageLog {
func (s *UsageLogRepoSuite) createUsageLog(user *service.User, apiKey *service.ApiKey, account *service.Account, inputTokens, outputTokens int, cost float64, createdAt time.Time) *service.UsageLog {
log := &service.UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
ApiKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: uuid.New().String(), // Generate unique RequestID for each log
Model: "claude-3",
InputTokens: inputTokens,
OutputTokens: outputTokens,
@@ -55,12 +58,12 @@ func (s *UsageLogRepoSuite) createUsageLog(user *service.User, apiKey *service.A
func (s *UsageLogRepoSuite) TestCreate() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "create@test.com"})
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-create", Name: "k"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-create", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-create"})
log := &service.UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
ApiKeyID: apiKey.ID,
AccountID: account.ID,
Model: "claude-3",
InputTokens: 10,
@@ -76,7 +79,7 @@ func (s *UsageLogRepoSuite) TestCreate() {
func (s *UsageLogRepoSuite) TestGetByID() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "getbyid@test.com"})
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-getbyid", Name: "k"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-getbyid", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-getbyid"})
log := s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
@@ -96,7 +99,7 @@ func (s *UsageLogRepoSuite) TestGetByID_NotFound() {
func (s *UsageLogRepoSuite) TestDelete() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "delete@test.com"})
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-delete", Name: "k"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-delete", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-delete"})
log := s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
@@ -112,7 +115,7 @@ func (s *UsageLogRepoSuite) TestDelete() {
func (s *UsageLogRepoSuite) TestListByUser() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "listbyuser@test.com"})
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-listbyuser", Name: "k"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-listbyuser", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-listbyuser"})
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
@@ -124,18 +127,18 @@ func (s *UsageLogRepoSuite) TestListByUser() {
s.Require().Equal(int64(2), page.Total)
}
// --- ListByAPIKey ---
// --- ListByApiKey ---
func (s *UsageLogRepoSuite) TestListByAPIKey() {
func (s *UsageLogRepoSuite) TestListByApiKey() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "listbyapikey@test.com"})
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-listbyapikey", Name: "k"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-listbyapikey", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-listbyapikey"})
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
s.createUsageLog(user, apiKey, account, 15, 25, 0.6, time.Now())
logs, page, err := s.repo.ListByAPIKey(s.ctx, apiKey.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "ListByAPIKey")
logs, page, err := s.repo.ListByApiKey(s.ctx, apiKey.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "ListByApiKey")
s.Require().Len(logs, 2)
s.Require().Equal(int64(2), page.Total)
}
@@ -144,7 +147,7 @@ func (s *UsageLogRepoSuite) TestListByAPIKey() {
func (s *UsageLogRepoSuite) TestListByAccount() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "listbyaccount@test.com"})
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-listbyaccount", Name: "k"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-listbyaccount", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-listbyaccount"})
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
@@ -159,7 +162,7 @@ func (s *UsageLogRepoSuite) TestListByAccount() {
func (s *UsageLogRepoSuite) TestGetUserStats() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "userstats@test.com"})
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-userstats", Name: "k"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-userstats", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-userstats"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
@@ -179,7 +182,7 @@ func (s *UsageLogRepoSuite) TestGetUserStats() {
func (s *UsageLogRepoSuite) TestListWithFilters() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "filters@test.com"})
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-filters", Name: "k"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-filters", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-filters"})
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
@@ -211,8 +214,8 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
})
group := mustCreateGroup(s.T(), s.client, &service.Group{Name: "g-ul"})
apiKey1 := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: userToday.ID, Key: "sk-ul-1", Name: "ul1"})
mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: userOld.ID, Key: "sk-ul-2", Name: "ul2", Status: service.StatusDisabled})
apiKey1 := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: userToday.ID, Key: "sk-ul-1", Name: "ul1"})
mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: userOld.ID, Key: "sk-ul-2", Name: "ul2", Status: service.StatusDisabled})
resetAt := now.Add(10 * time.Minute)
accNormal := mustCreateAccount(s.T(), s.client, &service.Account{Name: "a-normal", Schedulable: true})
@@ -223,7 +226,7 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
d1, d2, d3 := 100, 200, 300
logToday := &service.UsageLog{
UserID: userToday.ID,
APIKeyID: apiKey1.ID,
ApiKeyID: apiKey1.ID,
AccountID: accNormal.ID,
Model: "claude-3",
GroupID: &group.ID,
@@ -240,7 +243,7 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
logOld := &service.UsageLog{
UserID: userOld.ID,
APIKeyID: apiKey1.ID,
ApiKeyID: apiKey1.ID,
AccountID: accNormal.ID,
Model: "claude-3",
InputTokens: 5,
@@ -254,7 +257,7 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
logPerf := &service.UsageLog{
UserID: userToday.ID,
APIKeyID: apiKey1.ID,
ApiKeyID: apiKey1.ID,
AccountID: accNormal.ID,
Model: "claude-3",
InputTokens: 1,
@@ -272,8 +275,8 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
s.Require().Equal(baseStats.TotalUsers+2, stats.TotalUsers, "TotalUsers mismatch")
s.Require().Equal(baseStats.TodayNewUsers+1, stats.TodayNewUsers, "TodayNewUsers mismatch")
s.Require().Equal(baseStats.ActiveUsers+1, stats.ActiveUsers, "ActiveUsers mismatch")
s.Require().Equal(baseStats.TotalAPIKeys+2, stats.TotalAPIKeys, "TotalAPIKeys mismatch")
s.Require().Equal(baseStats.ActiveAPIKeys+1, stats.ActiveAPIKeys, "ActiveAPIKeys mismatch")
s.Require().Equal(baseStats.TotalApiKeys+2, stats.TotalApiKeys, "TotalApiKeys mismatch")
s.Require().Equal(baseStats.ActiveApiKeys+1, stats.ActiveApiKeys, "ActiveApiKeys mismatch")
s.Require().Equal(baseStats.TotalAccounts+4, stats.TotalAccounts, "TotalAccounts mismatch")
s.Require().Equal(baseStats.ErrorAccounts+1, stats.ErrorAccounts, "ErrorAccounts mismatch")
s.Require().Equal(baseStats.RateLimitAccounts+1, stats.RateLimitAccounts, "RateLimitAccounts mismatch")
@@ -300,14 +303,14 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
func (s *UsageLogRepoSuite) TestGetUserDashboardStats() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "userdash@test.com"})
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-userdash", Name: "k"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-userdash", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-userdash"})
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
stats, err := s.repo.GetUserDashboardStats(s.ctx, user.ID)
s.Require().NoError(err, "GetUserDashboardStats")
s.Require().Equal(int64(1), stats.TotalAPIKeys)
s.Require().Equal(int64(1), stats.TotalApiKeys)
s.Require().Equal(int64(1), stats.TotalRequests)
}
@@ -315,7 +318,7 @@ func (s *UsageLogRepoSuite) TestGetUserDashboardStats() {
func (s *UsageLogRepoSuite) TestGetAccountTodayStats() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "acctoday@test.com"})
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-acctoday", Name: "k"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-acctoday", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-today"})
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
@@ -331,8 +334,8 @@ func (s *UsageLogRepoSuite) TestGetAccountTodayStats() {
func (s *UsageLogRepoSuite) TestGetBatchUserUsageStats() {
user1 := mustCreateUser(s.T(), s.client, &service.User{Email: "batch1@test.com"})
user2 := mustCreateUser(s.T(), s.client, &service.User{Email: "batch2@test.com"})
apiKey1 := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user1.ID, Key: "sk-batch1", Name: "k"})
apiKey2 := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user2.ID, Key: "sk-batch2", Name: "k"})
apiKey1 := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user1.ID, Key: "sk-batch1", Name: "k"})
apiKey2 := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user2.ID, Key: "sk-batch2", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-batch"})
s.createUsageLog(user1, apiKey1, account, 10, 20, 0.5, time.Now())
@@ -351,24 +354,24 @@ func (s *UsageLogRepoSuite) TestGetBatchUserUsageStats_Empty() {
s.Require().Empty(stats)
}
// --- GetBatchAPIKeyUsageStats ---
// --- GetBatchApiKeyUsageStats ---
func (s *UsageLogRepoSuite) TestGetBatchAPIKeyUsageStats() {
func (s *UsageLogRepoSuite) TestGetBatchApiKeyUsageStats() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "batchkey@test.com"})
apiKey1 := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-batchkey1", Name: "k1"})
apiKey2 := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-batchkey2", Name: "k2"})
apiKey1 := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-batchkey1", Name: "k1"})
apiKey2 := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-batchkey2", Name: "k2"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-batchkey"})
s.createUsageLog(user, apiKey1, account, 10, 20, 0.5, time.Now())
s.createUsageLog(user, apiKey2, account, 15, 25, 0.6, time.Now())
stats, err := s.repo.GetBatchAPIKeyUsageStats(s.ctx, []int64{apiKey1.ID, apiKey2.ID})
s.Require().NoError(err, "GetBatchAPIKeyUsageStats")
stats, err := s.repo.GetBatchApiKeyUsageStats(s.ctx, []int64{apiKey1.ID, apiKey2.ID})
s.Require().NoError(err, "GetBatchApiKeyUsageStats")
s.Require().Len(stats, 2)
}
func (s *UsageLogRepoSuite) TestGetBatchAPIKeyUsageStats_Empty() {
stats, err := s.repo.GetBatchAPIKeyUsageStats(s.ctx, []int64{})
func (s *UsageLogRepoSuite) TestGetBatchApiKeyUsageStats_Empty() {
stats, err := s.repo.GetBatchApiKeyUsageStats(s.ctx, []int64{})
s.Require().NoError(err)
s.Require().Empty(stats)
}
@@ -377,7 +380,7 @@ func (s *UsageLogRepoSuite) TestGetBatchAPIKeyUsageStats_Empty() {
func (s *UsageLogRepoSuite) TestGetGlobalStats() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "global@test.com"})
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-global", Name: "k"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-global", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-global"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
@@ -402,7 +405,7 @@ func maxTime(a, b time.Time) time.Time {
func (s *UsageLogRepoSuite) TestListByUserAndTimeRange() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "timerange@test.com"})
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-timerange", Name: "k"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-timerange", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-timerange"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
@@ -417,11 +420,11 @@ func (s *UsageLogRepoSuite) TestListByUserAndTimeRange() {
s.Require().Len(logs, 2)
}
// --- ListByAPIKeyAndTimeRange ---
// --- ListByApiKeyAndTimeRange ---
func (s *UsageLogRepoSuite) TestListByAPIKeyAndTimeRange() {
func (s *UsageLogRepoSuite) TestListByApiKeyAndTimeRange() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "keytimerange@test.com"})
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-keytimerange", Name: "k"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-keytimerange", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-keytimerange"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
@@ -431,8 +434,8 @@ func (s *UsageLogRepoSuite) TestListByAPIKeyAndTimeRange() {
startTime := base.Add(-1 * time.Hour)
endTime := base.Add(2 * time.Hour)
logs, _, err := s.repo.ListByAPIKeyAndTimeRange(s.ctx, apiKey.ID, startTime, endTime)
s.Require().NoError(err, "ListByAPIKeyAndTimeRange")
logs, _, err := s.repo.ListByApiKeyAndTimeRange(s.ctx, apiKey.ID, startTime, endTime)
s.Require().NoError(err, "ListByApiKeyAndTimeRange")
s.Require().Len(logs, 2)
}
@@ -440,7 +443,7 @@ func (s *UsageLogRepoSuite) TestListByAPIKeyAndTimeRange() {
func (s *UsageLogRepoSuite) TestListByAccountAndTimeRange() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "acctimerange@test.com"})
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-acctimerange", Name: "k"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-acctimerange", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-acctimerange"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
@@ -459,7 +462,7 @@ func (s *UsageLogRepoSuite) TestListByAccountAndTimeRange() {
func (s *UsageLogRepoSuite) TestListByModelAndTimeRange() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "modeltimerange@test.com"})
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-modeltimerange", Name: "k"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-modeltimerange", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-modeltimerange"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
@@ -467,7 +470,7 @@ func (s *UsageLogRepoSuite) TestListByModelAndTimeRange() {
// Create logs with different models
log1 := &service.UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
ApiKeyID: apiKey.ID,
AccountID: account.ID,
Model: "claude-3-opus",
InputTokens: 10,
@@ -480,7 +483,7 @@ func (s *UsageLogRepoSuite) TestListByModelAndTimeRange() {
log2 := &service.UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
ApiKeyID: apiKey.ID,
AccountID: account.ID,
Model: "claude-3-opus",
InputTokens: 15,
@@ -493,7 +496,7 @@ func (s *UsageLogRepoSuite) TestListByModelAndTimeRange() {
log3 := &service.UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
ApiKeyID: apiKey.ID,
AccountID: account.ID,
Model: "claude-3-sonnet",
InputTokens: 20,
@@ -515,7 +518,7 @@ func (s *UsageLogRepoSuite) TestListByModelAndTimeRange() {
func (s *UsageLogRepoSuite) TestGetAccountWindowStats() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "windowstats@test.com"})
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-windowstats", Name: "k"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-windowstats", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-windowstats"})
now := time.Now()
@@ -535,7 +538,7 @@ func (s *UsageLogRepoSuite) TestGetAccountWindowStats() {
func (s *UsageLogRepoSuite) TestGetUserUsageTrendByUserID() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "usertrend@test.com"})
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-usertrend", Name: "k"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-usertrend", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-usertrend"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
@@ -552,7 +555,7 @@ func (s *UsageLogRepoSuite) TestGetUserUsageTrendByUserID() {
func (s *UsageLogRepoSuite) TestGetUserUsageTrendByUserID_HourlyGranularity() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "usertrendhourly@test.com"})
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-usertrendhourly", Name: "k"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-usertrendhourly", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-usertrendhourly"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
@@ -571,7 +574,7 @@ func (s *UsageLogRepoSuite) TestGetUserUsageTrendByUserID_HourlyGranularity() {
func (s *UsageLogRepoSuite) TestGetUserModelStats() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "modelstats@test.com"})
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-modelstats", Name: "k"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-modelstats", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-modelstats"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
@@ -579,7 +582,7 @@ func (s *UsageLogRepoSuite) TestGetUserModelStats() {
// Create logs with different models
log1 := &service.UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
ApiKeyID: apiKey.ID,
AccountID: account.ID,
Model: "claude-3-opus",
InputTokens: 100,
@@ -592,7 +595,7 @@ func (s *UsageLogRepoSuite) TestGetUserModelStats() {
log2 := &service.UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
ApiKeyID: apiKey.ID,
AccountID: account.ID,
Model: "claude-3-sonnet",
InputTokens: 50,
@@ -618,7 +621,7 @@ func (s *UsageLogRepoSuite) TestGetUserModelStats() {
func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "trendfilters@test.com"})
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-trendfilters", Name: "k"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-trendfilters", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-trendfilters"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
@@ -646,7 +649,7 @@ func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters() {
func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters_HourlyGranularity() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "trendfilters-h@test.com"})
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-trendfilters-h", Name: "k"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-trendfilters-h", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-trendfilters-h"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
@@ -665,14 +668,14 @@ func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters_HourlyGranularity() {
func (s *UsageLogRepoSuite) TestGetModelStatsWithFilters() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "modelfilters@test.com"})
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-modelfilters", Name: "k"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-modelfilters", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-modelfilters"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
log1 := &service.UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
ApiKeyID: apiKey.ID,
AccountID: account.ID,
Model: "claude-3-opus",
InputTokens: 100,
@@ -685,7 +688,7 @@ func (s *UsageLogRepoSuite) TestGetModelStatsWithFilters() {
log2 := &service.UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
ApiKeyID: apiKey.ID,
AccountID: account.ID,
Model: "claude-3-sonnet",
InputTokens: 50,
@@ -719,7 +722,7 @@ func (s *UsageLogRepoSuite) TestGetModelStatsWithFilters() {
func (s *UsageLogRepoSuite) TestGetAccountUsageStats() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "accstats@test.com"})
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-accstats", Name: "k"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-accstats", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-accstats"})
base := time.Date(2025, 1, 15, 0, 0, 0, 0, time.UTC)
@@ -727,7 +730,7 @@ func (s *UsageLogRepoSuite) TestGetAccountUsageStats() {
// Create logs on different days
log1 := &service.UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
ApiKeyID: apiKey.ID,
AccountID: account.ID,
Model: "claude-3-opus",
InputTokens: 100,
@@ -740,7 +743,7 @@ func (s *UsageLogRepoSuite) TestGetAccountUsageStats() {
log2 := &service.UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
ApiKeyID: apiKey.ID,
AccountID: account.ID,
Model: "claude-3-sonnet",
InputTokens: 50,
@@ -782,8 +785,8 @@ func (s *UsageLogRepoSuite) TestGetAccountUsageStats_EmptyRange() {
func (s *UsageLogRepoSuite) TestGetUserUsageTrend() {
user1 := mustCreateUser(s.T(), s.client, &service.User{Email: "usertrend1@test.com"})
user2 := mustCreateUser(s.T(), s.client, &service.User{Email: "usertrend2@test.com"})
apiKey1 := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user1.ID, Key: "sk-usertrend1", Name: "k1"})
apiKey2 := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user2.ID, Key: "sk-usertrend2", Name: "k2"})
apiKey1 := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user1.ID, Key: "sk-usertrend1", Name: "k1"})
apiKey2 := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user2.ID, Key: "sk-usertrend2", Name: "k2"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-usertrends"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
@@ -799,12 +802,12 @@ func (s *UsageLogRepoSuite) TestGetUserUsageTrend() {
s.Require().GreaterOrEqual(len(trend), 2)
}
// --- GetAPIKeyUsageTrend ---
// --- GetApiKeyUsageTrend ---
func (s *UsageLogRepoSuite) TestGetAPIKeyUsageTrend() {
func (s *UsageLogRepoSuite) TestGetApiKeyUsageTrend() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "keytrend@test.com"})
apiKey1 := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-keytrend1", Name: "k1"})
apiKey2 := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-keytrend2", Name: "k2"})
apiKey1 := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-keytrend1", Name: "k1"})
apiKey2 := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-keytrend2", Name: "k2"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-keytrends"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
@@ -815,14 +818,14 @@ func (s *UsageLogRepoSuite) TestGetAPIKeyUsageTrend() {
startTime := base.Add(-1 * time.Hour)
endTime := base.Add(48 * time.Hour)
trend, err := s.repo.GetAPIKeyUsageTrend(s.ctx, startTime, endTime, "day", 10)
s.Require().NoError(err, "GetAPIKeyUsageTrend")
trend, err := s.repo.GetApiKeyUsageTrend(s.ctx, startTime, endTime, "day", 10)
s.Require().NoError(err, "GetApiKeyUsageTrend")
s.Require().GreaterOrEqual(len(trend), 2)
}
func (s *UsageLogRepoSuite) TestGetAPIKeyUsageTrend_HourlyGranularity() {
func (s *UsageLogRepoSuite) TestGetApiKeyUsageTrend_HourlyGranularity() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "keytrendh@test.com"})
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-keytrendh", Name: "k"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-keytrendh", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-keytrendh"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
@@ -832,21 +835,21 @@ func (s *UsageLogRepoSuite) TestGetAPIKeyUsageTrend_HourlyGranularity() {
startTime := base.Add(-1 * time.Hour)
endTime := base.Add(3 * time.Hour)
trend, err := s.repo.GetAPIKeyUsageTrend(s.ctx, startTime, endTime, "hour", 10)
s.Require().NoError(err, "GetAPIKeyUsageTrend hourly")
trend, err := s.repo.GetApiKeyUsageTrend(s.ctx, startTime, endTime, "hour", 10)
s.Require().NoError(err, "GetApiKeyUsageTrend hourly")
s.Require().Len(trend, 2)
}
// --- ListWithFilters (additional filter tests) ---
func (s *UsageLogRepoSuite) TestListWithFilters_APIKeyFilter() {
func (s *UsageLogRepoSuite) TestListWithFilters_ApiKeyFilter() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "filterskey@test.com"})
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-filterskey", Name: "k"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-filterskey", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-filterskey"})
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
filters := usagestats.UsageLogFilters{APIKeyID: apiKey.ID}
filters := usagestats.UsageLogFilters{ApiKeyID: apiKey.ID}
logs, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, filters)
s.Require().NoError(err, "ListWithFilters apiKey")
s.Require().Len(logs, 1)
@@ -855,7 +858,7 @@ func (s *UsageLogRepoSuite) TestListWithFilters_APIKeyFilter() {
func (s *UsageLogRepoSuite) TestListWithFilters_TimeRange() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "filterstime@test.com"})
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-filterstime", Name: "k"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-filterstime", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-filterstime"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
@@ -874,7 +877,7 @@ func (s *UsageLogRepoSuite) TestListWithFilters_TimeRange() {
func (s *UsageLogRepoSuite) TestListWithFilters_CombinedFilters() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "filterscombined@test.com"})
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-filterscombined", Name: "k"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-filterscombined", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-filterscombined"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
@@ -885,7 +888,7 @@ func (s *UsageLogRepoSuite) TestListWithFilters_CombinedFilters() {
endTime := base.Add(2 * time.Hour)
filters := usagestats.UsageLogFilters{
UserID: user.ID,
APIKeyID: apiKey.ID,
ApiKeyID: apiKey.ID,
StartTime: &startTime,
EndTime: &endTime,
}

View File

@@ -91,7 +91,7 @@ func TestAPIContracts(t *testing.T) {
name: "GET /api/v1/keys (paginated)",
setup: func(t *testing.T, deps *contractDeps) {
t.Helper()
deps.apiKeyRepo.MustSeed(&service.APIKey{
deps.apiKeyRepo.MustSeed(&service.ApiKey{
ID: 100,
UserID: 1,
Key: "sk_custom_1234567890",
@@ -135,7 +135,7 @@ func TestAPIContracts(t *testing.T) {
{
ID: 1,
UserID: 1,
APIKeyID: 100,
ApiKeyID: 100,
AccountID: 200,
Model: "claude-3",
InputTokens: 10,
@@ -150,7 +150,7 @@ func TestAPIContracts(t *testing.T) {
{
ID: 2,
UserID: 1,
APIKeyID: 100,
ApiKeyID: 100,
AccountID: 200,
Model: "claude-3",
InputTokens: 5,
@@ -188,7 +188,7 @@ func TestAPIContracts(t *testing.T) {
{
ID: 1,
UserID: 1,
APIKeyID: 100,
ApiKeyID: 100,
AccountID: 200,
RequestID: "req_123",
Model: "claude-3",
@@ -259,13 +259,13 @@ func TestAPIContracts(t *testing.T) {
service.SettingKeyRegistrationEnabled: "true",
service.SettingKeyEmailVerifyEnabled: "false",
service.SettingKeySMTPHost: "smtp.example.com",
service.SettingKeySMTPPort: "587",
service.SettingKeySMTPUsername: "user",
service.SettingKeySMTPPassword: "secret",
service.SettingKeySMTPFrom: "no-reply@example.com",
service.SettingKeySMTPFromName: "Sub2API",
service.SettingKeySMTPUseTLS: "true",
service.SettingKeySmtpHost: "smtp.example.com",
service.SettingKeySmtpPort: "587",
service.SettingKeySmtpUsername: "user",
service.SettingKeySmtpPassword: "secret",
service.SettingKeySmtpFrom: "no-reply@example.com",
service.SettingKeySmtpFromName: "Sub2API",
service.SettingKeySmtpUseTLS: "true",
service.SettingKeyTurnstileEnabled: "true",
service.SettingKeyTurnstileSiteKey: "site-key",
@@ -274,9 +274,9 @@ func TestAPIContracts(t *testing.T) {
service.SettingKeySiteName: "Sub2API",
service.SettingKeySiteLogo: "",
service.SettingKeySiteSubtitle: "Subtitle",
service.SettingKeyAPIBaseURL: "https://api.example.com",
service.SettingKeyApiBaseUrl: "https://api.example.com",
service.SettingKeyContactInfo: "support",
service.SettingKeyDocURL: "https://docs.example.com",
service.SettingKeyDocUrl: "https://docs.example.com",
service.SettingKeyDefaultConcurrency: "5",
service.SettingKeyDefaultBalance: "1.25",
@@ -308,7 +308,12 @@ func TestAPIContracts(t *testing.T) {
"contact_info": "support",
"doc_url": "https://docs.example.com",
"default_concurrency": 5,
"default_balance": 1.25
"default_balance": 1.25,
"enable_model_fallback": false,
"fallback_model_anthropic": "",
"fallback_model_antigravity": "",
"fallback_model_gemini": "",
"fallback_model_openai": ""
}
}`,
},
@@ -331,7 +336,7 @@ func TestAPIContracts(t *testing.T) {
type contractDeps struct {
now time.Time
router http.Handler
apiKeyRepo *stubAPIKeyRepo
apiKeyRepo *stubApiKeyRepo
usageRepo *stubUsageLogRepo
settingRepo *stubSettingRepo
}
@@ -359,20 +364,20 @@ func newContractDeps(t *testing.T) *contractDeps {
},
}
apiKeyRepo := newStubAPIKeyRepo(now)
apiKeyCache := stubAPIKeyCache{}
apiKeyRepo := newStubApiKeyRepo(now)
apiKeyCache := stubApiKeyCache{}
groupRepo := stubGroupRepo{}
userSubRepo := stubUserSubscriptionRepo{}
cfg := &config.Config{
Default: config.DefaultConfig{
APIKeyPrefix: "sk-",
ApiKeyPrefix: "sk-",
},
RunMode: config.RunModeStandard,
}
userService := service.NewUserService(userRepo)
apiKeyService := service.NewAPIKeyService(apiKeyRepo, userRepo, groupRepo, userSubRepo, apiKeyCache, cfg)
apiKeyService := service.NewApiKeyService(apiKeyRepo, userRepo, groupRepo, userSubRepo, apiKeyCache, cfg)
usageRepo := newStubUsageLogRepo()
usageService := service.NewUsageService(usageRepo, userRepo)
@@ -525,25 +530,25 @@ func (r *stubUserRepo) RemoveGroupFromAllowedGroups(ctx context.Context, groupID
return 0, errors.New("not implemented")
}
type stubAPIKeyCache struct{}
type stubApiKeyCache struct{}
func (stubAPIKeyCache) GetCreateAttemptCount(ctx context.Context, userID int64) (int, error) {
func (stubApiKeyCache) GetCreateAttemptCount(ctx context.Context, userID int64) (int, error) {
return 0, nil
}
func (stubAPIKeyCache) IncrementCreateAttemptCount(ctx context.Context, userID int64) error {
func (stubApiKeyCache) IncrementCreateAttemptCount(ctx context.Context, userID int64) error {
return nil
}
func (stubAPIKeyCache) DeleteCreateAttemptCount(ctx context.Context, userID int64) error {
func (stubApiKeyCache) DeleteCreateAttemptCount(ctx context.Context, userID int64) error {
return nil
}
func (stubAPIKeyCache) IncrementDailyUsage(ctx context.Context, apiKey string) error {
func (stubApiKeyCache) IncrementDailyUsage(ctx context.Context, apiKey string) error {
return nil
}
func (stubAPIKeyCache) SetDailyUsageExpiry(ctx context.Context, apiKey string, ttl time.Duration) error {
func (stubApiKeyCache) SetDailyUsageExpiry(ctx context.Context, apiKey string, ttl time.Duration) error {
return nil
}
@@ -660,24 +665,24 @@ func (stubUserSubscriptionRepo) BatchUpdateExpiredStatus(ctx context.Context) (i
return 0, errors.New("not implemented")
}
type stubAPIKeyRepo struct {
type stubApiKeyRepo struct {
now time.Time
nextID int64
byID map[int64]*service.APIKey
byKey map[string]*service.APIKey
byID map[int64]*service.ApiKey
byKey map[string]*service.ApiKey
}
func newStubAPIKeyRepo(now time.Time) *stubAPIKeyRepo {
return &stubAPIKeyRepo{
func newStubApiKeyRepo(now time.Time) *stubApiKeyRepo {
return &stubApiKeyRepo{
now: now,
nextID: 100,
byID: make(map[int64]*service.APIKey),
byKey: make(map[string]*service.APIKey),
byID: make(map[int64]*service.ApiKey),
byKey: make(map[string]*service.ApiKey),
}
}
func (r *stubAPIKeyRepo) MustSeed(key *service.APIKey) {
func (r *stubApiKeyRepo) MustSeed(key *service.ApiKey) {
if key == nil {
return
}
@@ -686,7 +691,7 @@ func (r *stubAPIKeyRepo) MustSeed(key *service.APIKey) {
r.byKey[clone.Key] = &clone
}
func (r *stubAPIKeyRepo) Create(ctx context.Context, key *service.APIKey) error {
func (r *stubApiKeyRepo) Create(ctx context.Context, key *service.ApiKey) error {
if key == nil {
return errors.New("nil key")
}
@@ -706,38 +711,38 @@ func (r *stubAPIKeyRepo) Create(ctx context.Context, key *service.APIKey) error
return nil
}
func (r *stubAPIKeyRepo) GetByID(ctx context.Context, id int64) (*service.APIKey, error) {
func (r *stubApiKeyRepo) GetByID(ctx context.Context, id int64) (*service.ApiKey, error) {
key, ok := r.byID[id]
if !ok {
return nil, service.ErrAPIKeyNotFound
return nil, service.ErrApiKeyNotFound
}
clone := *key
return &clone, nil
}
func (r *stubAPIKeyRepo) GetOwnerID(ctx context.Context, id int64) (int64, error) {
func (r *stubApiKeyRepo) GetOwnerID(ctx context.Context, id int64) (int64, error) {
key, ok := r.byID[id]
if !ok {
return 0, service.ErrAPIKeyNotFound
return 0, service.ErrApiKeyNotFound
}
return key.UserID, nil
}
func (r *stubAPIKeyRepo) GetByKey(ctx context.Context, key string) (*service.APIKey, error) {
func (r *stubApiKeyRepo) GetByKey(ctx context.Context, key string) (*service.ApiKey, error) {
found, ok := r.byKey[key]
if !ok {
return nil, service.ErrAPIKeyNotFound
return nil, service.ErrApiKeyNotFound
}
clone := *found
return &clone, nil
}
func (r *stubAPIKeyRepo) Update(ctx context.Context, key *service.APIKey) error {
func (r *stubApiKeyRepo) Update(ctx context.Context, key *service.ApiKey) error {
if key == nil {
return errors.New("nil key")
}
if _, ok := r.byID[key.ID]; !ok {
return service.ErrAPIKeyNotFound
return service.ErrApiKeyNotFound
}
if key.UpdatedAt.IsZero() {
key.UpdatedAt = r.now
@@ -748,17 +753,17 @@ func (r *stubAPIKeyRepo) Update(ctx context.Context, key *service.APIKey) error
return nil
}
func (r *stubAPIKeyRepo) Delete(ctx context.Context, id int64) error {
func (r *stubApiKeyRepo) Delete(ctx context.Context, id int64) error {
key, ok := r.byID[id]
if !ok {
return service.ErrAPIKeyNotFound
return service.ErrApiKeyNotFound
}
delete(r.byID, id)
delete(r.byKey, key.Key)
return nil
}
func (r *stubAPIKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
func (r *stubApiKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
ids := make([]int64, 0, len(r.byID))
for id := range r.byID {
if r.byID[id].UserID == userID {
@@ -776,7 +781,7 @@ func (r *stubAPIKeyRepo) ListByUserID(ctx context.Context, userID int64, params
end = len(ids)
}
out := make([]service.APIKey, 0, end-start)
out := make([]service.ApiKey, 0, end-start)
for _, id := range ids[start:end] {
clone := *r.byID[id]
out = append(out, clone)
@@ -796,7 +801,7 @@ func (r *stubAPIKeyRepo) ListByUserID(ctx context.Context, userID int64, params
}, nil
}
func (r *stubAPIKeyRepo) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
func (r *stubApiKeyRepo) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
if len(apiKeyIDs) == 0 {
return []int64{}, nil
}
@@ -815,7 +820,7 @@ func (r *stubAPIKeyRepo) VerifyOwnership(ctx context.Context, userID int64, apiK
return out, nil
}
func (r *stubAPIKeyRepo) CountByUserID(ctx context.Context, userID int64) (int64, error) {
func (r *stubApiKeyRepo) CountByUserID(ctx context.Context, userID int64) (int64, error) {
var count int64
for _, key := range r.byID {
if key.UserID == userID {
@@ -825,24 +830,24 @@ func (r *stubAPIKeyRepo) CountByUserID(ctx context.Context, userID int64) (int64
return count, nil
}
func (r *stubAPIKeyRepo) ExistsByKey(ctx context.Context, key string) (bool, error) {
func (r *stubApiKeyRepo) ExistsByKey(ctx context.Context, key string) (bool, error) {
_, ok := r.byKey[key]
return ok, nil
}
func (r *stubAPIKeyRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
func (r *stubApiKeyRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}
func (r *stubAPIKeyRepo) SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.APIKey, error) {
func (r *stubApiKeyRepo) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.ApiKey, error) {
return nil, errors.New("not implemented")
}
func (r *stubAPIKeyRepo) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
func (r *stubApiKeyRepo) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
return 0, errors.New("not implemented")
}
func (r *stubAPIKeyRepo) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
func (r *stubApiKeyRepo) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
return 0, errors.New("not implemented")
}
@@ -877,7 +882,7 @@ func (r *stubUsageLogRepo) ListByUser(ctx context.Context, userID int64, params
return out, paginationResult(total, params), nil
}
func (r *stubUsageLogRepo) ListByAPIKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
func (r *stubUsageLogRepo) ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}
@@ -890,7 +895,7 @@ func (r *stubUsageLogRepo) ListByUserAndTimeRange(ctx context.Context, userID in
return logs, paginationResult(int64(len(logs)), pagination.PaginationParams{Page: 1, PageSize: 100}), nil
}
func (r *stubUsageLogRepo) ListByAPIKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
func (r *stubUsageLogRepo) ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}
@@ -922,7 +927,7 @@ func (r *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTi
return nil, errors.New("not implemented")
}
func (r *stubUsageLogRepo) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error) {
func (r *stubUsageLogRepo) GetApiKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.ApiKeyUsageTrendPoint, error) {
return nil, errors.New("not implemented")
}
@@ -975,7 +980,7 @@ func (r *stubUsageLogRepo) GetUserStatsAggregated(ctx context.Context, userID in
}, nil
}
func (r *stubUsageLogRepo) GetAPIKeyStatsAggregated(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
func (r *stubUsageLogRepo) GetApiKeyStatsAggregated(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
return nil, errors.New("not implemented")
}
@@ -995,7 +1000,7 @@ func (r *stubUsageLogRepo) GetBatchUserUsageStats(ctx context.Context, userIDs [
return nil, errors.New("not implemented")
}
func (r *stubUsageLogRepo) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) {
func (r *stubUsageLogRepo) GetBatchApiKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchApiKeyUsageStats, error) {
return nil, errors.New("not implemented")
}
@@ -1017,8 +1022,8 @@ func (r *stubUsageLogRepo) ListWithFilters(ctx context.Context, params paginatio
// Apply filters
var filtered []service.UsageLog
for _, log := range logs {
// Apply APIKeyID filter
if filters.APIKeyID > 0 && log.APIKeyID != filters.APIKeyID {
// Apply ApiKeyID filter
if filters.ApiKeyID > 0 && log.ApiKeyID != filters.ApiKeyID {
continue
}
// Apply Model filter
@@ -1151,8 +1156,8 @@ func paginationResult(total int64, params pagination.PaginationParams) *paginati
// Ensure compile-time interface compliance.
var (
_ service.UserRepository = (*stubUserRepo)(nil)
_ service.APIKeyRepository = (*stubAPIKeyRepo)(nil)
_ service.APIKeyCache = (*stubAPIKeyCache)(nil)
_ service.ApiKeyRepository = (*stubApiKeyRepo)(nil)
_ service.ApiKeyCache = (*stubApiKeyCache)(nil)
_ service.GroupRepository = (*stubGroupRepo)(nil)
_ service.UserSubscriptionRepository = (*stubUserSubscriptionRepo)(nil)
_ service.UsageLogRepository = (*stubUsageLogRepo)(nil)

View File

@@ -121,6 +121,12 @@ func (m *mockAccountRepoForGemini) SetRateLimited(ctx context.Context, id int64,
func (m *mockAccountRepoForGemini) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
return nil
}
func (m *mockAccountRepoForGemini) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error {
return nil
}
func (m *mockAccountRepoForGemini) ClearTempUnschedulable(ctx context.Context, id int64) error {
return nil
}
func (m *mockAccountRepoForGemini) ClearRateLimit(ctx context.Context, id int64) error { return nil }
func (m *mockAccountRepoForGemini) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
return nil
@@ -275,7 +281,7 @@ func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_OAuthPr
repo := &mockAccountRepoForGemini{
accounts: []Account{
{ID: 1, Platform: PlatformGemini, Type: AccountTypeAPIKey, Priority: 1, Status: StatusActive, Schedulable: true, LastUsedAt: nil},
{ID: 1, Platform: PlatformGemini, Type: AccountTypeApiKey, Priority: 1, Status: StatusActive, Schedulable: true, LastUsedAt: nil},
{ID: 2, Platform: PlatformGemini, Type: AccountTypeOAuth, Priority: 1, Status: StatusActive, Schedulable: true, LastUsedAt: nil},
},
accountsByID: map[int64]*Account{},

View File

@@ -197,7 +197,7 @@ func TestClaudeTokenRefresher_CanRefresh(t *testing.T) {
{
name: "anthropic api-key - cannot refresh",
platform: PlatformAnthropic,
accType: AccountTypeAPIKey,
accType: AccountTypeApiKey,
want: false,
},
{

View File

@@ -0,0 +1,104 @@
-- Ops monitoring: pre-aggregation tables for dashboard queries
--
-- Problem:
-- The ops dashboard currently runs percentile_cont + GROUP BY queries over large raw tables
-- (usage_logs, ops_error_logs). These will get slower as data grows.
--
-- This migration adds schema-only aggregation tables that can be populated by a future background job.
-- No triggers/functions/jobs are created here (schema only).
-- ============================================
-- Hourly aggregates (per provider/platform)
-- ============================================
CREATE TABLE IF NOT EXISTS ops_metrics_hourly (
-- Start of the hour bucket (recommended: UTC).
bucket_start TIMESTAMPTZ NOT NULL,
-- Provider/platform label (e.g. anthropic/openai/gemini). Mirrors ops_* queries that GROUP BY platform.
platform VARCHAR(50) NOT NULL,
-- Traffic counts (use these to compute rates reliably across ranges).
request_count BIGINT NOT NULL DEFAULT 0,
success_count BIGINT NOT NULL DEFAULT 0,
error_count BIGINT NOT NULL DEFAULT 0,
-- Error breakdown used by provider health UI.
error_4xx_count BIGINT NOT NULL DEFAULT 0,
error_5xx_count BIGINT NOT NULL DEFAULT 0,
timeout_count BIGINT NOT NULL DEFAULT 0,
-- Latency aggregates (ms).
avg_latency_ms DOUBLE PRECISION,
p99_latency_ms DOUBLE PRECISION,
-- Convenience rate (percentage, 0-100). Still keep counts as source of truth.
error_rate DOUBLE PRECISION NOT NULL DEFAULT 0,
-- When this row was last (re)computed by the background job.
computed_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
PRIMARY KEY (bucket_start, platform)
);
CREATE INDEX IF NOT EXISTS idx_ops_metrics_hourly_platform_bucket_start
ON ops_metrics_hourly (platform, bucket_start DESC);
COMMENT ON TABLE ops_metrics_hourly IS 'Pre-aggregated hourly ops metrics by provider/platform to speed up dashboard queries.';
COMMENT ON COLUMN ops_metrics_hourly.bucket_start IS 'Start timestamp of the hour bucket (recommended UTC).';
COMMENT ON COLUMN ops_metrics_hourly.platform IS 'Provider/platform label (anthropic/openai/gemini, etc).';
COMMENT ON COLUMN ops_metrics_hourly.error_rate IS 'Error rate percentage for the bucket (0-100). Counts remain the source of truth.';
COMMENT ON COLUMN ops_metrics_hourly.computed_at IS 'When the row was last computed/refreshed.';
-- ============================================
-- Daily aggregates (per provider/platform)
-- ============================================
CREATE TABLE IF NOT EXISTS ops_metrics_daily (
-- Day bucket (recommended: UTC date).
bucket_date DATE NOT NULL,
platform VARCHAR(50) NOT NULL,
request_count BIGINT NOT NULL DEFAULT 0,
success_count BIGINT NOT NULL DEFAULT 0,
error_count BIGINT NOT NULL DEFAULT 0,
error_4xx_count BIGINT NOT NULL DEFAULT 0,
error_5xx_count BIGINT NOT NULL DEFAULT 0,
timeout_count BIGINT NOT NULL DEFAULT 0,
avg_latency_ms DOUBLE PRECISION,
p99_latency_ms DOUBLE PRECISION,
error_rate DOUBLE PRECISION NOT NULL DEFAULT 0,
computed_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
PRIMARY KEY (bucket_date, platform)
);
CREATE INDEX IF NOT EXISTS idx_ops_metrics_daily_platform_bucket_date
ON ops_metrics_daily (platform, bucket_date DESC);
COMMENT ON TABLE ops_metrics_daily IS 'Pre-aggregated daily ops metrics by provider/platform for longer-term trends.';
COMMENT ON COLUMN ops_metrics_daily.bucket_date IS 'UTC date of the day bucket (recommended).';
-- ============================================
-- Population strategy (future background job)
-- ============================================
--
-- Suggested approach:
-- 1) Compute hourly buckets from raw logs using UTC time-bucketing, then UPSERT into ops_metrics_hourly.
-- 2) Compute daily buckets either directly from raw logs or by rolling up ops_metrics_hourly.
--
-- Notes:
-- - Ensure the job uses a consistent timezone (recommended: SET TIME ZONE ''UTC'') to avoid bucket drift.
-- - Derive the provider/platform similarly to existing dashboard queries:
-- usage_logs: COALESCE(NULLIF(groups.platform, ''), accounts.platform, '')
-- ops_error_logs: COALESCE(NULLIF(ops_error_logs.platform, ''), groups.platform, accounts.platform, '')
-- - Keep request_count/success_count/error_count as the authoritative values; compute error_rate from counts.
--
-- Example (hourly) shape (pseudo-SQL):
-- INSERT INTO ops_metrics_hourly (...)
-- SELECT date_trunc('hour', created_at) AS bucket_start, platform, ...
-- FROM (/* aggregate usage_logs + ops_error_logs */) s
-- ON CONFLICT (bucket_start, platform) DO UPDATE SET ...;

View File

@@ -0,0 +1,58 @@
-- 027_usage_billing_consistency.sql
-- Ensure usage_logs idempotency (request_id, api_key_id) and add reconciliation infrastructure.
-- -----------------------------------------------------------------------------
-- 1) Normalize legacy request_id values
-- -----------------------------------------------------------------------------
-- Historically request_id may be inserted as empty string. Convert it to NULL so
-- the upcoming unique index does not break on repeated "" values.
UPDATE usage_logs
SET request_id = NULL
WHERE request_id = '';
-- If duplicates already exist for the same (request_id, api_key_id), keep the
-- first row and NULL-out request_id for the rest so the unique index can be
-- created without deleting historical logs.
WITH ranked AS (
SELECT
id,
ROW_NUMBER() OVER (PARTITION BY api_key_id, request_id ORDER BY id) AS rn
FROM usage_logs
WHERE request_id IS NOT NULL
)
UPDATE usage_logs ul
SET request_id = NULL
FROM ranked r
WHERE ul.id = r.id
AND r.rn > 1;
-- -----------------------------------------------------------------------------
-- 2) Idempotency constraint for usage_logs
-- -----------------------------------------------------------------------------
CREATE UNIQUE INDEX IF NOT EXISTS idx_usage_logs_request_id_api_key_unique
ON usage_logs (request_id, api_key_id);
-- -----------------------------------------------------------------------------
-- 3) Reconciliation infrastructure: billing ledger for usage charges
-- -----------------------------------------------------------------------------
CREATE TABLE IF NOT EXISTS billing_usage_entries (
id BIGSERIAL PRIMARY KEY,
usage_log_id BIGINT NOT NULL REFERENCES usage_logs(id) ON DELETE CASCADE,
user_id BIGINT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
api_key_id BIGINT NOT NULL REFERENCES api_keys(id) ON DELETE CASCADE,
subscription_id BIGINT REFERENCES user_subscriptions(id) ON DELETE SET NULL,
billing_type SMALLINT NOT NULL,
applied BOOLEAN NOT NULL DEFAULT TRUE,
delta_usd DECIMAL(20, 10) NOT NULL DEFAULT 0,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE UNIQUE INDEX IF NOT EXISTS billing_usage_entries_usage_log_id_unique
ON billing_usage_entries (usage_log_id);
CREATE INDEX IF NOT EXISTS idx_billing_usage_entries_user_time
ON billing_usage_entries (user_id, created_at);
CREATE INDEX IF NOT EXISTS idx_billing_usage_entries_created_at
ON billing_usage_entries (created_at);