Files
sub2api/backend/internal/service/openai_oauth_service_state_test.go
yangjianbo 900cce20a1 feat(sora): 对齐 Sora OAuth 流程并隔离网关请求路径
- 新增并接通 Sora 专用 OAuth 接口与 ST/RT 换取能力
- 完成前端 Sora 授权、RT/ST 手动导入与账号创建流程
- 强化 Sora token 恢复、转发日志与网关路由隔离行为
- 补充后端服务层与路由层相关测试覆盖

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-19 08:02:56 +08:00

103 lines
2.9 KiB
Go

package service
import (
"context"
"errors"
"sync/atomic"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/stretchr/testify/require"
)
type openaiOAuthClientStateStub struct {
exchangeCalled int32
}
func (s *openaiOAuthClientStateStub) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL string) (*openai.TokenResponse, error) {
atomic.AddInt32(&s.exchangeCalled, 1)
return &openai.TokenResponse{
AccessToken: "at",
RefreshToken: "rt",
ExpiresIn: 3600,
}, nil
}
func (s *openaiOAuthClientStateStub) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error) {
return nil, errors.New("not implemented")
}
func (s *openaiOAuthClientStateStub) RefreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL string, clientID string) (*openai.TokenResponse, error) {
return s.RefreshToken(ctx, refreshToken, proxyURL)
}
func TestOpenAIOAuthService_ExchangeCode_StateRequired(t *testing.T) {
client := &openaiOAuthClientStateStub{}
svc := NewOpenAIOAuthService(nil, client)
defer svc.Stop()
svc.sessionStore.Set("sid", &openai.OAuthSession{
State: "expected-state",
CodeVerifier: "verifier",
RedirectURI: openai.DefaultRedirectURI,
CreatedAt: time.Now(),
})
_, err := svc.ExchangeCode(context.Background(), &OpenAIExchangeCodeInput{
SessionID: "sid",
Code: "auth-code",
})
require.Error(t, err)
require.Contains(t, err.Error(), "oauth state is required")
require.Equal(t, int32(0), atomic.LoadInt32(&client.exchangeCalled))
}
func TestOpenAIOAuthService_ExchangeCode_StateMismatch(t *testing.T) {
client := &openaiOAuthClientStateStub{}
svc := NewOpenAIOAuthService(nil, client)
defer svc.Stop()
svc.sessionStore.Set("sid", &openai.OAuthSession{
State: "expected-state",
CodeVerifier: "verifier",
RedirectURI: openai.DefaultRedirectURI,
CreatedAt: time.Now(),
})
_, err := svc.ExchangeCode(context.Background(), &OpenAIExchangeCodeInput{
SessionID: "sid",
Code: "auth-code",
State: "wrong-state",
})
require.Error(t, err)
require.Contains(t, err.Error(), "invalid oauth state")
require.Equal(t, int32(0), atomic.LoadInt32(&client.exchangeCalled))
}
func TestOpenAIOAuthService_ExchangeCode_StateMatch(t *testing.T) {
client := &openaiOAuthClientStateStub{}
svc := NewOpenAIOAuthService(nil, client)
defer svc.Stop()
svc.sessionStore.Set("sid", &openai.OAuthSession{
State: "expected-state",
CodeVerifier: "verifier",
RedirectURI: openai.DefaultRedirectURI,
CreatedAt: time.Now(),
})
info, err := svc.ExchangeCode(context.Background(), &OpenAIExchangeCodeInput{
SessionID: "sid",
Code: "auth-code",
State: "expected-state",
})
require.NoError(t, err)
require.NotNil(t, info)
require.Equal(t, "at", info.AccessToken)
require.Equal(t, int32(1), atomic.LoadInt32(&client.exchangeCalled))
_, ok := svc.sessionStore.Get("sid")
require.False(t, ok)
}