feat: codex channel (#2652)
* feat: codex channel * feat: codex channel * feat: codex oauth flow * feat: codex refresh cred * feat: codex usage * fix: codex err message detail * fix: codex setting ui * feat: codex refresh cred task * fix: import err * fix: codex store must be false * fix: chat -> responses tool call * fix: chat -> responses tool call
This commit is contained in:
@@ -1,11 +1,13 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/constant"
|
||||
@@ -604,9 +606,60 @@ func validateChannel(channel *model.Channel, isAdd bool) error {
|
||||
}
|
||||
}
|
||||
|
||||
// Codex OAuth key validation (optional, only when JSON object is provided)
|
||||
if channel.Type == constant.ChannelTypeCodex {
|
||||
trimmedKey := strings.TrimSpace(channel.Key)
|
||||
if isAdd || trimmedKey != "" {
|
||||
if !strings.HasPrefix(trimmedKey, "{") {
|
||||
return fmt.Errorf("Codex key must be a valid JSON object")
|
||||
}
|
||||
var keyMap map[string]any
|
||||
if err := common.Unmarshal([]byte(trimmedKey), &keyMap); err != nil {
|
||||
return fmt.Errorf("Codex key must be a valid JSON object")
|
||||
}
|
||||
if v, ok := keyMap["access_token"]; !ok || v == nil || strings.TrimSpace(fmt.Sprintf("%v", v)) == "" {
|
||||
return fmt.Errorf("Codex key JSON must include access_token")
|
||||
}
|
||||
if v, ok := keyMap["account_id"]; !ok || v == nil || strings.TrimSpace(fmt.Sprintf("%v", v)) == "" {
|
||||
return fmt.Errorf("Codex key JSON must include account_id")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func RefreshCodexChannelCredential(c *gin.Context) {
|
||||
channelId, err := strconv.Atoi(c.Param("id"))
|
||||
if err != nil {
|
||||
common.ApiError(c, fmt.Errorf("invalid channel id: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(c.Request.Context(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
oauthKey, ch, err := service.RefreshCodexChannelCredential(ctx, channelId, service.CodexCredentialRefreshOptions{ResetCaches: true})
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "refreshed",
|
||||
"data": gin.H{
|
||||
"expires_at": oauthKey.Expired,
|
||||
"last_refresh": oauthKey.LastRefresh,
|
||||
"account_id": oauthKey.AccountID,
|
||||
"email": oauthKey.Email,
|
||||
"channel_id": ch.Id,
|
||||
"channel_type": ch.Type,
|
||||
"channel_name": ch.Name,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
type AddChannelRequest struct {
|
||||
Mode string `json:"mode"`
|
||||
MultiKeyMode constant.MultiKeyMode `json:"multi_key_mode"`
|
||||
|
||||
243
controller/codex_oauth.go
Normal file
243
controller/codex_oauth.go
Normal file
@@ -0,0 +1,243 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/constant"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
"github.com/QuantumNous/new-api/relay/channel/codex"
|
||||
"github.com/QuantumNous/new-api/service"
|
||||
|
||||
"github.com/gin-contrib/sessions"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type codexOAuthCompleteRequest struct {
|
||||
Input string `json:"input"`
|
||||
}
|
||||
|
||||
func codexOAuthSessionKey(channelID int, field string) string {
|
||||
return fmt.Sprintf("codex_oauth_%s_%d", field, channelID)
|
||||
}
|
||||
|
||||
func parseCodexAuthorizationInput(input string) (code string, state string, err error) {
|
||||
v := strings.TrimSpace(input)
|
||||
if v == "" {
|
||||
return "", "", errors.New("empty input")
|
||||
}
|
||||
if strings.Contains(v, "#") {
|
||||
parts := strings.SplitN(v, "#", 2)
|
||||
code = strings.TrimSpace(parts[0])
|
||||
state = strings.TrimSpace(parts[1])
|
||||
return code, state, nil
|
||||
}
|
||||
if strings.Contains(v, "code=") {
|
||||
u, parseErr := url.Parse(v)
|
||||
if parseErr == nil {
|
||||
q := u.Query()
|
||||
code = strings.TrimSpace(q.Get("code"))
|
||||
state = strings.TrimSpace(q.Get("state"))
|
||||
return code, state, nil
|
||||
}
|
||||
q, parseErr := url.ParseQuery(v)
|
||||
if parseErr == nil {
|
||||
code = strings.TrimSpace(q.Get("code"))
|
||||
state = strings.TrimSpace(q.Get("state"))
|
||||
return code, state, nil
|
||||
}
|
||||
}
|
||||
|
||||
code = v
|
||||
return code, "", nil
|
||||
}
|
||||
|
||||
func StartCodexOAuth(c *gin.Context) {
|
||||
startCodexOAuthWithChannelID(c, 0)
|
||||
}
|
||||
|
||||
func StartCodexOAuthForChannel(c *gin.Context) {
|
||||
channelID, err := strconv.Atoi(c.Param("id"))
|
||||
if err != nil {
|
||||
common.ApiError(c, fmt.Errorf("invalid channel id: %w", err))
|
||||
return
|
||||
}
|
||||
startCodexOAuthWithChannelID(c, channelID)
|
||||
}
|
||||
|
||||
func startCodexOAuthWithChannelID(c *gin.Context, channelID int) {
|
||||
if channelID > 0 {
|
||||
ch, err := model.GetChannelById(channelID, false)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
if ch == nil {
|
||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": "channel not found"})
|
||||
return
|
||||
}
|
||||
if ch.Type != constant.ChannelTypeCodex {
|
||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": "channel type is not Codex"})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
flow, err := service.CreateCodexOAuthAuthorizationFlow()
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
session := sessions.Default(c)
|
||||
session.Set(codexOAuthSessionKey(channelID, "state"), flow.State)
|
||||
session.Set(codexOAuthSessionKey(channelID, "verifier"), flow.Verifier)
|
||||
session.Set(codexOAuthSessionKey(channelID, "created_at"), time.Now().Unix())
|
||||
_ = session.Save()
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": gin.H{
|
||||
"authorize_url": flow.AuthorizeURL,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func CompleteCodexOAuth(c *gin.Context) {
|
||||
completeCodexOAuthWithChannelID(c, 0)
|
||||
}
|
||||
|
||||
func CompleteCodexOAuthForChannel(c *gin.Context) {
|
||||
channelID, err := strconv.Atoi(c.Param("id"))
|
||||
if err != nil {
|
||||
common.ApiError(c, fmt.Errorf("invalid channel id: %w", err))
|
||||
return
|
||||
}
|
||||
completeCodexOAuthWithChannelID(c, channelID)
|
||||
}
|
||||
|
||||
func completeCodexOAuthWithChannelID(c *gin.Context, channelID int) {
|
||||
req := codexOAuthCompleteRequest{}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
code, state, err := parseCodexAuthorizationInput(req.Input)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
|
||||
return
|
||||
}
|
||||
if strings.TrimSpace(code) == "" {
|
||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": "missing authorization code"})
|
||||
return
|
||||
}
|
||||
if strings.TrimSpace(state) == "" {
|
||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": "missing state in input"})
|
||||
return
|
||||
}
|
||||
|
||||
if channelID > 0 {
|
||||
ch, err := model.GetChannelById(channelID, false)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
if ch == nil {
|
||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": "channel not found"})
|
||||
return
|
||||
}
|
||||
if ch.Type != constant.ChannelTypeCodex {
|
||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": "channel type is not Codex"})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
session := sessions.Default(c)
|
||||
expectedState, _ := session.Get(codexOAuthSessionKey(channelID, "state")).(string)
|
||||
verifier, _ := session.Get(codexOAuthSessionKey(channelID, "verifier")).(string)
|
||||
if strings.TrimSpace(expectedState) == "" || strings.TrimSpace(verifier) == "" {
|
||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": "oauth flow not started or session expired"})
|
||||
return
|
||||
}
|
||||
if state != expectedState {
|
||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": "state mismatch"})
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(c.Request.Context(), 15*time.Second)
|
||||
defer cancel()
|
||||
|
||||
tokenRes, err := service.ExchangeCodexAuthorizationCode(ctx, code, verifier)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
accountID, ok := service.ExtractCodexAccountIDFromJWT(tokenRes.AccessToken)
|
||||
if !ok {
|
||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": "failed to extract account_id from access_token"})
|
||||
return
|
||||
}
|
||||
email, _ := service.ExtractEmailFromJWT(tokenRes.AccessToken)
|
||||
|
||||
key := codex.OAuthKey{
|
||||
AccessToken: tokenRes.AccessToken,
|
||||
RefreshToken: tokenRes.RefreshToken,
|
||||
AccountID: accountID,
|
||||
LastRefresh: time.Now().Format(time.RFC3339),
|
||||
Expired: tokenRes.ExpiresAt.Format(time.RFC3339),
|
||||
Email: email,
|
||||
Type: "codex",
|
||||
}
|
||||
encoded, err := common.Marshal(key)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
session.Delete(codexOAuthSessionKey(channelID, "state"))
|
||||
session.Delete(codexOAuthSessionKey(channelID, "verifier"))
|
||||
session.Delete(codexOAuthSessionKey(channelID, "created_at"))
|
||||
_ = session.Save()
|
||||
|
||||
if channelID > 0 {
|
||||
if err := model.DB.Model(&model.Channel{}).Where("id = ?", channelID).Update("key", string(encoded)).Error; err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
model.InitChannelCache()
|
||||
service.ResetProxyClientCache()
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "saved",
|
||||
"data": gin.H{
|
||||
"channel_id": channelID,
|
||||
"account_id": accountID,
|
||||
"email": email,
|
||||
"expires_at": key.Expired,
|
||||
"last_refresh": key.LastRefresh,
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "generated",
|
||||
"data": gin.H{
|
||||
"key": string(encoded),
|
||||
"account_id": accountID,
|
||||
"email": email,
|
||||
"expires_at": key.Expired,
|
||||
"last_refresh": key.LastRefresh,
|
||||
},
|
||||
})
|
||||
}
|
||||
124
controller/codex_usage.go
Normal file
124
controller/codex_usage.go
Normal file
@@ -0,0 +1,124 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/constant"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
"github.com/QuantumNous/new-api/relay/channel/codex"
|
||||
"github.com/QuantumNous/new-api/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func GetCodexChannelUsage(c *gin.Context) {
|
||||
channelId, err := strconv.Atoi(c.Param("id"))
|
||||
if err != nil {
|
||||
common.ApiError(c, fmt.Errorf("invalid channel id: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
ch, err := model.GetChannelById(channelId, true)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
if ch == nil {
|
||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": "channel not found"})
|
||||
return
|
||||
}
|
||||
if ch.Type != constant.ChannelTypeCodex {
|
||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": "channel type is not Codex"})
|
||||
return
|
||||
}
|
||||
if ch.ChannelInfo.IsMultiKey {
|
||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": "multi-key channel is not supported"})
|
||||
return
|
||||
}
|
||||
|
||||
oauthKey, err := codex.ParseOAuthKey(strings.TrimSpace(ch.Key))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
|
||||
return
|
||||
}
|
||||
accessToken := strings.TrimSpace(oauthKey.AccessToken)
|
||||
accountID := strings.TrimSpace(oauthKey.AccountID)
|
||||
if accessToken == "" {
|
||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": "codex channel: access_token is required"})
|
||||
return
|
||||
}
|
||||
if accountID == "" {
|
||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": "codex channel: account_id is required"})
|
||||
return
|
||||
}
|
||||
|
||||
client, err := service.NewProxyHttpClient(ch.GetSetting().Proxy)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(c.Request.Context(), 15*time.Second)
|
||||
defer cancel()
|
||||
|
||||
statusCode, body, err := service.FetchCodexWhamUsage(ctx, client, ch.GetBaseURL(), accessToken, accountID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if (statusCode == http.StatusUnauthorized || statusCode == http.StatusForbidden) && strings.TrimSpace(oauthKey.RefreshToken) != "" {
|
||||
refreshCtx, refreshCancel := context.WithTimeout(c.Request.Context(), 10*time.Second)
|
||||
defer refreshCancel()
|
||||
|
||||
res, refreshErr := service.RefreshCodexOAuthToken(refreshCtx, oauthKey.RefreshToken)
|
||||
if refreshErr == nil {
|
||||
oauthKey.AccessToken = res.AccessToken
|
||||
oauthKey.RefreshToken = res.RefreshToken
|
||||
oauthKey.LastRefresh = time.Now().Format(time.RFC3339)
|
||||
oauthKey.Expired = res.ExpiresAt.Format(time.RFC3339)
|
||||
if strings.TrimSpace(oauthKey.Type) == "" {
|
||||
oauthKey.Type = "codex"
|
||||
}
|
||||
|
||||
encoded, encErr := common.Marshal(oauthKey)
|
||||
if encErr == nil {
|
||||
_ = model.DB.Model(&model.Channel{}).Where("id = ?", ch.Id).Update("key", string(encoded)).Error
|
||||
model.InitChannelCache()
|
||||
service.ResetProxyClientCache()
|
||||
}
|
||||
|
||||
ctx2, cancel2 := context.WithTimeout(c.Request.Context(), 15*time.Second)
|
||||
defer cancel2()
|
||||
statusCode, body, err = service.FetchCodexWhamUsage(ctx2, client, ch.GetBaseURL(), oauthKey.AccessToken, accountID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var payload any
|
||||
if json.Unmarshal(body, &payload) != nil {
|
||||
payload = string(body)
|
||||
}
|
||||
|
||||
ok := statusCode >= 200 && statusCode < 300
|
||||
resp := gin.H{
|
||||
"success": ok,
|
||||
"message": "",
|
||||
"upstream_status": statusCode,
|
||||
"data": payload,
|
||||
}
|
||||
if !ok {
|
||||
resp["message"] = fmt.Sprintf("upstream status: %d", statusCode)
|
||||
}
|
||||
c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
Reference in New Issue
Block a user