fix(test): 修复测试和添加数据库迁移

测试修复:
- 修复集成测试中的重复键冲突问题
- 移除 JSON 中多余的尾随逗号
- 新增 inprocess_transport_test.go
- 更新 haiku 模型映射测试用例

数据库迁移:
- 026: 运营指标聚合表
- 027: 使用量与计费一致性约束
This commit is contained in:
ianshaw
2026-01-03 06:36:35 -08:00
parent ff3f514f6b
commit b1702de522
16 changed files with 495 additions and 244 deletions

View File

@@ -5,28 +5,20 @@ import (
"encoding/json" "encoding/json"
"io" "io"
"net/http" "net/http"
"net/http/httptest"
"strings" "strings"
"testing" "testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/oauth" "github.com/Wei-Shaw/sub2api/internal/pkg/oauth"
"github.com/imroc/req/v3"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
) )
type ClaudeOAuthServiceSuite struct { type ClaudeOAuthServiceSuite struct {
suite.Suite suite.Suite
srv *httptest.Server
client *claudeOAuthService client *claudeOAuthService
} }
func (s *ClaudeOAuthServiceSuite) TearDownTest() {
if s.srv != nil {
s.srv.Close()
s.srv = nil
}
}
// requestCapture holds captured request data for assertions in the main goroutine. // requestCapture holds captured request data for assertions in the main goroutine.
type requestCapture struct { type requestCapture struct {
path string path string
@@ -37,6 +29,12 @@ type requestCapture struct {
contentType string contentType string
} }
func newTestReqClient(rt http.RoundTripper) *req.Client {
c := req.C()
c.GetClient().Transport = rt
return c
}
func (s *ClaudeOAuthServiceSuite) TestGetOrganizationUUID() { func (s *ClaudeOAuthServiceSuite) TestGetOrganizationUUID() {
tests := []struct { tests := []struct {
name string name string
@@ -83,17 +81,17 @@ func (s *ClaudeOAuthServiceSuite) TestGetOrganizationUUID() {
s.Run(tt.name, func() { s.Run(tt.name, func() {
var captured requestCapture var captured requestCapture
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { rt := newInProcessTransport(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
captured.path = r.URL.Path captured.path = r.URL.Path
captured.cookies = r.Cookies() captured.cookies = r.Cookies()
tt.handler(w, r) tt.handler(w, r)
})) }), nil)
defer s.srv.Close()
client, ok := NewClaudeOAuthClient().(*claudeOAuthService) client, ok := NewClaudeOAuthClient().(*claudeOAuthService)
require.True(s.T(), ok, "type assertion failed") require.True(s.T(), ok, "type assertion failed")
s.client = client s.client = client
s.client.baseURL = s.srv.URL s.client.baseURL = "http://in-process"
s.client.clientFactory = func(string) *req.Client { return newTestReqClient(rt) }
got, err := s.client.GetOrganizationUUID(context.Background(), "sess", "") got, err := s.client.GetOrganizationUUID(context.Background(), "sess", "")
@@ -158,20 +156,20 @@ func (s *ClaudeOAuthServiceSuite) TestGetAuthorizationCode() {
s.Run(tt.name, func() { s.Run(tt.name, func() {
var captured requestCapture var captured requestCapture
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { rt := newInProcessTransport(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
captured.path = r.URL.Path captured.path = r.URL.Path
captured.method = r.Method captured.method = r.Method
captured.cookies = r.Cookies() captured.cookies = r.Cookies()
captured.body, _ = io.ReadAll(r.Body) captured.body, _ = io.ReadAll(r.Body)
_ = json.Unmarshal(captured.body, &captured.bodyJSON) _ = json.Unmarshal(captured.body, &captured.bodyJSON)
tt.handler(w, r) tt.handler(w, r)
})) }), nil)
defer s.srv.Close()
client, ok := NewClaudeOAuthClient().(*claudeOAuthService) client, ok := NewClaudeOAuthClient().(*claudeOAuthService)
require.True(s.T(), ok, "type assertion failed") require.True(s.T(), ok, "type assertion failed")
s.client = client s.client = client
s.client.baseURL = s.srv.URL s.client.baseURL = "http://in-process"
s.client.clientFactory = func(string) *req.Client { return newTestReqClient(rt) }
code, err := s.client.GetAuthorizationCode(context.Background(), "sess", "org-1", oauth.ScopeProfile, "cc", "st", "") code, err := s.client.GetAuthorizationCode(context.Background(), "sess", "org-1", oauth.ScopeProfile, "cc", "st", "")
@@ -266,19 +264,19 @@ func (s *ClaudeOAuthServiceSuite) TestExchangeCodeForToken() {
s.Run(tt.name, func() { s.Run(tt.name, func() {
var captured requestCapture var captured requestCapture
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { rt := newInProcessTransport(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
captured.method = r.Method captured.method = r.Method
captured.contentType = r.Header.Get("Content-Type") captured.contentType = r.Header.Get("Content-Type")
captured.body, _ = io.ReadAll(r.Body) captured.body, _ = io.ReadAll(r.Body)
_ = json.Unmarshal(captured.body, &captured.bodyJSON) _ = json.Unmarshal(captured.body, &captured.bodyJSON)
tt.handler(w, r) tt.handler(w, r)
})) }), nil)
defer s.srv.Close()
client, ok := NewClaudeOAuthClient().(*claudeOAuthService) client, ok := NewClaudeOAuthClient().(*claudeOAuthService)
require.True(s.T(), ok, "type assertion failed") require.True(s.T(), ok, "type assertion failed")
s.client = client s.client = client
s.client.tokenURL = s.srv.URL s.client.tokenURL = "http://in-process/token"
s.client.clientFactory = func(string) *req.Client { return newTestReqClient(rt) }
resp, err := s.client.ExchangeCodeForToken(context.Background(), tt.code, "ver", "", "", tt.isSetupToken) resp, err := s.client.ExchangeCodeForToken(context.Background(), tt.code, "ver", "", "", tt.isSetupToken)
@@ -362,19 +360,19 @@ func (s *ClaudeOAuthServiceSuite) TestRefreshToken() {
s.Run(tt.name, func() { s.Run(tt.name, func() {
var captured requestCapture var captured requestCapture
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { rt := newInProcessTransport(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
captured.method = r.Method captured.method = r.Method
captured.contentType = r.Header.Get("Content-Type") captured.contentType = r.Header.Get("Content-Type")
captured.body, _ = io.ReadAll(r.Body) captured.body, _ = io.ReadAll(r.Body)
_ = json.Unmarshal(captured.body, &captured.bodyJSON) _ = json.Unmarshal(captured.body, &captured.bodyJSON)
tt.handler(w, r) tt.handler(w, r)
})) }), nil)
defer s.srv.Close()
client, ok := NewClaudeOAuthClient().(*claudeOAuthService) client, ok := NewClaudeOAuthClient().(*claudeOAuthService)
require.True(s.T(), ok, "type assertion failed") require.True(s.T(), ok, "type assertion failed")
s.client = client s.client = client
s.client.tokenURL = s.srv.URL s.client.tokenURL = "http://in-process/token"
s.client.clientFactory = func(string) *req.Client { return newTestReqClient(rt) }
resp, err := s.client.RefreshToken(context.Background(), "rt", "") resp, err := s.client.RefreshToken(context.Background(), "rt", "")

View File

@@ -33,7 +33,7 @@ type usageRequestCapture struct {
func (s *ClaudeUsageServiceSuite) TestFetchUsage_Success() { func (s *ClaudeUsageServiceSuite) TestFetchUsage_Success() {
var captured usageRequestCapture var captured usageRequestCapture
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
captured.authorization = r.Header.Get("Authorization") captured.authorization = r.Header.Get("Authorization")
captured.anthropicBeta = r.Header.Get("anthropic-beta") captured.anthropicBeta = r.Header.Get("anthropic-beta")
@@ -59,7 +59,7 @@ func (s *ClaudeUsageServiceSuite) TestFetchUsage_Success() {
} }
func (s *ClaudeUsageServiceSuite) TestFetchUsage_NonOK() { func (s *ClaudeUsageServiceSuite) TestFetchUsage_NonOK() {
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusUnauthorized) w.WriteHeader(http.StatusUnauthorized)
_, _ = io.WriteString(w, "nope") _, _ = io.WriteString(w, "nope")
})) }))
@@ -73,7 +73,7 @@ func (s *ClaudeUsageServiceSuite) TestFetchUsage_NonOK() {
} }
func (s *ClaudeUsageServiceSuite) TestFetchUsage_BadJSON() { func (s *ClaudeUsageServiceSuite) TestFetchUsage_BadJSON() {
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
_, _ = io.WriteString(w, "not-json") _, _ = io.WriteString(w, "not-json")
})) }))
@@ -86,7 +86,7 @@ func (s *ClaudeUsageServiceSuite) TestFetchUsage_BadJSON() {
} }
func (s *ClaudeUsageServiceSuite) TestFetchUsage_ContextCancel() { func (s *ClaudeUsageServiceSuite) TestFetchUsage_ContextCancel() {
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Never respond - simulate slow server // Never respond - simulate slow server
<-r.Context().Done() <-r.Context().Done()
})) }))

View File

@@ -49,7 +49,7 @@ func (s *GitHubReleaseServiceSuite) TearDownTest() {
} }
func (s *GitHubReleaseServiceSuite) TestDownloadFile_EnforcesMaxSize_ContentLength() { func (s *GitHubReleaseServiceSuite) TestDownloadFile_EnforcesMaxSize_ContentLength() {
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Length", "100") w.Header().Set("Content-Length", "100")
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
_, _ = w.Write(bytes.Repeat([]byte("a"), 100)) _, _ = w.Write(bytes.Repeat([]byte("a"), 100))
@@ -68,7 +68,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_EnforcesMaxSize_ContentLeng
} }
func (s *GitHubReleaseServiceSuite) TestDownloadFile_EnforcesMaxSize_Chunked() { func (s *GitHubReleaseServiceSuite) TestDownloadFile_EnforcesMaxSize_Chunked() {
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Force chunked encoding (unknown Content-Length) by flushing headers before writing. // Force chunked encoding (unknown Content-Length) by flushing headers before writing.
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
if fl, ok := w.(http.Flusher); ok { if fl, ok := w.(http.Flusher); ok {
@@ -95,7 +95,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_EnforcesMaxSize_Chunked() {
} }
func (s *GitHubReleaseServiceSuite) TestDownloadFile_Success() { func (s *GitHubReleaseServiceSuite) TestDownloadFile_Success() {
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
if fl, ok := w.(http.Flusher); ok { if fl, ok := w.(http.Flusher); ok {
fl.Flush() fl.Flush()
@@ -123,7 +123,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_Success() {
} }
func (s *GitHubReleaseServiceSuite) TestDownloadFile_404() { func (s *GitHubReleaseServiceSuite) TestDownloadFile_404() {
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound) w.WriteHeader(http.StatusNotFound)
})) }))
@@ -140,7 +140,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_404() {
} }
func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_Success() { func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_Success() {
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("sum")) _, _ = w.Write([]byte("sum"))
})) }))
@@ -155,7 +155,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_Success() {
} }
func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_Non200() { func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_Non200() {
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError) w.WriteHeader(http.StatusInternalServerError)
})) }))
@@ -168,7 +168,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_Non200() {
} }
func (s *GitHubReleaseServiceSuite) TestDownloadFile_ContextCancel() { func (s *GitHubReleaseServiceSuite) TestDownloadFile_ContextCancel() {
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
<-r.Context().Done() <-r.Context().Done()
})) }))
@@ -195,7 +195,7 @@ func (s *GitHubReleaseServiceSuite) TestDownloadFile_InvalidURL() {
} }
func (s *GitHubReleaseServiceSuite) TestDownloadFile_InvalidDestPath() { func (s *GitHubReleaseServiceSuite) TestDownloadFile_InvalidDestPath() {
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("content")) _, _ = w.Write([]byte("content"))
})) }))
@@ -233,7 +233,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_Success() {
] ]
}` }`
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(s.T(), "/repos/test/repo/releases/latest", r.URL.Path) require.Equal(s.T(), "/repos/test/repo/releases/latest", r.URL.Path)
require.Equal(s.T(), "application/vnd.github.v3+json", r.Header.Get("Accept")) require.Equal(s.T(), "application/vnd.github.v3+json", r.Header.Get("Accept"))
require.Equal(s.T(), "Sub2API-Updater", r.Header.Get("User-Agent")) require.Equal(s.T(), "Sub2API-Updater", r.Header.Get("User-Agent"))
@@ -258,7 +258,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_Success() {
} }
func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_Non200() { func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_Non200() {
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound) w.WriteHeader(http.StatusNotFound)
})) }))
@@ -274,7 +274,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_Non200() {
} }
func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_InvalidJSON() { func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_InvalidJSON() {
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("not valid json")) _, _ = w.Write([]byte("not valid json"))
})) }))
@@ -290,7 +290,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_InvalidJSON() {
} }
func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_ContextCancel() { func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_ContextCancel() {
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
<-r.Context().Done() <-r.Context().Done()
})) }))
@@ -308,7 +308,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_ContextCancel() {
} }
func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_ContextCancel() { func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_ContextCancel() {
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
<-r.Context().Done() <-r.Context().Done()
})) }))

View File

@@ -3,7 +3,6 @@ package repository
import ( import (
"io" "io"
"net/http" "net/http"
"net/http/httptest"
"sync/atomic" "sync/atomic"
"testing" "testing"
"time" "time"
@@ -93,7 +92,7 @@ func (s *HTTPUpstreamSuite) TestAcquireClient_OverLimitReturnsError() {
// 验证空代理 URL 时请求直接发送到目标服务器 // 验证空代理 URL 时请求直接发送到目标服务器
func (s *HTTPUpstreamSuite) TestDo_WithoutProxy_GoesDirect() { func (s *HTTPUpstreamSuite) TestDo_WithoutProxy_GoesDirect() {
// 创建模拟上游服务器 // 创建模拟上游服务器
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { upstream := newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = io.WriteString(w, "direct") _, _ = io.WriteString(w, "direct")
})) }))
s.T().Cleanup(upstream.Close) s.T().Cleanup(upstream.Close)
@@ -115,7 +114,7 @@ func (s *HTTPUpstreamSuite) TestDo_WithHTTPProxy_UsesProxy() {
// 用于接收代理请求的通道 // 用于接收代理请求的通道
seen := make(chan string, 1) seen := make(chan string, 1)
// 创建模拟代理服务器 // 创建模拟代理服务器
proxySrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { proxySrv := newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
seen <- r.RequestURI // 记录请求 URI seen <- r.RequestURI // 记录请求 URI
_, _ = io.WriteString(w, "proxied") _, _ = io.WriteString(w, "proxied")
})) }))
@@ -145,7 +144,7 @@ func (s *HTTPUpstreamSuite) TestDo_WithHTTPProxy_UsesProxy() {
// TestDo_EmptyProxy_UsesDirect 测试空代理字符串 // TestDo_EmptyProxy_UsesDirect 测试空代理字符串
// 验证空字符串代理等同于直连 // 验证空字符串代理等同于直连
func (s *HTTPUpstreamSuite) TestDo_EmptyProxy_UsesDirect() { func (s *HTTPUpstreamSuite) TestDo_EmptyProxy_UsesDirect() {
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { upstream := newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = io.WriteString(w, "direct-empty") _, _ = io.WriteString(w, "direct-empty")
})) }))
s.T().Cleanup(upstream.Close) s.T().Cleanup(upstream.Close)

View File

@@ -0,0 +1,63 @@
package repository
import (
"bytes"
"io"
"net"
"net/http"
"net/http/httptest"
"sync"
"testing"
)
type roundTripFunc func(*http.Request) (*http.Response, error)
func (f roundTripFunc) RoundTrip(r *http.Request) (*http.Response, error) { return f(r) }
// newInProcessTransport adapts an http.HandlerFunc into an http.RoundTripper without opening sockets.
// It captures the request body (if any) and then rewinds it before invoking the handler.
func newInProcessTransport(handler http.HandlerFunc, capture func(r *http.Request, body []byte)) http.RoundTripper {
return roundTripFunc(func(r *http.Request) (*http.Response, error) {
var body []byte
if r.Body != nil {
body, _ = io.ReadAll(r.Body)
_ = r.Body.Close()
r.Body = io.NopCloser(bytes.NewReader(body))
}
if capture != nil {
capture(r, body)
}
rec := httptest.NewRecorder()
handler(rec, r)
return rec.Result(), nil
})
}
var (
canListenOnce sync.Once
canListen bool
canListenErr error
)
func localListenerAvailable() bool {
canListenOnce.Do(func() {
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
canListenErr = err
canListen = false
return
}
_ = ln.Close()
canListen = true
})
return canListen
}
func newLocalTestServer(tb testing.TB, handler http.Handler) *httptest.Server {
tb.Helper()
if !localListenerAvailable() {
tb.Skipf("local listeners are not permitted in this environment: %v", canListenErr)
}
return httptest.NewServer(handler)
}

View File

@@ -34,7 +34,7 @@ func (s *OpenAIOAuthServiceSuite) TearDownTest() {
} }
func (s *OpenAIOAuthServiceSuite) setupServer(handler http.HandlerFunc) { func (s *OpenAIOAuthServiceSuite) setupServer(handler http.HandlerFunc) {
s.srv = httptest.NewServer(handler) s.srv = newLocalTestServer(s.T(), handler)
s.svc = &openaiOAuthService{tokenURL: s.srv.URL} s.svc = &openaiOAuthService{tokenURL: s.srv.URL}
} }

View File

@@ -32,7 +32,7 @@ func (s *PricingServiceSuite) TearDownTest() {
} }
func (s *PricingServiceSuite) setupServer(handler http.HandlerFunc) { func (s *PricingServiceSuite) setupServer(handler http.HandlerFunc) {
s.srv = httptest.NewServer(handler) s.srv = newLocalTestServer(s.T(), handler)
} }
func (s *PricingServiceSuite) TestFetchPricingJSON_Success() { func (s *PricingServiceSuite) TestFetchPricingJSON_Success() {

View File

@@ -31,7 +31,7 @@ func (s *ProxyProbeServiceSuite) TearDownTest() {
} }
func (s *ProxyProbeServiceSuite) setupProxyServer(handler http.HandlerFunc) { func (s *ProxyProbeServiceSuite) setupProxyServer(handler http.HandlerFunc) {
s.proxySrv = httptest.NewServer(handler) s.proxySrv = newLocalTestServer(s.T(), handler)
} }
func (s *ProxyProbeServiceSuite) TestProbeProxy_InvalidProxyURL() { func (s *ProxyProbeServiceSuite) TestProbeProxy_InvalidProxyURL() {

View File

@@ -3,9 +3,9 @@ package repository
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"errors"
"io" "io"
"net/http" "net/http"
"net/http/httptest"
"net/url" "net/url"
"strings" "strings"
"testing" "testing"
@@ -18,7 +18,6 @@ import (
type TurnstileServiceSuite struct { type TurnstileServiceSuite struct {
suite.Suite suite.Suite
ctx context.Context ctx context.Context
srv *httptest.Server
verifier *turnstileVerifier verifier *turnstileVerifier
received chan url.Values received chan url.Values
} }
@@ -31,20 +30,15 @@ func (s *TurnstileServiceSuite) SetupTest() {
s.verifier = verifier s.verifier = verifier
} }
func (s *TurnstileServiceSuite) TearDownTest() { func (s *TurnstileServiceSuite) setupTransport(handler http.HandlerFunc) {
if s.srv != nil { s.verifier.verifyURL = "http://in-process/turnstile"
s.srv.Close() s.verifier.httpClient = &http.Client{
s.srv = nil Transport: newInProcessTransport(handler, nil),
} }
} }
func (s *TurnstileServiceSuite) setupServer(handler http.HandlerFunc) {
s.srv = httptest.NewServer(handler)
s.verifier.verifyURL = s.srv.URL
}
func (s *TurnstileServiceSuite) TestVerifyToken_SendsFormAndDecodesJSON() { func (s *TurnstileServiceSuite) TestVerifyToken_SendsFormAndDecodesJSON() {
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { s.setupTransport(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Capture form data in main goroutine context later // Capture form data in main goroutine context later
body, _ := io.ReadAll(r.Body) body, _ := io.ReadAll(r.Body)
values, _ := url.ParseQuery(string(body)) values, _ := url.ParseQuery(string(body))
@@ -72,7 +66,7 @@ func (s *TurnstileServiceSuite) TestVerifyToken_SendsFormAndDecodesJSON() {
func (s *TurnstileServiceSuite) TestVerifyToken_ContentType() { func (s *TurnstileServiceSuite) TestVerifyToken_ContentType() {
var contentType string var contentType string
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { s.setupTransport(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
contentType = r.Header.Get("Content-Type") contentType = r.Header.Get("Content-Type")
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(service.TurnstileVerifyResponse{Success: true}) _ = json.NewEncoder(w).Encode(service.TurnstileVerifyResponse{Success: true})
@@ -84,7 +78,7 @@ func (s *TurnstileServiceSuite) TestVerifyToken_ContentType() {
} }
func (s *TurnstileServiceSuite) TestVerifyToken_EmptyRemoteIP_NotSent() { func (s *TurnstileServiceSuite) TestVerifyToken_EmptyRemoteIP_NotSent() {
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { s.setupTransport(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, _ := io.ReadAll(r.Body) body, _ := io.ReadAll(r.Body)
values, _ := url.ParseQuery(string(body)) values, _ := url.ParseQuery(string(body))
s.received <- values s.received <- values
@@ -105,15 +99,19 @@ func (s *TurnstileServiceSuite) TestVerifyToken_EmptyRemoteIP_NotSent() {
} }
func (s *TurnstileServiceSuite) TestVerifyToken_RequestError() { func (s *TurnstileServiceSuite) TestVerifyToken_RequestError() {
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) s.verifier.verifyURL = "http://in-process/turnstile"
s.srv.Close() s.verifier.httpClient = &http.Client{
Transport: roundTripFunc(func(*http.Request) (*http.Response, error) {
return nil, errors.New("dial failed")
}),
}
_, err := s.verifier.VerifyToken(s.ctx, "sk", "token", "1.1.1.1") _, err := s.verifier.VerifyToken(s.ctx, "sk", "token", "1.1.1.1")
require.Error(s.T(), err, "expected error when server is closed") require.Error(s.T(), err, "expected error when server is closed")
} }
func (s *TurnstileServiceSuite) TestVerifyToken_InvalidJSON() { func (s *TurnstileServiceSuite) TestVerifyToken_InvalidJSON() {
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { s.setupTransport(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
_, _ = io.WriteString(w, "not-valid-json") _, _ = io.WriteString(w, "not-valid-json")
})) }))
@@ -123,7 +121,7 @@ func (s *TurnstileServiceSuite) TestVerifyToken_InvalidJSON() {
} }
func (s *TurnstileServiceSuite) TestVerifyToken_SuccessFalse() { func (s *TurnstileServiceSuite) TestVerifyToken_SuccessFalse() {
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { s.setupTransport(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(service.TurnstileVerifyResponse{ _ = json.NewEncoder(w).Encode(service.TurnstileVerifyResponse{
Success: false, Success: false,

View File

@@ -3,6 +3,7 @@ package repository
import ( import (
"context" "context"
"database/sql" "database/sql"
"errors"
"fmt" "fmt"
"os" "os"
"strings" "strings"
@@ -70,6 +71,9 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
createdAt = time.Now() createdAt = time.Now()
} }
requestID := strings.TrimSpace(log.RequestID)
log.RequestID = requestID
rateMultiplier := log.RateMultiplier rateMultiplier := log.RateMultiplier
query := ` query := `
@@ -107,6 +111,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
$14, $15, $16, $17, $18, $19, $14, $15, $16, $17, $18, $19,
$20, $21, $22, $23, $24, $25 $20, $21, $22, $23, $24, $25
) )
ON CONFLICT (request_id, api_key_id) DO NOTHING
RETURNING id, created_at RETURNING id, created_at
` `
@@ -115,11 +120,16 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
duration := nullInt(log.DurationMs) duration := nullInt(log.DurationMs)
firstToken := nullInt(log.FirstTokenMs) firstToken := nullInt(log.FirstTokenMs)
var requestIDArg any
if requestID != "" {
requestIDArg = requestID
}
args := []any{ args := []any{
log.UserID, log.UserID,
log.APIKeyID, log.ApiKeyID,
log.AccountID, log.AccountID,
log.RequestID, requestIDArg,
log.Model, log.Model,
groupID, groupID,
subscriptionID, subscriptionID,
@@ -143,7 +153,14 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
createdAt, createdAt,
} }
if err := scanSingleRow(ctx, r.sql, query, args, &log.ID, &log.CreatedAt); err != nil { if err := scanSingleRow(ctx, r.sql, query, args, &log.ID, &log.CreatedAt); err != nil {
return err if errors.Is(err, sql.ErrNoRows) && requestID != "" {
selectQuery := "SELECT id, created_at FROM usage_logs WHERE request_id = $1 AND api_key_id = $2"
if err := scanSingleRow(ctx, r.sql, selectQuery, []any{requestID, log.ApiKeyID}, &log.ID, &log.CreatedAt); err != nil {
return err
}
} else {
return err
}
} }
log.RateMultiplier = rateMultiplier log.RateMultiplier = rateMultiplier
return nil return nil
@@ -183,7 +200,7 @@ func (r *usageLogRepository) ListByUser(ctx context.Context, userID int64, param
return r.listUsageLogsWithPagination(ctx, "WHERE user_id = $1", []any{userID}, params) return r.listUsageLogsWithPagination(ctx, "WHERE user_id = $1", []any{userID}, params)
} }
func (r *usageLogRepository) ListByAPIKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) { func (r *usageLogRepository) ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
return r.listUsageLogsWithPagination(ctx, "WHERE api_key_id = $1", []any{apiKeyID}, params) return r.listUsageLogsWithPagination(ctx, "WHERE api_key_id = $1", []any{apiKeyID}, params)
} }
@@ -270,8 +287,8 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS
r.sql, r.sql,
apiKeyStatsQuery, apiKeyStatsQuery,
[]any{service.StatusActive}, []any{service.StatusActive},
&stats.TotalAPIKeys, &stats.TotalApiKeys,
&stats.ActiveAPIKeys, &stats.ActiveApiKeys,
); err != nil { ); err != nil {
return nil, err return nil, err
} }
@@ -418,8 +435,8 @@ func (r *usageLogRepository) GetUserStatsAggregated(ctx context.Context, userID
return &stats, nil return &stats, nil
} }
// GetAPIKeyStatsAggregated returns aggregated usage statistics for an API key using database-level aggregation // GetApiKeyStatsAggregated returns aggregated usage statistics for an API key using database-level aggregation
func (r *usageLogRepository) GetAPIKeyStatsAggregated(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) { func (r *usageLogRepository) GetApiKeyStatsAggregated(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
query := ` query := `
SELECT SELECT
COUNT(*) as total_requests, COUNT(*) as total_requests,
@@ -623,7 +640,7 @@ func resolveUsageStatsTimezone() string {
return "UTC" return "UTC"
} }
func (r *usageLogRepository) ListByAPIKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) { func (r *usageLogRepository) ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE api_key_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC" query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE api_key_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC"
logs, err := r.queryUsageLogs(ctx, query, apiKeyID, startTime, endTime) logs, err := r.queryUsageLogs(ctx, query, apiKeyID, startTime, endTime)
return logs, nil, err return logs, nil, err
@@ -709,11 +726,11 @@ type ModelStat = usagestats.ModelStat
// UserUsageTrendPoint represents user usage trend data point // UserUsageTrendPoint represents user usage trend data point
type UserUsageTrendPoint = usagestats.UserUsageTrendPoint type UserUsageTrendPoint = usagestats.UserUsageTrendPoint
// APIKeyUsageTrendPoint represents API key usage trend data point // ApiKeyUsageTrendPoint represents API key usage trend data point
type APIKeyUsageTrendPoint = usagestats.APIKeyUsageTrendPoint type ApiKeyUsageTrendPoint = usagestats.ApiKeyUsageTrendPoint
// GetAPIKeyUsageTrend returns usage trend data grouped by API key and date // GetApiKeyUsageTrend returns usage trend data grouped by API key and date
func (r *usageLogRepository) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) (results []APIKeyUsageTrendPoint, err error) { func (r *usageLogRepository) GetApiKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) (results []ApiKeyUsageTrendPoint, err error) {
dateFormat := "YYYY-MM-DD" dateFormat := "YYYY-MM-DD"
if granularity == "hour" { if granularity == "hour" {
dateFormat = "YYYY-MM-DD HH24:00" dateFormat = "YYYY-MM-DD HH24:00"
@@ -755,10 +772,10 @@ func (r *usageLogRepository) GetAPIKeyUsageTrend(ctx context.Context, startTime,
} }
}() }()
results = make([]APIKeyUsageTrendPoint, 0) results = make([]ApiKeyUsageTrendPoint, 0)
for rows.Next() { for rows.Next() {
var row APIKeyUsageTrendPoint var row ApiKeyUsageTrendPoint
if err = rows.Scan(&row.Date, &row.APIKeyID, &row.KeyName, &row.Requests, &row.Tokens); err != nil { if err = rows.Scan(&row.Date, &row.ApiKeyID, &row.KeyName, &row.Requests, &row.Tokens); err != nil {
return nil, err return nil, err
} }
results = append(results, row) results = append(results, row)
@@ -844,7 +861,7 @@ func (r *usageLogRepository) GetUserDashboardStats(ctx context.Context, userID i
r.sql, r.sql,
"SELECT COUNT(*) FROM api_keys WHERE user_id = $1 AND deleted_at IS NULL", "SELECT COUNT(*) FROM api_keys WHERE user_id = $1 AND deleted_at IS NULL",
[]any{userID}, []any{userID},
&stats.TotalAPIKeys, &stats.TotalApiKeys,
); err != nil { ); err != nil {
return nil, err return nil, err
} }
@@ -853,7 +870,7 @@ func (r *usageLogRepository) GetUserDashboardStats(ctx context.Context, userID i
r.sql, r.sql,
"SELECT COUNT(*) FROM api_keys WHERE user_id = $1 AND status = $2 AND deleted_at IS NULL", "SELECT COUNT(*) FROM api_keys WHERE user_id = $1 AND status = $2 AND deleted_at IS NULL",
[]any{userID, service.StatusActive}, []any{userID, service.StatusActive},
&stats.ActiveAPIKeys, &stats.ActiveApiKeys,
); err != nil { ); err != nil {
return nil, err return nil, err
} }
@@ -1023,9 +1040,9 @@ func (r *usageLogRepository) ListWithFilters(ctx context.Context, params paginat
conditions = append(conditions, fmt.Sprintf("user_id = $%d", len(args)+1)) conditions = append(conditions, fmt.Sprintf("user_id = $%d", len(args)+1))
args = append(args, filters.UserID) args = append(args, filters.UserID)
} }
if filters.APIKeyID > 0 { if filters.ApiKeyID > 0 {
conditions = append(conditions, fmt.Sprintf("api_key_id = $%d", len(args)+1)) conditions = append(conditions, fmt.Sprintf("api_key_id = $%d", len(args)+1))
args = append(args, filters.APIKeyID) args = append(args, filters.ApiKeyID)
} }
if filters.AccountID > 0 { if filters.AccountID > 0 {
conditions = append(conditions, fmt.Sprintf("account_id = $%d", len(args)+1)) conditions = append(conditions, fmt.Sprintf("account_id = $%d", len(args)+1))
@@ -1145,18 +1162,18 @@ func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs
return result, nil return result, nil
} }
// BatchAPIKeyUsageStats represents usage stats for a single API key // BatchApiKeyUsageStats represents usage stats for a single API key
type BatchAPIKeyUsageStats = usagestats.BatchAPIKeyUsageStats type BatchApiKeyUsageStats = usagestats.BatchApiKeyUsageStats
// GetBatchAPIKeyUsageStats gets today and total actual_cost for multiple API keys // GetBatchApiKeyUsageStats gets today and total actual_cost for multiple API keys
func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*BatchAPIKeyUsageStats, error) { func (r *usageLogRepository) GetBatchApiKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*BatchApiKeyUsageStats, error) {
result := make(map[int64]*BatchAPIKeyUsageStats) result := make(map[int64]*BatchApiKeyUsageStats)
if len(apiKeyIDs) == 0 { if len(apiKeyIDs) == 0 {
return result, nil return result, nil
} }
for _, id := range apiKeyIDs { for _, id := range apiKeyIDs {
result[id] = &BatchAPIKeyUsageStats{APIKeyID: id} result[id] = &BatchApiKeyUsageStats{ApiKeyID: id}
} }
query := ` query := `
@@ -1582,7 +1599,7 @@ func (r *usageLogRepository) hydrateUsageLogAssociations(ctx context.Context, lo
if err != nil { if err != nil {
return err return err
} }
apiKeys, err := r.loadAPIKeys(ctx, ids.apiKeyIDs) apiKeys, err := r.loadApiKeys(ctx, ids.apiKeyIDs)
if err != nil { if err != nil {
return err return err
} }
@@ -1603,8 +1620,8 @@ func (r *usageLogRepository) hydrateUsageLogAssociations(ctx context.Context, lo
if user, ok := users[logs[i].UserID]; ok { if user, ok := users[logs[i].UserID]; ok {
logs[i].User = user logs[i].User = user
} }
if key, ok := apiKeys[logs[i].APIKeyID]; ok { if key, ok := apiKeys[logs[i].ApiKeyID]; ok {
logs[i].APIKey = key logs[i].ApiKey = key
} }
if acc, ok := accounts[logs[i].AccountID]; ok { if acc, ok := accounts[logs[i].AccountID]; ok {
logs[i].Account = acc logs[i].Account = acc
@@ -1642,7 +1659,7 @@ func collectUsageLogIDs(logs []service.UsageLog) usageLogIDs {
for i := range logs { for i := range logs {
userIDs[logs[i].UserID] = struct{}{} userIDs[logs[i].UserID] = struct{}{}
apiKeyIDs[logs[i].APIKeyID] = struct{}{} apiKeyIDs[logs[i].ApiKeyID] = struct{}{}
accountIDs[logs[i].AccountID] = struct{}{} accountIDs[logs[i].AccountID] = struct{}{}
if logs[i].GroupID != nil { if logs[i].GroupID != nil {
groupIDs[*logs[i].GroupID] = struct{}{} groupIDs[*logs[i].GroupID] = struct{}{}
@@ -1676,12 +1693,12 @@ func (r *usageLogRepository) loadUsers(ctx context.Context, ids []int64) (map[in
return out, nil return out, nil
} }
func (r *usageLogRepository) loadAPIKeys(ctx context.Context, ids []int64) (map[int64]*service.APIKey, error) { func (r *usageLogRepository) loadApiKeys(ctx context.Context, ids []int64) (map[int64]*service.ApiKey, error) {
out := make(map[int64]*service.APIKey) out := make(map[int64]*service.ApiKey)
if len(ids) == 0 { if len(ids) == 0 {
return out, nil return out, nil
} }
models, err := r.client.APIKey.Query().Where(dbapikey.IDIn(ids...)).All(ctx) models, err := r.client.ApiKey.Query().Where(dbapikey.IDIn(ids...)).All(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -1800,7 +1817,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
log := &service.UsageLog{ log := &service.UsageLog{
ID: id, ID: id,
UserID: userID, UserID: userID,
APIKeyID: apiKeyID, ApiKeyID: apiKeyID,
AccountID: accountID, AccountID: accountID,
Model: model, Model: model,
InputTokens: inputTokens, InputTokens: inputTokens,

View File

@@ -7,6 +7,8 @@ import (
"testing" "testing"
"time" "time"
"github.com/google/uuid"
dbent "github.com/Wei-Shaw/sub2api/ent" dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone" "github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
@@ -35,11 +37,12 @@ func TestUsageLogRepoSuite(t *testing.T) {
suite.Run(t, new(UsageLogRepoSuite)) suite.Run(t, new(UsageLogRepoSuite))
} }
func (s *UsageLogRepoSuite) createUsageLog(user *service.User, apiKey *service.APIKey, account *service.Account, inputTokens, outputTokens int, cost float64, createdAt time.Time) *service.UsageLog { func (s *UsageLogRepoSuite) createUsageLog(user *service.User, apiKey *service.ApiKey, account *service.Account, inputTokens, outputTokens int, cost float64, createdAt time.Time) *service.UsageLog {
log := &service.UsageLog{ log := &service.UsageLog{
UserID: user.ID, UserID: user.ID,
APIKeyID: apiKey.ID, ApiKeyID: apiKey.ID,
AccountID: account.ID, AccountID: account.ID,
RequestID: uuid.New().String(), // Generate unique RequestID for each log
Model: "claude-3", Model: "claude-3",
InputTokens: inputTokens, InputTokens: inputTokens,
OutputTokens: outputTokens, OutputTokens: outputTokens,
@@ -55,12 +58,12 @@ func (s *UsageLogRepoSuite) createUsageLog(user *service.User, apiKey *service.A
func (s *UsageLogRepoSuite) TestCreate() { func (s *UsageLogRepoSuite) TestCreate() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "create@test.com"}) user := mustCreateUser(s.T(), s.client, &service.User{Email: "create@test.com"})
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-create", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-create", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-create"}) account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-create"})
log := &service.UsageLog{ log := &service.UsageLog{
UserID: user.ID, UserID: user.ID,
APIKeyID: apiKey.ID, ApiKeyID: apiKey.ID,
AccountID: account.ID, AccountID: account.ID,
Model: "claude-3", Model: "claude-3",
InputTokens: 10, InputTokens: 10,
@@ -76,7 +79,7 @@ func (s *UsageLogRepoSuite) TestCreate() {
func (s *UsageLogRepoSuite) TestGetByID() { func (s *UsageLogRepoSuite) TestGetByID() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "getbyid@test.com"}) user := mustCreateUser(s.T(), s.client, &service.User{Email: "getbyid@test.com"})
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-getbyid", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-getbyid", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-getbyid"}) account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-getbyid"})
log := s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now()) log := s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
@@ -96,7 +99,7 @@ func (s *UsageLogRepoSuite) TestGetByID_NotFound() {
func (s *UsageLogRepoSuite) TestDelete() { func (s *UsageLogRepoSuite) TestDelete() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "delete@test.com"}) user := mustCreateUser(s.T(), s.client, &service.User{Email: "delete@test.com"})
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-delete", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-delete", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-delete"}) account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-delete"})
log := s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now()) log := s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
@@ -112,7 +115,7 @@ func (s *UsageLogRepoSuite) TestDelete() {
func (s *UsageLogRepoSuite) TestListByUser() { func (s *UsageLogRepoSuite) TestListByUser() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "listbyuser@test.com"}) user := mustCreateUser(s.T(), s.client, &service.User{Email: "listbyuser@test.com"})
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-listbyuser", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-listbyuser", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-listbyuser"}) account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-listbyuser"})
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now()) s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
@@ -124,18 +127,18 @@ func (s *UsageLogRepoSuite) TestListByUser() {
s.Require().Equal(int64(2), page.Total) s.Require().Equal(int64(2), page.Total)
} }
// --- ListByAPIKey --- // --- ListByApiKey ---
func (s *UsageLogRepoSuite) TestListByAPIKey() { func (s *UsageLogRepoSuite) TestListByApiKey() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "listbyapikey@test.com"}) user := mustCreateUser(s.T(), s.client, &service.User{Email: "listbyapikey@test.com"})
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-listbyapikey", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-listbyapikey", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-listbyapikey"}) account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-listbyapikey"})
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now()) s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
s.createUsageLog(user, apiKey, account, 15, 25, 0.6, time.Now()) s.createUsageLog(user, apiKey, account, 15, 25, 0.6, time.Now())
logs, page, err := s.repo.ListByAPIKey(s.ctx, apiKey.ID, pagination.PaginationParams{Page: 1, PageSize: 10}) logs, page, err := s.repo.ListByApiKey(s.ctx, apiKey.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "ListByAPIKey") s.Require().NoError(err, "ListByApiKey")
s.Require().Len(logs, 2) s.Require().Len(logs, 2)
s.Require().Equal(int64(2), page.Total) s.Require().Equal(int64(2), page.Total)
} }
@@ -144,7 +147,7 @@ func (s *UsageLogRepoSuite) TestListByAPIKey() {
func (s *UsageLogRepoSuite) TestListByAccount() { func (s *UsageLogRepoSuite) TestListByAccount() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "listbyaccount@test.com"}) user := mustCreateUser(s.T(), s.client, &service.User{Email: "listbyaccount@test.com"})
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-listbyaccount", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-listbyaccount", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-listbyaccount"}) account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-listbyaccount"})
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now()) s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
@@ -159,7 +162,7 @@ func (s *UsageLogRepoSuite) TestListByAccount() {
func (s *UsageLogRepoSuite) TestGetUserStats() { func (s *UsageLogRepoSuite) TestGetUserStats() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "userstats@test.com"}) user := mustCreateUser(s.T(), s.client, &service.User{Email: "userstats@test.com"})
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-userstats", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-userstats", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-userstats"}) account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-userstats"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
@@ -179,7 +182,7 @@ func (s *UsageLogRepoSuite) TestGetUserStats() {
func (s *UsageLogRepoSuite) TestListWithFilters() { func (s *UsageLogRepoSuite) TestListWithFilters() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "filters@test.com"}) user := mustCreateUser(s.T(), s.client, &service.User{Email: "filters@test.com"})
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-filters", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-filters", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-filters"}) account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-filters"})
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now()) s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
@@ -211,8 +214,8 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
}) })
group := mustCreateGroup(s.T(), s.client, &service.Group{Name: "g-ul"}) group := mustCreateGroup(s.T(), s.client, &service.Group{Name: "g-ul"})
apiKey1 := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: userToday.ID, Key: "sk-ul-1", Name: "ul1"}) apiKey1 := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: userToday.ID, Key: "sk-ul-1", Name: "ul1"})
mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: userOld.ID, Key: "sk-ul-2", Name: "ul2", Status: service.StatusDisabled}) mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: userOld.ID, Key: "sk-ul-2", Name: "ul2", Status: service.StatusDisabled})
resetAt := now.Add(10 * time.Minute) resetAt := now.Add(10 * time.Minute)
accNormal := mustCreateAccount(s.T(), s.client, &service.Account{Name: "a-normal", Schedulable: true}) accNormal := mustCreateAccount(s.T(), s.client, &service.Account{Name: "a-normal", Schedulable: true})
@@ -223,7 +226,7 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
d1, d2, d3 := 100, 200, 300 d1, d2, d3 := 100, 200, 300
logToday := &service.UsageLog{ logToday := &service.UsageLog{
UserID: userToday.ID, UserID: userToday.ID,
APIKeyID: apiKey1.ID, ApiKeyID: apiKey1.ID,
AccountID: accNormal.ID, AccountID: accNormal.ID,
Model: "claude-3", Model: "claude-3",
GroupID: &group.ID, GroupID: &group.ID,
@@ -240,7 +243,7 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
logOld := &service.UsageLog{ logOld := &service.UsageLog{
UserID: userOld.ID, UserID: userOld.ID,
APIKeyID: apiKey1.ID, ApiKeyID: apiKey1.ID,
AccountID: accNormal.ID, AccountID: accNormal.ID,
Model: "claude-3", Model: "claude-3",
InputTokens: 5, InputTokens: 5,
@@ -254,7 +257,7 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
logPerf := &service.UsageLog{ logPerf := &service.UsageLog{
UserID: userToday.ID, UserID: userToday.ID,
APIKeyID: apiKey1.ID, ApiKeyID: apiKey1.ID,
AccountID: accNormal.ID, AccountID: accNormal.ID,
Model: "claude-3", Model: "claude-3",
InputTokens: 1, InputTokens: 1,
@@ -272,8 +275,8 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
s.Require().Equal(baseStats.TotalUsers+2, stats.TotalUsers, "TotalUsers mismatch") s.Require().Equal(baseStats.TotalUsers+2, stats.TotalUsers, "TotalUsers mismatch")
s.Require().Equal(baseStats.TodayNewUsers+1, stats.TodayNewUsers, "TodayNewUsers mismatch") s.Require().Equal(baseStats.TodayNewUsers+1, stats.TodayNewUsers, "TodayNewUsers mismatch")
s.Require().Equal(baseStats.ActiveUsers+1, stats.ActiveUsers, "ActiveUsers mismatch") s.Require().Equal(baseStats.ActiveUsers+1, stats.ActiveUsers, "ActiveUsers mismatch")
s.Require().Equal(baseStats.TotalAPIKeys+2, stats.TotalAPIKeys, "TotalAPIKeys mismatch") s.Require().Equal(baseStats.TotalApiKeys+2, stats.TotalApiKeys, "TotalApiKeys mismatch")
s.Require().Equal(baseStats.ActiveAPIKeys+1, stats.ActiveAPIKeys, "ActiveAPIKeys mismatch") s.Require().Equal(baseStats.ActiveApiKeys+1, stats.ActiveApiKeys, "ActiveApiKeys mismatch")
s.Require().Equal(baseStats.TotalAccounts+4, stats.TotalAccounts, "TotalAccounts mismatch") s.Require().Equal(baseStats.TotalAccounts+4, stats.TotalAccounts, "TotalAccounts mismatch")
s.Require().Equal(baseStats.ErrorAccounts+1, stats.ErrorAccounts, "ErrorAccounts mismatch") s.Require().Equal(baseStats.ErrorAccounts+1, stats.ErrorAccounts, "ErrorAccounts mismatch")
s.Require().Equal(baseStats.RateLimitAccounts+1, stats.RateLimitAccounts, "RateLimitAccounts mismatch") s.Require().Equal(baseStats.RateLimitAccounts+1, stats.RateLimitAccounts, "RateLimitAccounts mismatch")
@@ -300,14 +303,14 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
func (s *UsageLogRepoSuite) TestGetUserDashboardStats() { func (s *UsageLogRepoSuite) TestGetUserDashboardStats() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "userdash@test.com"}) user := mustCreateUser(s.T(), s.client, &service.User{Email: "userdash@test.com"})
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-userdash", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-userdash", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-userdash"}) account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-userdash"})
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now()) s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
stats, err := s.repo.GetUserDashboardStats(s.ctx, user.ID) stats, err := s.repo.GetUserDashboardStats(s.ctx, user.ID)
s.Require().NoError(err, "GetUserDashboardStats") s.Require().NoError(err, "GetUserDashboardStats")
s.Require().Equal(int64(1), stats.TotalAPIKeys) s.Require().Equal(int64(1), stats.TotalApiKeys)
s.Require().Equal(int64(1), stats.TotalRequests) s.Require().Equal(int64(1), stats.TotalRequests)
} }
@@ -315,7 +318,7 @@ func (s *UsageLogRepoSuite) TestGetUserDashboardStats() {
func (s *UsageLogRepoSuite) TestGetAccountTodayStats() { func (s *UsageLogRepoSuite) TestGetAccountTodayStats() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "acctoday@test.com"}) user := mustCreateUser(s.T(), s.client, &service.User{Email: "acctoday@test.com"})
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-acctoday", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-acctoday", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-today"}) account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-today"})
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now()) s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
@@ -331,8 +334,8 @@ func (s *UsageLogRepoSuite) TestGetAccountTodayStats() {
func (s *UsageLogRepoSuite) TestGetBatchUserUsageStats() { func (s *UsageLogRepoSuite) TestGetBatchUserUsageStats() {
user1 := mustCreateUser(s.T(), s.client, &service.User{Email: "batch1@test.com"}) user1 := mustCreateUser(s.T(), s.client, &service.User{Email: "batch1@test.com"})
user2 := mustCreateUser(s.T(), s.client, &service.User{Email: "batch2@test.com"}) user2 := mustCreateUser(s.T(), s.client, &service.User{Email: "batch2@test.com"})
apiKey1 := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user1.ID, Key: "sk-batch1", Name: "k"}) apiKey1 := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user1.ID, Key: "sk-batch1", Name: "k"})
apiKey2 := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user2.ID, Key: "sk-batch2", Name: "k"}) apiKey2 := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user2.ID, Key: "sk-batch2", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-batch"}) account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-batch"})
s.createUsageLog(user1, apiKey1, account, 10, 20, 0.5, time.Now()) s.createUsageLog(user1, apiKey1, account, 10, 20, 0.5, time.Now())
@@ -351,24 +354,24 @@ func (s *UsageLogRepoSuite) TestGetBatchUserUsageStats_Empty() {
s.Require().Empty(stats) s.Require().Empty(stats)
} }
// --- GetBatchAPIKeyUsageStats --- // --- GetBatchApiKeyUsageStats ---
func (s *UsageLogRepoSuite) TestGetBatchAPIKeyUsageStats() { func (s *UsageLogRepoSuite) TestGetBatchApiKeyUsageStats() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "batchkey@test.com"}) user := mustCreateUser(s.T(), s.client, &service.User{Email: "batchkey@test.com"})
apiKey1 := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-batchkey1", Name: "k1"}) apiKey1 := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-batchkey1", Name: "k1"})
apiKey2 := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-batchkey2", Name: "k2"}) apiKey2 := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-batchkey2", Name: "k2"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-batchkey"}) account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-batchkey"})
s.createUsageLog(user, apiKey1, account, 10, 20, 0.5, time.Now()) s.createUsageLog(user, apiKey1, account, 10, 20, 0.5, time.Now())
s.createUsageLog(user, apiKey2, account, 15, 25, 0.6, time.Now()) s.createUsageLog(user, apiKey2, account, 15, 25, 0.6, time.Now())
stats, err := s.repo.GetBatchAPIKeyUsageStats(s.ctx, []int64{apiKey1.ID, apiKey2.ID}) stats, err := s.repo.GetBatchApiKeyUsageStats(s.ctx, []int64{apiKey1.ID, apiKey2.ID})
s.Require().NoError(err, "GetBatchAPIKeyUsageStats") s.Require().NoError(err, "GetBatchApiKeyUsageStats")
s.Require().Len(stats, 2) s.Require().Len(stats, 2)
} }
func (s *UsageLogRepoSuite) TestGetBatchAPIKeyUsageStats_Empty() { func (s *UsageLogRepoSuite) TestGetBatchApiKeyUsageStats_Empty() {
stats, err := s.repo.GetBatchAPIKeyUsageStats(s.ctx, []int64{}) stats, err := s.repo.GetBatchApiKeyUsageStats(s.ctx, []int64{})
s.Require().NoError(err) s.Require().NoError(err)
s.Require().Empty(stats) s.Require().Empty(stats)
} }
@@ -377,7 +380,7 @@ func (s *UsageLogRepoSuite) TestGetBatchAPIKeyUsageStats_Empty() {
func (s *UsageLogRepoSuite) TestGetGlobalStats() { func (s *UsageLogRepoSuite) TestGetGlobalStats() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "global@test.com"}) user := mustCreateUser(s.T(), s.client, &service.User{Email: "global@test.com"})
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-global", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-global", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-global"}) account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-global"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
@@ -402,7 +405,7 @@ func maxTime(a, b time.Time) time.Time {
func (s *UsageLogRepoSuite) TestListByUserAndTimeRange() { func (s *UsageLogRepoSuite) TestListByUserAndTimeRange() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "timerange@test.com"}) user := mustCreateUser(s.T(), s.client, &service.User{Email: "timerange@test.com"})
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-timerange", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-timerange", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-timerange"}) account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-timerange"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
@@ -417,11 +420,11 @@ func (s *UsageLogRepoSuite) TestListByUserAndTimeRange() {
s.Require().Len(logs, 2) s.Require().Len(logs, 2)
} }
// --- ListByAPIKeyAndTimeRange --- // --- ListByApiKeyAndTimeRange ---
func (s *UsageLogRepoSuite) TestListByAPIKeyAndTimeRange() { func (s *UsageLogRepoSuite) TestListByApiKeyAndTimeRange() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "keytimerange@test.com"}) user := mustCreateUser(s.T(), s.client, &service.User{Email: "keytimerange@test.com"})
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-keytimerange", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-keytimerange", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-keytimerange"}) account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-keytimerange"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
@@ -431,8 +434,8 @@ func (s *UsageLogRepoSuite) TestListByAPIKeyAndTimeRange() {
startTime := base.Add(-1 * time.Hour) startTime := base.Add(-1 * time.Hour)
endTime := base.Add(2 * time.Hour) endTime := base.Add(2 * time.Hour)
logs, _, err := s.repo.ListByAPIKeyAndTimeRange(s.ctx, apiKey.ID, startTime, endTime) logs, _, err := s.repo.ListByApiKeyAndTimeRange(s.ctx, apiKey.ID, startTime, endTime)
s.Require().NoError(err, "ListByAPIKeyAndTimeRange") s.Require().NoError(err, "ListByApiKeyAndTimeRange")
s.Require().Len(logs, 2) s.Require().Len(logs, 2)
} }
@@ -440,7 +443,7 @@ func (s *UsageLogRepoSuite) TestListByAPIKeyAndTimeRange() {
func (s *UsageLogRepoSuite) TestListByAccountAndTimeRange() { func (s *UsageLogRepoSuite) TestListByAccountAndTimeRange() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "acctimerange@test.com"}) user := mustCreateUser(s.T(), s.client, &service.User{Email: "acctimerange@test.com"})
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-acctimerange", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-acctimerange", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-acctimerange"}) account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-acctimerange"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
@@ -459,7 +462,7 @@ func (s *UsageLogRepoSuite) TestListByAccountAndTimeRange() {
func (s *UsageLogRepoSuite) TestListByModelAndTimeRange() { func (s *UsageLogRepoSuite) TestListByModelAndTimeRange() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "modeltimerange@test.com"}) user := mustCreateUser(s.T(), s.client, &service.User{Email: "modeltimerange@test.com"})
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-modeltimerange", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-modeltimerange", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-modeltimerange"}) account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-modeltimerange"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
@@ -467,7 +470,7 @@ func (s *UsageLogRepoSuite) TestListByModelAndTimeRange() {
// Create logs with different models // Create logs with different models
log1 := &service.UsageLog{ log1 := &service.UsageLog{
UserID: user.ID, UserID: user.ID,
APIKeyID: apiKey.ID, ApiKeyID: apiKey.ID,
AccountID: account.ID, AccountID: account.ID,
Model: "claude-3-opus", Model: "claude-3-opus",
InputTokens: 10, InputTokens: 10,
@@ -480,7 +483,7 @@ func (s *UsageLogRepoSuite) TestListByModelAndTimeRange() {
log2 := &service.UsageLog{ log2 := &service.UsageLog{
UserID: user.ID, UserID: user.ID,
APIKeyID: apiKey.ID, ApiKeyID: apiKey.ID,
AccountID: account.ID, AccountID: account.ID,
Model: "claude-3-opus", Model: "claude-3-opus",
InputTokens: 15, InputTokens: 15,
@@ -493,7 +496,7 @@ func (s *UsageLogRepoSuite) TestListByModelAndTimeRange() {
log3 := &service.UsageLog{ log3 := &service.UsageLog{
UserID: user.ID, UserID: user.ID,
APIKeyID: apiKey.ID, ApiKeyID: apiKey.ID,
AccountID: account.ID, AccountID: account.ID,
Model: "claude-3-sonnet", Model: "claude-3-sonnet",
InputTokens: 20, InputTokens: 20,
@@ -515,7 +518,7 @@ func (s *UsageLogRepoSuite) TestListByModelAndTimeRange() {
func (s *UsageLogRepoSuite) TestGetAccountWindowStats() { func (s *UsageLogRepoSuite) TestGetAccountWindowStats() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "windowstats@test.com"}) user := mustCreateUser(s.T(), s.client, &service.User{Email: "windowstats@test.com"})
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-windowstats", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-windowstats", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-windowstats"}) account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-windowstats"})
now := time.Now() now := time.Now()
@@ -535,7 +538,7 @@ func (s *UsageLogRepoSuite) TestGetAccountWindowStats() {
func (s *UsageLogRepoSuite) TestGetUserUsageTrendByUserID() { func (s *UsageLogRepoSuite) TestGetUserUsageTrendByUserID() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "usertrend@test.com"}) user := mustCreateUser(s.T(), s.client, &service.User{Email: "usertrend@test.com"})
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-usertrend", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-usertrend", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-usertrend"}) account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-usertrend"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
@@ -552,7 +555,7 @@ func (s *UsageLogRepoSuite) TestGetUserUsageTrendByUserID() {
func (s *UsageLogRepoSuite) TestGetUserUsageTrendByUserID_HourlyGranularity() { func (s *UsageLogRepoSuite) TestGetUserUsageTrendByUserID_HourlyGranularity() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "usertrendhourly@test.com"}) user := mustCreateUser(s.T(), s.client, &service.User{Email: "usertrendhourly@test.com"})
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-usertrendhourly", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-usertrendhourly", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-usertrendhourly"}) account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-usertrendhourly"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
@@ -571,7 +574,7 @@ func (s *UsageLogRepoSuite) TestGetUserUsageTrendByUserID_HourlyGranularity() {
func (s *UsageLogRepoSuite) TestGetUserModelStats() { func (s *UsageLogRepoSuite) TestGetUserModelStats() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "modelstats@test.com"}) user := mustCreateUser(s.T(), s.client, &service.User{Email: "modelstats@test.com"})
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-modelstats", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-modelstats", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-modelstats"}) account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-modelstats"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
@@ -579,7 +582,7 @@ func (s *UsageLogRepoSuite) TestGetUserModelStats() {
// Create logs with different models // Create logs with different models
log1 := &service.UsageLog{ log1 := &service.UsageLog{
UserID: user.ID, UserID: user.ID,
APIKeyID: apiKey.ID, ApiKeyID: apiKey.ID,
AccountID: account.ID, AccountID: account.ID,
Model: "claude-3-opus", Model: "claude-3-opus",
InputTokens: 100, InputTokens: 100,
@@ -592,7 +595,7 @@ func (s *UsageLogRepoSuite) TestGetUserModelStats() {
log2 := &service.UsageLog{ log2 := &service.UsageLog{
UserID: user.ID, UserID: user.ID,
APIKeyID: apiKey.ID, ApiKeyID: apiKey.ID,
AccountID: account.ID, AccountID: account.ID,
Model: "claude-3-sonnet", Model: "claude-3-sonnet",
InputTokens: 50, InputTokens: 50,
@@ -618,7 +621,7 @@ func (s *UsageLogRepoSuite) TestGetUserModelStats() {
func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters() { func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "trendfilters@test.com"}) user := mustCreateUser(s.T(), s.client, &service.User{Email: "trendfilters@test.com"})
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-trendfilters", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-trendfilters", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-trendfilters"}) account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-trendfilters"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
@@ -646,7 +649,7 @@ func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters() {
func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters_HourlyGranularity() { func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters_HourlyGranularity() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "trendfilters-h@test.com"}) user := mustCreateUser(s.T(), s.client, &service.User{Email: "trendfilters-h@test.com"})
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-trendfilters-h", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-trendfilters-h", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-trendfilters-h"}) account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-trendfilters-h"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
@@ -665,14 +668,14 @@ func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters_HourlyGranularity() {
func (s *UsageLogRepoSuite) TestGetModelStatsWithFilters() { func (s *UsageLogRepoSuite) TestGetModelStatsWithFilters() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "modelfilters@test.com"}) user := mustCreateUser(s.T(), s.client, &service.User{Email: "modelfilters@test.com"})
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-modelfilters", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-modelfilters", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-modelfilters"}) account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-modelfilters"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
log1 := &service.UsageLog{ log1 := &service.UsageLog{
UserID: user.ID, UserID: user.ID,
APIKeyID: apiKey.ID, ApiKeyID: apiKey.ID,
AccountID: account.ID, AccountID: account.ID,
Model: "claude-3-opus", Model: "claude-3-opus",
InputTokens: 100, InputTokens: 100,
@@ -685,7 +688,7 @@ func (s *UsageLogRepoSuite) TestGetModelStatsWithFilters() {
log2 := &service.UsageLog{ log2 := &service.UsageLog{
UserID: user.ID, UserID: user.ID,
APIKeyID: apiKey.ID, ApiKeyID: apiKey.ID,
AccountID: account.ID, AccountID: account.ID,
Model: "claude-3-sonnet", Model: "claude-3-sonnet",
InputTokens: 50, InputTokens: 50,
@@ -719,7 +722,7 @@ func (s *UsageLogRepoSuite) TestGetModelStatsWithFilters() {
func (s *UsageLogRepoSuite) TestGetAccountUsageStats() { func (s *UsageLogRepoSuite) TestGetAccountUsageStats() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "accstats@test.com"}) user := mustCreateUser(s.T(), s.client, &service.User{Email: "accstats@test.com"})
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-accstats", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-accstats", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-accstats"}) account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-accstats"})
base := time.Date(2025, 1, 15, 0, 0, 0, 0, time.UTC) base := time.Date(2025, 1, 15, 0, 0, 0, 0, time.UTC)
@@ -727,7 +730,7 @@ func (s *UsageLogRepoSuite) TestGetAccountUsageStats() {
// Create logs on different days // Create logs on different days
log1 := &service.UsageLog{ log1 := &service.UsageLog{
UserID: user.ID, UserID: user.ID,
APIKeyID: apiKey.ID, ApiKeyID: apiKey.ID,
AccountID: account.ID, AccountID: account.ID,
Model: "claude-3-opus", Model: "claude-3-opus",
InputTokens: 100, InputTokens: 100,
@@ -740,7 +743,7 @@ func (s *UsageLogRepoSuite) TestGetAccountUsageStats() {
log2 := &service.UsageLog{ log2 := &service.UsageLog{
UserID: user.ID, UserID: user.ID,
APIKeyID: apiKey.ID, ApiKeyID: apiKey.ID,
AccountID: account.ID, AccountID: account.ID,
Model: "claude-3-sonnet", Model: "claude-3-sonnet",
InputTokens: 50, InputTokens: 50,
@@ -782,8 +785,8 @@ func (s *UsageLogRepoSuite) TestGetAccountUsageStats_EmptyRange() {
func (s *UsageLogRepoSuite) TestGetUserUsageTrend() { func (s *UsageLogRepoSuite) TestGetUserUsageTrend() {
user1 := mustCreateUser(s.T(), s.client, &service.User{Email: "usertrend1@test.com"}) user1 := mustCreateUser(s.T(), s.client, &service.User{Email: "usertrend1@test.com"})
user2 := mustCreateUser(s.T(), s.client, &service.User{Email: "usertrend2@test.com"}) user2 := mustCreateUser(s.T(), s.client, &service.User{Email: "usertrend2@test.com"})
apiKey1 := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user1.ID, Key: "sk-usertrend1", Name: "k1"}) apiKey1 := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user1.ID, Key: "sk-usertrend1", Name: "k1"})
apiKey2 := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user2.ID, Key: "sk-usertrend2", Name: "k2"}) apiKey2 := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user2.ID, Key: "sk-usertrend2", Name: "k2"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-usertrends"}) account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-usertrends"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
@@ -799,12 +802,12 @@ func (s *UsageLogRepoSuite) TestGetUserUsageTrend() {
s.Require().GreaterOrEqual(len(trend), 2) s.Require().GreaterOrEqual(len(trend), 2)
} }
// --- GetAPIKeyUsageTrend --- // --- GetApiKeyUsageTrend ---
func (s *UsageLogRepoSuite) TestGetAPIKeyUsageTrend() { func (s *UsageLogRepoSuite) TestGetApiKeyUsageTrend() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "keytrend@test.com"}) user := mustCreateUser(s.T(), s.client, &service.User{Email: "keytrend@test.com"})
apiKey1 := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-keytrend1", Name: "k1"}) apiKey1 := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-keytrend1", Name: "k1"})
apiKey2 := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-keytrend2", Name: "k2"}) apiKey2 := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-keytrend2", Name: "k2"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-keytrends"}) account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-keytrends"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
@@ -815,14 +818,14 @@ func (s *UsageLogRepoSuite) TestGetAPIKeyUsageTrend() {
startTime := base.Add(-1 * time.Hour) startTime := base.Add(-1 * time.Hour)
endTime := base.Add(48 * time.Hour) endTime := base.Add(48 * time.Hour)
trend, err := s.repo.GetAPIKeyUsageTrend(s.ctx, startTime, endTime, "day", 10) trend, err := s.repo.GetApiKeyUsageTrend(s.ctx, startTime, endTime, "day", 10)
s.Require().NoError(err, "GetAPIKeyUsageTrend") s.Require().NoError(err, "GetApiKeyUsageTrend")
s.Require().GreaterOrEqual(len(trend), 2) s.Require().GreaterOrEqual(len(trend), 2)
} }
func (s *UsageLogRepoSuite) TestGetAPIKeyUsageTrend_HourlyGranularity() { func (s *UsageLogRepoSuite) TestGetApiKeyUsageTrend_HourlyGranularity() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "keytrendh@test.com"}) user := mustCreateUser(s.T(), s.client, &service.User{Email: "keytrendh@test.com"})
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-keytrendh", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-keytrendh", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-keytrendh"}) account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-keytrendh"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
@@ -832,21 +835,21 @@ func (s *UsageLogRepoSuite) TestGetAPIKeyUsageTrend_HourlyGranularity() {
startTime := base.Add(-1 * time.Hour) startTime := base.Add(-1 * time.Hour)
endTime := base.Add(3 * time.Hour) endTime := base.Add(3 * time.Hour)
trend, err := s.repo.GetAPIKeyUsageTrend(s.ctx, startTime, endTime, "hour", 10) trend, err := s.repo.GetApiKeyUsageTrend(s.ctx, startTime, endTime, "hour", 10)
s.Require().NoError(err, "GetAPIKeyUsageTrend hourly") s.Require().NoError(err, "GetApiKeyUsageTrend hourly")
s.Require().Len(trend, 2) s.Require().Len(trend, 2)
} }
// --- ListWithFilters (additional filter tests) --- // --- ListWithFilters (additional filter tests) ---
func (s *UsageLogRepoSuite) TestListWithFilters_APIKeyFilter() { func (s *UsageLogRepoSuite) TestListWithFilters_ApiKeyFilter() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "filterskey@test.com"}) user := mustCreateUser(s.T(), s.client, &service.User{Email: "filterskey@test.com"})
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-filterskey", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-filterskey", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-filterskey"}) account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-filterskey"})
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now()) s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
filters := usagestats.UsageLogFilters{APIKeyID: apiKey.ID} filters := usagestats.UsageLogFilters{ApiKeyID: apiKey.ID}
logs, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, filters) logs, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, filters)
s.Require().NoError(err, "ListWithFilters apiKey") s.Require().NoError(err, "ListWithFilters apiKey")
s.Require().Len(logs, 1) s.Require().Len(logs, 1)
@@ -855,7 +858,7 @@ func (s *UsageLogRepoSuite) TestListWithFilters_APIKeyFilter() {
func (s *UsageLogRepoSuite) TestListWithFilters_TimeRange() { func (s *UsageLogRepoSuite) TestListWithFilters_TimeRange() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "filterstime@test.com"}) user := mustCreateUser(s.T(), s.client, &service.User{Email: "filterstime@test.com"})
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-filterstime", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-filterstime", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-filterstime"}) account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-filterstime"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
@@ -874,7 +877,7 @@ func (s *UsageLogRepoSuite) TestListWithFilters_TimeRange() {
func (s *UsageLogRepoSuite) TestListWithFilters_CombinedFilters() { func (s *UsageLogRepoSuite) TestListWithFilters_CombinedFilters() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "filterscombined@test.com"}) user := mustCreateUser(s.T(), s.client, &service.User{Email: "filterscombined@test.com"})
apiKey := mustCreateAPIKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-filterscombined", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.client, &service.ApiKey{UserID: user.ID, Key: "sk-filterscombined", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-filterscombined"}) account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-filterscombined"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
@@ -885,7 +888,7 @@ func (s *UsageLogRepoSuite) TestListWithFilters_CombinedFilters() {
endTime := base.Add(2 * time.Hour) endTime := base.Add(2 * time.Hour)
filters := usagestats.UsageLogFilters{ filters := usagestats.UsageLogFilters{
UserID: user.ID, UserID: user.ID,
APIKeyID: apiKey.ID, ApiKeyID: apiKey.ID,
StartTime: &startTime, StartTime: &startTime,
EndTime: &endTime, EndTime: &endTime,
} }

View File

@@ -91,7 +91,7 @@ func TestAPIContracts(t *testing.T) {
name: "GET /api/v1/keys (paginated)", name: "GET /api/v1/keys (paginated)",
setup: func(t *testing.T, deps *contractDeps) { setup: func(t *testing.T, deps *contractDeps) {
t.Helper() t.Helper()
deps.apiKeyRepo.MustSeed(&service.APIKey{ deps.apiKeyRepo.MustSeed(&service.ApiKey{
ID: 100, ID: 100,
UserID: 1, UserID: 1,
Key: "sk_custom_1234567890", Key: "sk_custom_1234567890",
@@ -135,7 +135,7 @@ func TestAPIContracts(t *testing.T) {
{ {
ID: 1, ID: 1,
UserID: 1, UserID: 1,
APIKeyID: 100, ApiKeyID: 100,
AccountID: 200, AccountID: 200,
Model: "claude-3", Model: "claude-3",
InputTokens: 10, InputTokens: 10,
@@ -150,7 +150,7 @@ func TestAPIContracts(t *testing.T) {
{ {
ID: 2, ID: 2,
UserID: 1, UserID: 1,
APIKeyID: 100, ApiKeyID: 100,
AccountID: 200, AccountID: 200,
Model: "claude-3", Model: "claude-3",
InputTokens: 5, InputTokens: 5,
@@ -188,7 +188,7 @@ func TestAPIContracts(t *testing.T) {
{ {
ID: 1, ID: 1,
UserID: 1, UserID: 1,
APIKeyID: 100, ApiKeyID: 100,
AccountID: 200, AccountID: 200,
RequestID: "req_123", RequestID: "req_123",
Model: "claude-3", Model: "claude-3",
@@ -259,13 +259,13 @@ func TestAPIContracts(t *testing.T) {
service.SettingKeyRegistrationEnabled: "true", service.SettingKeyRegistrationEnabled: "true",
service.SettingKeyEmailVerifyEnabled: "false", service.SettingKeyEmailVerifyEnabled: "false",
service.SettingKeySMTPHost: "smtp.example.com", service.SettingKeySmtpHost: "smtp.example.com",
service.SettingKeySMTPPort: "587", service.SettingKeySmtpPort: "587",
service.SettingKeySMTPUsername: "user", service.SettingKeySmtpUsername: "user",
service.SettingKeySMTPPassword: "secret", service.SettingKeySmtpPassword: "secret",
service.SettingKeySMTPFrom: "no-reply@example.com", service.SettingKeySmtpFrom: "no-reply@example.com",
service.SettingKeySMTPFromName: "Sub2API", service.SettingKeySmtpFromName: "Sub2API",
service.SettingKeySMTPUseTLS: "true", service.SettingKeySmtpUseTLS: "true",
service.SettingKeyTurnstileEnabled: "true", service.SettingKeyTurnstileEnabled: "true",
service.SettingKeyTurnstileSiteKey: "site-key", service.SettingKeyTurnstileSiteKey: "site-key",
@@ -274,9 +274,9 @@ func TestAPIContracts(t *testing.T) {
service.SettingKeySiteName: "Sub2API", service.SettingKeySiteName: "Sub2API",
service.SettingKeySiteLogo: "", service.SettingKeySiteLogo: "",
service.SettingKeySiteSubtitle: "Subtitle", service.SettingKeySiteSubtitle: "Subtitle",
service.SettingKeyAPIBaseURL: "https://api.example.com", service.SettingKeyApiBaseUrl: "https://api.example.com",
service.SettingKeyContactInfo: "support", service.SettingKeyContactInfo: "support",
service.SettingKeyDocURL: "https://docs.example.com", service.SettingKeyDocUrl: "https://docs.example.com",
service.SettingKeyDefaultConcurrency: "5", service.SettingKeyDefaultConcurrency: "5",
service.SettingKeyDefaultBalance: "1.25", service.SettingKeyDefaultBalance: "1.25",
@@ -308,7 +308,12 @@ func TestAPIContracts(t *testing.T) {
"contact_info": "support", "contact_info": "support",
"doc_url": "https://docs.example.com", "doc_url": "https://docs.example.com",
"default_concurrency": 5, "default_concurrency": 5,
"default_balance": 1.25 "default_balance": 1.25,
"enable_model_fallback": false,
"fallback_model_anthropic": "",
"fallback_model_antigravity": "",
"fallback_model_gemini": "",
"fallback_model_openai": ""
} }
}`, }`,
}, },
@@ -331,7 +336,7 @@ func TestAPIContracts(t *testing.T) {
type contractDeps struct { type contractDeps struct {
now time.Time now time.Time
router http.Handler router http.Handler
apiKeyRepo *stubAPIKeyRepo apiKeyRepo *stubApiKeyRepo
usageRepo *stubUsageLogRepo usageRepo *stubUsageLogRepo
settingRepo *stubSettingRepo settingRepo *stubSettingRepo
} }
@@ -359,20 +364,20 @@ func newContractDeps(t *testing.T) *contractDeps {
}, },
} }
apiKeyRepo := newStubAPIKeyRepo(now) apiKeyRepo := newStubApiKeyRepo(now)
apiKeyCache := stubAPIKeyCache{} apiKeyCache := stubApiKeyCache{}
groupRepo := stubGroupRepo{} groupRepo := stubGroupRepo{}
userSubRepo := stubUserSubscriptionRepo{} userSubRepo := stubUserSubscriptionRepo{}
cfg := &config.Config{ cfg := &config.Config{
Default: config.DefaultConfig{ Default: config.DefaultConfig{
APIKeyPrefix: "sk-", ApiKeyPrefix: "sk-",
}, },
RunMode: config.RunModeStandard, RunMode: config.RunModeStandard,
} }
userService := service.NewUserService(userRepo) userService := service.NewUserService(userRepo)
apiKeyService := service.NewAPIKeyService(apiKeyRepo, userRepo, groupRepo, userSubRepo, apiKeyCache, cfg) apiKeyService := service.NewApiKeyService(apiKeyRepo, userRepo, groupRepo, userSubRepo, apiKeyCache, cfg)
usageRepo := newStubUsageLogRepo() usageRepo := newStubUsageLogRepo()
usageService := service.NewUsageService(usageRepo, userRepo) usageService := service.NewUsageService(usageRepo, userRepo)
@@ -525,25 +530,25 @@ func (r *stubUserRepo) RemoveGroupFromAllowedGroups(ctx context.Context, groupID
return 0, errors.New("not implemented") return 0, errors.New("not implemented")
} }
type stubAPIKeyCache struct{} type stubApiKeyCache struct{}
func (stubAPIKeyCache) GetCreateAttemptCount(ctx context.Context, userID int64) (int, error) { func (stubApiKeyCache) GetCreateAttemptCount(ctx context.Context, userID int64) (int, error) {
return 0, nil return 0, nil
} }
func (stubAPIKeyCache) IncrementCreateAttemptCount(ctx context.Context, userID int64) error { func (stubApiKeyCache) IncrementCreateAttemptCount(ctx context.Context, userID int64) error {
return nil return nil
} }
func (stubAPIKeyCache) DeleteCreateAttemptCount(ctx context.Context, userID int64) error { func (stubApiKeyCache) DeleteCreateAttemptCount(ctx context.Context, userID int64) error {
return nil return nil
} }
func (stubAPIKeyCache) IncrementDailyUsage(ctx context.Context, apiKey string) error { func (stubApiKeyCache) IncrementDailyUsage(ctx context.Context, apiKey string) error {
return nil return nil
} }
func (stubAPIKeyCache) SetDailyUsageExpiry(ctx context.Context, apiKey string, ttl time.Duration) error { func (stubApiKeyCache) SetDailyUsageExpiry(ctx context.Context, apiKey string, ttl time.Duration) error {
return nil return nil
} }
@@ -660,24 +665,24 @@ func (stubUserSubscriptionRepo) BatchUpdateExpiredStatus(ctx context.Context) (i
return 0, errors.New("not implemented") return 0, errors.New("not implemented")
} }
type stubAPIKeyRepo struct { type stubApiKeyRepo struct {
now time.Time now time.Time
nextID int64 nextID int64
byID map[int64]*service.APIKey byID map[int64]*service.ApiKey
byKey map[string]*service.APIKey byKey map[string]*service.ApiKey
} }
func newStubAPIKeyRepo(now time.Time) *stubAPIKeyRepo { func newStubApiKeyRepo(now time.Time) *stubApiKeyRepo {
return &stubAPIKeyRepo{ return &stubApiKeyRepo{
now: now, now: now,
nextID: 100, nextID: 100,
byID: make(map[int64]*service.APIKey), byID: make(map[int64]*service.ApiKey),
byKey: make(map[string]*service.APIKey), byKey: make(map[string]*service.ApiKey),
} }
} }
func (r *stubAPIKeyRepo) MustSeed(key *service.APIKey) { func (r *stubApiKeyRepo) MustSeed(key *service.ApiKey) {
if key == nil { if key == nil {
return return
} }
@@ -686,7 +691,7 @@ func (r *stubAPIKeyRepo) MustSeed(key *service.APIKey) {
r.byKey[clone.Key] = &clone r.byKey[clone.Key] = &clone
} }
func (r *stubAPIKeyRepo) Create(ctx context.Context, key *service.APIKey) error { func (r *stubApiKeyRepo) Create(ctx context.Context, key *service.ApiKey) error {
if key == nil { if key == nil {
return errors.New("nil key") return errors.New("nil key")
} }
@@ -706,38 +711,38 @@ func (r *stubAPIKeyRepo) Create(ctx context.Context, key *service.APIKey) error
return nil return nil
} }
func (r *stubAPIKeyRepo) GetByID(ctx context.Context, id int64) (*service.APIKey, error) { func (r *stubApiKeyRepo) GetByID(ctx context.Context, id int64) (*service.ApiKey, error) {
key, ok := r.byID[id] key, ok := r.byID[id]
if !ok { if !ok {
return nil, service.ErrAPIKeyNotFound return nil, service.ErrApiKeyNotFound
} }
clone := *key clone := *key
return &clone, nil return &clone, nil
} }
func (r *stubAPIKeyRepo) GetOwnerID(ctx context.Context, id int64) (int64, error) { func (r *stubApiKeyRepo) GetOwnerID(ctx context.Context, id int64) (int64, error) {
key, ok := r.byID[id] key, ok := r.byID[id]
if !ok { if !ok {
return 0, service.ErrAPIKeyNotFound return 0, service.ErrApiKeyNotFound
} }
return key.UserID, nil return key.UserID, nil
} }
func (r *stubAPIKeyRepo) GetByKey(ctx context.Context, key string) (*service.APIKey, error) { func (r *stubApiKeyRepo) GetByKey(ctx context.Context, key string) (*service.ApiKey, error) {
found, ok := r.byKey[key] found, ok := r.byKey[key]
if !ok { if !ok {
return nil, service.ErrAPIKeyNotFound return nil, service.ErrApiKeyNotFound
} }
clone := *found clone := *found
return &clone, nil return &clone, nil
} }
func (r *stubAPIKeyRepo) Update(ctx context.Context, key *service.APIKey) error { func (r *stubApiKeyRepo) Update(ctx context.Context, key *service.ApiKey) error {
if key == nil { if key == nil {
return errors.New("nil key") return errors.New("nil key")
} }
if _, ok := r.byID[key.ID]; !ok { if _, ok := r.byID[key.ID]; !ok {
return service.ErrAPIKeyNotFound return service.ErrApiKeyNotFound
} }
if key.UpdatedAt.IsZero() { if key.UpdatedAt.IsZero() {
key.UpdatedAt = r.now key.UpdatedAt = r.now
@@ -748,17 +753,17 @@ func (r *stubAPIKeyRepo) Update(ctx context.Context, key *service.APIKey) error
return nil return nil
} }
func (r *stubAPIKeyRepo) Delete(ctx context.Context, id int64) error { func (r *stubApiKeyRepo) Delete(ctx context.Context, id int64) error {
key, ok := r.byID[id] key, ok := r.byID[id]
if !ok { if !ok {
return service.ErrAPIKeyNotFound return service.ErrApiKeyNotFound
} }
delete(r.byID, id) delete(r.byID, id)
delete(r.byKey, key.Key) delete(r.byKey, key.Key)
return nil return nil
} }
func (r *stubAPIKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) { func (r *stubApiKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
ids := make([]int64, 0, len(r.byID)) ids := make([]int64, 0, len(r.byID))
for id := range r.byID { for id := range r.byID {
if r.byID[id].UserID == userID { if r.byID[id].UserID == userID {
@@ -776,7 +781,7 @@ func (r *stubAPIKeyRepo) ListByUserID(ctx context.Context, userID int64, params
end = len(ids) end = len(ids)
} }
out := make([]service.APIKey, 0, end-start) out := make([]service.ApiKey, 0, end-start)
for _, id := range ids[start:end] { for _, id := range ids[start:end] {
clone := *r.byID[id] clone := *r.byID[id]
out = append(out, clone) out = append(out, clone)
@@ -796,7 +801,7 @@ func (r *stubAPIKeyRepo) ListByUserID(ctx context.Context, userID int64, params
}, nil }, nil
} }
func (r *stubAPIKeyRepo) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) { func (r *stubApiKeyRepo) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
if len(apiKeyIDs) == 0 { if len(apiKeyIDs) == 0 {
return []int64{}, nil return []int64{}, nil
} }
@@ -815,7 +820,7 @@ func (r *stubAPIKeyRepo) VerifyOwnership(ctx context.Context, userID int64, apiK
return out, nil return out, nil
} }
func (r *stubAPIKeyRepo) CountByUserID(ctx context.Context, userID int64) (int64, error) { func (r *stubApiKeyRepo) CountByUserID(ctx context.Context, userID int64) (int64, error) {
var count int64 var count int64
for _, key := range r.byID { for _, key := range r.byID {
if key.UserID == userID { if key.UserID == userID {
@@ -825,24 +830,24 @@ func (r *stubAPIKeyRepo) CountByUserID(ctx context.Context, userID int64) (int64
return count, nil return count, nil
} }
func (r *stubAPIKeyRepo) ExistsByKey(ctx context.Context, key string) (bool, error) { func (r *stubApiKeyRepo) ExistsByKey(ctx context.Context, key string) (bool, error) {
_, ok := r.byKey[key] _, ok := r.byKey[key]
return ok, nil return ok, nil
} }
func (r *stubAPIKeyRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) { func (r *stubApiKeyRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented") return nil, nil, errors.New("not implemented")
} }
func (r *stubAPIKeyRepo) SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.APIKey, error) { func (r *stubApiKeyRepo) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.ApiKey, error) {
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
} }
func (r *stubAPIKeyRepo) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) { func (r *stubApiKeyRepo) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
return 0, errors.New("not implemented") return 0, errors.New("not implemented")
} }
func (r *stubAPIKeyRepo) CountByGroupID(ctx context.Context, groupID int64) (int64, error) { func (r *stubApiKeyRepo) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
return 0, errors.New("not implemented") return 0, errors.New("not implemented")
} }
@@ -877,7 +882,7 @@ func (r *stubUsageLogRepo) ListByUser(ctx context.Context, userID int64, params
return out, paginationResult(total, params), nil return out, paginationResult(total, params), nil
} }
func (r *stubUsageLogRepo) ListByAPIKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) { func (r *stubUsageLogRepo) ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented") return nil, nil, errors.New("not implemented")
} }
@@ -890,7 +895,7 @@ func (r *stubUsageLogRepo) ListByUserAndTimeRange(ctx context.Context, userID in
return logs, paginationResult(int64(len(logs)), pagination.PaginationParams{Page: 1, PageSize: 100}), nil return logs, paginationResult(int64(len(logs)), pagination.PaginationParams{Page: 1, PageSize: 100}), nil
} }
func (r *stubUsageLogRepo) ListByAPIKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) { func (r *stubUsageLogRepo) ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented") return nil, nil, errors.New("not implemented")
} }
@@ -922,7 +927,7 @@ func (r *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTi
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
} }
func (r *stubUsageLogRepo) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error) { func (r *stubUsageLogRepo) GetApiKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.ApiKeyUsageTrendPoint, error) {
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
} }
@@ -975,7 +980,7 @@ func (r *stubUsageLogRepo) GetUserStatsAggregated(ctx context.Context, userID in
}, nil }, nil
} }
func (r *stubUsageLogRepo) GetAPIKeyStatsAggregated(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) { func (r *stubUsageLogRepo) GetApiKeyStatsAggregated(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
} }
@@ -995,7 +1000,7 @@ func (r *stubUsageLogRepo) GetBatchUserUsageStats(ctx context.Context, userIDs [
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
} }
func (r *stubUsageLogRepo) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) { func (r *stubUsageLogRepo) GetBatchApiKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchApiKeyUsageStats, error) {
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
} }
@@ -1017,8 +1022,8 @@ func (r *stubUsageLogRepo) ListWithFilters(ctx context.Context, params paginatio
// Apply filters // Apply filters
var filtered []service.UsageLog var filtered []service.UsageLog
for _, log := range logs { for _, log := range logs {
// Apply APIKeyID filter // Apply ApiKeyID filter
if filters.APIKeyID > 0 && log.APIKeyID != filters.APIKeyID { if filters.ApiKeyID > 0 && log.ApiKeyID != filters.ApiKeyID {
continue continue
} }
// Apply Model filter // Apply Model filter
@@ -1151,8 +1156,8 @@ func paginationResult(total int64, params pagination.PaginationParams) *paginati
// Ensure compile-time interface compliance. // Ensure compile-time interface compliance.
var ( var (
_ service.UserRepository = (*stubUserRepo)(nil) _ service.UserRepository = (*stubUserRepo)(nil)
_ service.APIKeyRepository = (*stubAPIKeyRepo)(nil) _ service.ApiKeyRepository = (*stubApiKeyRepo)(nil)
_ service.APIKeyCache = (*stubAPIKeyCache)(nil) _ service.ApiKeyCache = (*stubApiKeyCache)(nil)
_ service.GroupRepository = (*stubGroupRepo)(nil) _ service.GroupRepository = (*stubGroupRepo)(nil)
_ service.UserSubscriptionRepository = (*stubUserSubscriptionRepo)(nil) _ service.UserSubscriptionRepository = (*stubUserSubscriptionRepo)(nil)
_ service.UsageLogRepository = (*stubUsageLogRepo)(nil) _ service.UsageLogRepository = (*stubUsageLogRepo)(nil)

View File

@@ -121,6 +121,12 @@ func (m *mockAccountRepoForGemini) SetRateLimited(ctx context.Context, id int64,
func (m *mockAccountRepoForGemini) SetOverloaded(ctx context.Context, id int64, until time.Time) error { func (m *mockAccountRepoForGemini) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
return nil return nil
} }
func (m *mockAccountRepoForGemini) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error {
return nil
}
func (m *mockAccountRepoForGemini) ClearTempUnschedulable(ctx context.Context, id int64) error {
return nil
}
func (m *mockAccountRepoForGemini) ClearRateLimit(ctx context.Context, id int64) error { return nil } func (m *mockAccountRepoForGemini) ClearRateLimit(ctx context.Context, id int64) error { return nil }
func (m *mockAccountRepoForGemini) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error { func (m *mockAccountRepoForGemini) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
return nil return nil
@@ -275,7 +281,7 @@ func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_OAuthPr
repo := &mockAccountRepoForGemini{ repo := &mockAccountRepoForGemini{
accounts: []Account{ accounts: []Account{
{ID: 1, Platform: PlatformGemini, Type: AccountTypeAPIKey, Priority: 1, Status: StatusActive, Schedulable: true, LastUsedAt: nil}, {ID: 1, Platform: PlatformGemini, Type: AccountTypeApiKey, Priority: 1, Status: StatusActive, Schedulable: true, LastUsedAt: nil},
{ID: 2, Platform: PlatformGemini, Type: AccountTypeOAuth, Priority: 1, Status: StatusActive, Schedulable: true, LastUsedAt: nil}, {ID: 2, Platform: PlatformGemini, Type: AccountTypeOAuth, Priority: 1, Status: StatusActive, Schedulable: true, LastUsedAt: nil},
}, },
accountsByID: map[int64]*Account{}, accountsByID: map[int64]*Account{},

View File

@@ -197,7 +197,7 @@ func TestClaudeTokenRefresher_CanRefresh(t *testing.T) {
{ {
name: "anthropic api-key - cannot refresh", name: "anthropic api-key - cannot refresh",
platform: PlatformAnthropic, platform: PlatformAnthropic,
accType: AccountTypeAPIKey, accType: AccountTypeApiKey,
want: false, want: false,
}, },
{ {

View File

@@ -0,0 +1,104 @@
-- Ops monitoring: pre-aggregation tables for dashboard queries
--
-- Problem:
-- The ops dashboard currently runs percentile_cont + GROUP BY queries over large raw tables
-- (usage_logs, ops_error_logs). These will get slower as data grows.
--
-- This migration adds schema-only aggregation tables that can be populated by a future background job.
-- No triggers/functions/jobs are created here (schema only).
-- ============================================
-- Hourly aggregates (per provider/platform)
-- ============================================
CREATE TABLE IF NOT EXISTS ops_metrics_hourly (
-- Start of the hour bucket (recommended: UTC).
bucket_start TIMESTAMPTZ NOT NULL,
-- Provider/platform label (e.g. anthropic/openai/gemini). Mirrors ops_* queries that GROUP BY platform.
platform VARCHAR(50) NOT NULL,
-- Traffic counts (use these to compute rates reliably across ranges).
request_count BIGINT NOT NULL DEFAULT 0,
success_count BIGINT NOT NULL DEFAULT 0,
error_count BIGINT NOT NULL DEFAULT 0,
-- Error breakdown used by provider health UI.
error_4xx_count BIGINT NOT NULL DEFAULT 0,
error_5xx_count BIGINT NOT NULL DEFAULT 0,
timeout_count BIGINT NOT NULL DEFAULT 0,
-- Latency aggregates (ms).
avg_latency_ms DOUBLE PRECISION,
p99_latency_ms DOUBLE PRECISION,
-- Convenience rate (percentage, 0-100). Still keep counts as source of truth.
error_rate DOUBLE PRECISION NOT NULL DEFAULT 0,
-- When this row was last (re)computed by the background job.
computed_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
PRIMARY KEY (bucket_start, platform)
);
CREATE INDEX IF NOT EXISTS idx_ops_metrics_hourly_platform_bucket_start
ON ops_metrics_hourly (platform, bucket_start DESC);
COMMENT ON TABLE ops_metrics_hourly IS 'Pre-aggregated hourly ops metrics by provider/platform to speed up dashboard queries.';
COMMENT ON COLUMN ops_metrics_hourly.bucket_start IS 'Start timestamp of the hour bucket (recommended UTC).';
COMMENT ON COLUMN ops_metrics_hourly.platform IS 'Provider/platform label (anthropic/openai/gemini, etc).';
COMMENT ON COLUMN ops_metrics_hourly.error_rate IS 'Error rate percentage for the bucket (0-100). Counts remain the source of truth.';
COMMENT ON COLUMN ops_metrics_hourly.computed_at IS 'When the row was last computed/refreshed.';
-- ============================================
-- Daily aggregates (per provider/platform)
-- ============================================
CREATE TABLE IF NOT EXISTS ops_metrics_daily (
-- Day bucket (recommended: UTC date).
bucket_date DATE NOT NULL,
platform VARCHAR(50) NOT NULL,
request_count BIGINT NOT NULL DEFAULT 0,
success_count BIGINT NOT NULL DEFAULT 0,
error_count BIGINT NOT NULL DEFAULT 0,
error_4xx_count BIGINT NOT NULL DEFAULT 0,
error_5xx_count BIGINT NOT NULL DEFAULT 0,
timeout_count BIGINT NOT NULL DEFAULT 0,
avg_latency_ms DOUBLE PRECISION,
p99_latency_ms DOUBLE PRECISION,
error_rate DOUBLE PRECISION NOT NULL DEFAULT 0,
computed_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
PRIMARY KEY (bucket_date, platform)
);
CREATE INDEX IF NOT EXISTS idx_ops_metrics_daily_platform_bucket_date
ON ops_metrics_daily (platform, bucket_date DESC);
COMMENT ON TABLE ops_metrics_daily IS 'Pre-aggregated daily ops metrics by provider/platform for longer-term trends.';
COMMENT ON COLUMN ops_metrics_daily.bucket_date IS 'UTC date of the day bucket (recommended).';
-- ============================================
-- Population strategy (future background job)
-- ============================================
--
-- Suggested approach:
-- 1) Compute hourly buckets from raw logs using UTC time-bucketing, then UPSERT into ops_metrics_hourly.
-- 2) Compute daily buckets either directly from raw logs or by rolling up ops_metrics_hourly.
--
-- Notes:
-- - Ensure the job uses a consistent timezone (recommended: SET TIME ZONE ''UTC'') to avoid bucket drift.
-- - Derive the provider/platform similarly to existing dashboard queries:
-- usage_logs: COALESCE(NULLIF(groups.platform, ''), accounts.platform, '')
-- ops_error_logs: COALESCE(NULLIF(ops_error_logs.platform, ''), groups.platform, accounts.platform, '')
-- - Keep request_count/success_count/error_count as the authoritative values; compute error_rate from counts.
--
-- Example (hourly) shape (pseudo-SQL):
-- INSERT INTO ops_metrics_hourly (...)
-- SELECT date_trunc('hour', created_at) AS bucket_start, platform, ...
-- FROM (/* aggregate usage_logs + ops_error_logs */) s
-- ON CONFLICT (bucket_start, platform) DO UPDATE SET ...;

View File

@@ -0,0 +1,58 @@
-- 027_usage_billing_consistency.sql
-- Ensure usage_logs idempotency (request_id, api_key_id) and add reconciliation infrastructure.
-- -----------------------------------------------------------------------------
-- 1) Normalize legacy request_id values
-- -----------------------------------------------------------------------------
-- Historically request_id may be inserted as empty string. Convert it to NULL so
-- the upcoming unique index does not break on repeated "" values.
UPDATE usage_logs
SET request_id = NULL
WHERE request_id = '';
-- If duplicates already exist for the same (request_id, api_key_id), keep the
-- first row and NULL-out request_id for the rest so the unique index can be
-- created without deleting historical logs.
WITH ranked AS (
SELECT
id,
ROW_NUMBER() OVER (PARTITION BY api_key_id, request_id ORDER BY id) AS rn
FROM usage_logs
WHERE request_id IS NOT NULL
)
UPDATE usage_logs ul
SET request_id = NULL
FROM ranked r
WHERE ul.id = r.id
AND r.rn > 1;
-- -----------------------------------------------------------------------------
-- 2) Idempotency constraint for usage_logs
-- -----------------------------------------------------------------------------
CREATE UNIQUE INDEX IF NOT EXISTS idx_usage_logs_request_id_api_key_unique
ON usage_logs (request_id, api_key_id);
-- -----------------------------------------------------------------------------
-- 3) Reconciliation infrastructure: billing ledger for usage charges
-- -----------------------------------------------------------------------------
CREATE TABLE IF NOT EXISTS billing_usage_entries (
id BIGSERIAL PRIMARY KEY,
usage_log_id BIGINT NOT NULL REFERENCES usage_logs(id) ON DELETE CASCADE,
user_id BIGINT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
api_key_id BIGINT NOT NULL REFERENCES api_keys(id) ON DELETE CASCADE,
subscription_id BIGINT REFERENCES user_subscriptions(id) ON DELETE SET NULL,
billing_type SMALLINT NOT NULL,
applied BOOLEAN NOT NULL DEFAULT TRUE,
delta_usd DECIMAL(20, 10) NOT NULL DEFAULT 0,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE UNIQUE INDEX IF NOT EXISTS billing_usage_entries_usage_log_id_unique
ON billing_usage_entries (usage_log_id);
CREATE INDEX IF NOT EXISTS idx_billing_usage_entries_user_time
ON billing_usage_entries (user_id, created_at);
CREATE INDEX IF NOT EXISTS idx_billing_usage_entries_created_at
ON billing_usage_entries (created_at);