feat: codex channel (#2652)
* feat: codex channel * feat: codex channel * feat: codex oauth flow * feat: codex refresh cred * feat: codex usage * fix: codex err message detail * fix: codex setting ui * feat: codex refresh cred task * fix: import err * fix: codex store must be false * fix: chat -> responses tool call * fix: chat -> responses tool call
This commit is contained in:
243
controller/codex_oauth.go
Normal file
243
controller/codex_oauth.go
Normal file
@@ -0,0 +1,243 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/constant"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
"github.com/QuantumNous/new-api/relay/channel/codex"
|
||||
"github.com/QuantumNous/new-api/service"
|
||||
|
||||
"github.com/gin-contrib/sessions"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type codexOAuthCompleteRequest struct {
|
||||
Input string `json:"input"`
|
||||
}
|
||||
|
||||
func codexOAuthSessionKey(channelID int, field string) string {
|
||||
return fmt.Sprintf("codex_oauth_%s_%d", field, channelID)
|
||||
}
|
||||
|
||||
func parseCodexAuthorizationInput(input string) (code string, state string, err error) {
|
||||
v := strings.TrimSpace(input)
|
||||
if v == "" {
|
||||
return "", "", errors.New("empty input")
|
||||
}
|
||||
if strings.Contains(v, "#") {
|
||||
parts := strings.SplitN(v, "#", 2)
|
||||
code = strings.TrimSpace(parts[0])
|
||||
state = strings.TrimSpace(parts[1])
|
||||
return code, state, nil
|
||||
}
|
||||
if strings.Contains(v, "code=") {
|
||||
u, parseErr := url.Parse(v)
|
||||
if parseErr == nil {
|
||||
q := u.Query()
|
||||
code = strings.TrimSpace(q.Get("code"))
|
||||
state = strings.TrimSpace(q.Get("state"))
|
||||
return code, state, nil
|
||||
}
|
||||
q, parseErr := url.ParseQuery(v)
|
||||
if parseErr == nil {
|
||||
code = strings.TrimSpace(q.Get("code"))
|
||||
state = strings.TrimSpace(q.Get("state"))
|
||||
return code, state, nil
|
||||
}
|
||||
}
|
||||
|
||||
code = v
|
||||
return code, "", nil
|
||||
}
|
||||
|
||||
func StartCodexOAuth(c *gin.Context) {
|
||||
startCodexOAuthWithChannelID(c, 0)
|
||||
}
|
||||
|
||||
func StartCodexOAuthForChannel(c *gin.Context) {
|
||||
channelID, err := strconv.Atoi(c.Param("id"))
|
||||
if err != nil {
|
||||
common.ApiError(c, fmt.Errorf("invalid channel id: %w", err))
|
||||
return
|
||||
}
|
||||
startCodexOAuthWithChannelID(c, channelID)
|
||||
}
|
||||
|
||||
func startCodexOAuthWithChannelID(c *gin.Context, channelID int) {
|
||||
if channelID > 0 {
|
||||
ch, err := model.GetChannelById(channelID, false)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
if ch == nil {
|
||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": "channel not found"})
|
||||
return
|
||||
}
|
||||
if ch.Type != constant.ChannelTypeCodex {
|
||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": "channel type is not Codex"})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
flow, err := service.CreateCodexOAuthAuthorizationFlow()
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
session := sessions.Default(c)
|
||||
session.Set(codexOAuthSessionKey(channelID, "state"), flow.State)
|
||||
session.Set(codexOAuthSessionKey(channelID, "verifier"), flow.Verifier)
|
||||
session.Set(codexOAuthSessionKey(channelID, "created_at"), time.Now().Unix())
|
||||
_ = session.Save()
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": gin.H{
|
||||
"authorize_url": flow.AuthorizeURL,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func CompleteCodexOAuth(c *gin.Context) {
|
||||
completeCodexOAuthWithChannelID(c, 0)
|
||||
}
|
||||
|
||||
func CompleteCodexOAuthForChannel(c *gin.Context) {
|
||||
channelID, err := strconv.Atoi(c.Param("id"))
|
||||
if err != nil {
|
||||
common.ApiError(c, fmt.Errorf("invalid channel id: %w", err))
|
||||
return
|
||||
}
|
||||
completeCodexOAuthWithChannelID(c, channelID)
|
||||
}
|
||||
|
||||
func completeCodexOAuthWithChannelID(c *gin.Context, channelID int) {
|
||||
req := codexOAuthCompleteRequest{}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
code, state, err := parseCodexAuthorizationInput(req.Input)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
|
||||
return
|
||||
}
|
||||
if strings.TrimSpace(code) == "" {
|
||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": "missing authorization code"})
|
||||
return
|
||||
}
|
||||
if strings.TrimSpace(state) == "" {
|
||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": "missing state in input"})
|
||||
return
|
||||
}
|
||||
|
||||
if channelID > 0 {
|
||||
ch, err := model.GetChannelById(channelID, false)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
if ch == nil {
|
||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": "channel not found"})
|
||||
return
|
||||
}
|
||||
if ch.Type != constant.ChannelTypeCodex {
|
||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": "channel type is not Codex"})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
session := sessions.Default(c)
|
||||
expectedState, _ := session.Get(codexOAuthSessionKey(channelID, "state")).(string)
|
||||
verifier, _ := session.Get(codexOAuthSessionKey(channelID, "verifier")).(string)
|
||||
if strings.TrimSpace(expectedState) == "" || strings.TrimSpace(verifier) == "" {
|
||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": "oauth flow not started or session expired"})
|
||||
return
|
||||
}
|
||||
if state != expectedState {
|
||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": "state mismatch"})
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(c.Request.Context(), 15*time.Second)
|
||||
defer cancel()
|
||||
|
||||
tokenRes, err := service.ExchangeCodexAuthorizationCode(ctx, code, verifier)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
accountID, ok := service.ExtractCodexAccountIDFromJWT(tokenRes.AccessToken)
|
||||
if !ok {
|
||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": "failed to extract account_id from access_token"})
|
||||
return
|
||||
}
|
||||
email, _ := service.ExtractEmailFromJWT(tokenRes.AccessToken)
|
||||
|
||||
key := codex.OAuthKey{
|
||||
AccessToken: tokenRes.AccessToken,
|
||||
RefreshToken: tokenRes.RefreshToken,
|
||||
AccountID: accountID,
|
||||
LastRefresh: time.Now().Format(time.RFC3339),
|
||||
Expired: tokenRes.ExpiresAt.Format(time.RFC3339),
|
||||
Email: email,
|
||||
Type: "codex",
|
||||
}
|
||||
encoded, err := common.Marshal(key)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
session.Delete(codexOAuthSessionKey(channelID, "state"))
|
||||
session.Delete(codexOAuthSessionKey(channelID, "verifier"))
|
||||
session.Delete(codexOAuthSessionKey(channelID, "created_at"))
|
||||
_ = session.Save()
|
||||
|
||||
if channelID > 0 {
|
||||
if err := model.DB.Model(&model.Channel{}).Where("id = ?", channelID).Update("key", string(encoded)).Error; err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
model.InitChannelCache()
|
||||
service.ResetProxyClientCache()
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "saved",
|
||||
"data": gin.H{
|
||||
"channel_id": channelID,
|
||||
"account_id": accountID,
|
||||
"email": email,
|
||||
"expires_at": key.Expired,
|
||||
"last_refresh": key.LastRefresh,
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "generated",
|
||||
"data": gin.H{
|
||||
"key": string(encoded),
|
||||
"account_id": accountID,
|
||||
"email": email,
|
||||
"expires_at": key.Expired,
|
||||
"last_refresh": key.LastRefresh,
|
||||
},
|
||||
})
|
||||
}
|
||||
Reference in New Issue
Block a user