diff --git a/common/api_type.go b/common/api_type.go
index d9071236..f045866a 100644
--- a/common/api_type.go
+++ b/common/api_type.go
@@ -63,6 +63,8 @@ func ChannelType2APIType(channelType int) (int, bool) {
apiType = constant.APITypeXai
case constant.ChannelTypeCoze:
apiType = constant.APITypeCoze
+ case constant.ChannelTypeJimeng:
+ apiType = constant.APITypeJimeng
}
if apiType == -1 {
return constant.APITypeOpenAI, false
diff --git a/common/constants.go b/common/constants.go
index e4f5f047..30522411 100644
--- a/common/constants.go
+++ b/common/constants.go
@@ -193,3 +193,9 @@ const (
ChannelStatusManuallyDisabled = 2 // also don't use 0
ChannelStatusAutoDisabled = 3
)
+
+const (
+ TopUpStatusPending = "pending"
+ TopUpStatusSuccess = "success"
+ TopUpStatusExpired = "expired"
+)
diff --git a/common/hash.go b/common/hash.go
new file mode 100644
index 00000000..50191938
--- /dev/null
+++ b/common/hash.go
@@ -0,0 +1,34 @@
+package common
+
+import (
+ "crypto/hmac"
+ "crypto/sha1"
+ "crypto/sha256"
+ "encoding/hex"
+)
+
+func Sha256Raw(data []byte) []byte {
+ h := sha256.New()
+ h.Write(data)
+ return h.Sum(nil)
+}
+
+func Sha1Raw(data []byte) []byte {
+ h := sha1.New()
+ h.Write(data)
+ return h.Sum(nil)
+}
+
+func Sha1(data []byte) string {
+ return hex.EncodeToString(Sha1Raw(data))
+}
+
+func HmacSha256Raw(message, key []byte) []byte {
+ h := hmac.New(sha256.New, key)
+ h.Write(message)
+ return h.Sum(nil)
+}
+
+func HmacSha256(message, key string) string {
+ return hex.EncodeToString(HmacSha256Raw([]byte(message), []byte(key)))
+}
diff --git a/common/logger.go b/common/logger.go
index 86d15fa4..0f6dc3c3 100644
--- a/common/logger.go
+++ b/common/logger.go
@@ -75,6 +75,9 @@ func logHelper(ctx context.Context, level string, msg string) {
writer = gin.DefaultWriter
}
id := ctx.Value(RequestIdKey)
+ if id == nil {
+ id = "SYSTEM"
+ }
now := time.Now()
_, _ = fmt.Fprintf(writer, "[%s] %v | %s | %s \n", level, now.Format("2006/01/02 - 15:04:05"), id, msg)
logCount++ // we don't need accurate count, so no lock here
diff --git a/constant/api_type.go b/constant/api_type.go
index ae867870..6ba5f257 100644
--- a/constant/api_type.go
+++ b/constant/api_type.go
@@ -30,5 +30,6 @@ const (
APITypeXinference
APITypeXai
APITypeCoze
+ APITypeJimeng
APITypeDummy // this one is only for count, do not add any channel after this
)
diff --git a/controller/misc.go b/controller/misc.go
index e89909d1..a3ed9be9 100644
--- a/controller/misc.go
+++ b/controller/misc.go
@@ -57,7 +57,9 @@ func GetStatus(c *gin.Context) {
"wechat_login": common.WeChatAuthEnabled,
"server_address": setting.ServerAddress,
"price": setting.Price,
+ "stripe_unit_price": setting.StripeUnitPrice,
"min_topup": setting.MinTopUp,
+ "stripe_min_topup": setting.StripeMinTopUp,
"turnstile_check": common.TurnstileCheckEnabled,
"turnstile_site_key": common.TurnstileSiteKey,
"top_up_link": common.TopUpLink,
@@ -71,6 +73,7 @@ func GetStatus(c *gin.Context) {
"data_export_default_time": common.DataExportDefaultTime,
"default_collapse_sidebar": common.DefaultCollapseSidebar,
"enable_online_topup": setting.PayAddress != "" && setting.EpayId != "" && setting.EpayKey != "",
+ "enable_stripe_topup": setting.StripeApiSecret != "" && setting.StripeWebhookSecret != "" && setting.StripePriceId != "",
"mj_notify_enabled": setting.MjNotifyEnabled,
"chats": setting.Chats,
"demo_site_enabled": operation_setting.DemoSiteEnabled,
diff --git a/controller/swag_video.go b/controller/swag_video.go
new file mode 100644
index 00000000..185fd515
--- /dev/null
+++ b/controller/swag_video.go
@@ -0,0 +1,116 @@
+package controller
+
+import (
+ "github.com/gin-gonic/gin"
+)
+
+// VideoGenerations
+// @Summary 生成视频
+// @Description 调用视频生成接口生成视频
+// @Description 支持多种视频生成服务:
+// @Description - 可灵AI (Kling): https://app.klingai.com/cn/dev/document-api/apiReference/commonInfo
+// @Description - 即梦 (Jimeng): https://www.volcengine.com/docs/85621/1538636
+// @Tags Video
+// @Accept json
+// @Produce json
+// @Param Authorization header string true "用户认证令牌 (Aeess-Token: sk-xxxx)"
+// @Param request body dto.VideoRequest true "视频生成请求参数"
+// @Failure 400 {object} dto.OpenAIError "请求参数错误"
+// @Failure 401 {object} dto.OpenAIError "未授权"
+// @Failure 403 {object} dto.OpenAIError "无权限"
+// @Failure 500 {object} dto.OpenAIError "服务器内部错误"
+// @Router /v1/video/generations [post]
+func VideoGenerations(c *gin.Context) {
+}
+
+// VideoGenerationsTaskId
+// @Summary 查询视频
+// @Description 根据任务ID查询视频生成任务的状态和结果
+// @Tags Video
+// @Accept json
+// @Produce json
+// @Security BearerAuth
+// @Param task_id path string true "Task ID"
+// @Success 200 {object} dto.VideoTaskResponse "任务状态和结果"
+// @Failure 400 {object} dto.OpenAIError "请求参数错误"
+// @Failure 401 {object} dto.OpenAIError "未授权"
+// @Failure 403 {object} dto.OpenAIError "无权限"
+// @Failure 500 {object} dto.OpenAIError "服务器内部错误"
+// @Router /v1/video/generations/{task_id} [get]
+func VideoGenerationsTaskId(c *gin.Context) {
+}
+
+// KlingText2VideoGenerations
+// @Summary 可灵文生视频
+// @Description 调用可灵AI文生视频接口,生成视频内容
+// @Tags Video
+// @Accept json
+// @Produce json
+// @Param Authorization header string true "用户认证令牌 (Aeess-Token: sk-xxxx)"
+// @Param request body KlingText2VideoRequest true "视频生成请求参数"
+// @Success 200 {object} dto.VideoTaskResponse "任务状态和结果"
+// @Failure 400 {object} dto.OpenAIError "请求参数错误"
+// @Failure 401 {object} dto.OpenAIError "未授权"
+// @Failure 403 {object} dto.OpenAIError "无权限"
+// @Failure 500 {object} dto.OpenAIError "服务器内部错误"
+// @Router /kling/v1/videos/text2video [post]
+func KlingText2VideoGenerations(c *gin.Context) {
+}
+
+type KlingText2VideoRequest struct {
+ ModelName string `json:"model_name,omitempty" example:"kling-v1"`
+ Prompt string `json:"prompt" binding:"required" example:"A cat playing piano in the garden"`
+ NegativePrompt string `json:"negative_prompt,omitempty" example:"blurry, low quality"`
+ CfgScale float64 `json:"cfg_scale,omitempty" example:"0.7"`
+ Mode string `json:"mode,omitempty" example:"std"`
+ CameraControl *KlingCameraControl `json:"camera_control,omitempty"`
+ AspectRatio string `json:"aspect_ratio,omitempty" example:"16:9"`
+ Duration string `json:"duration,omitempty" example:"5"`
+ CallbackURL string `json:"callback_url,omitempty" example:"https://your.domain/callback"`
+ ExternalTaskId string `json:"external_task_id,omitempty" example:"custom-task-001"`
+}
+
+type KlingCameraControl struct {
+ Type string `json:"type,omitempty" example:"simple"`
+ Config *KlingCameraConfig `json:"config,omitempty"`
+}
+
+type KlingCameraConfig struct {
+ Horizontal float64 `json:"horizontal,omitempty" example:"2.5"`
+ Vertical float64 `json:"vertical,omitempty" example:"0"`
+ Pan float64 `json:"pan,omitempty" example:"0"`
+ Tilt float64 `json:"tilt,omitempty" example:"0"`
+ Roll float64 `json:"roll,omitempty" example:"0"`
+ Zoom float64 `json:"zoom,omitempty" example:"0"`
+}
+
+// KlingImage2VideoGenerations
+// @Summary 可灵官方-图生视频
+// @Description 调用可灵AI图生视频接口,生成视频内容
+// @Tags Video
+// @Accept json
+// @Produce json
+// @Param Authorization header string true "用户认证令牌 (Aeess-Token: sk-xxxx)"
+// @Param request body KlingImage2VideoRequest true "图生视频请求参数"
+// @Success 200 {object} dto.VideoTaskResponse "任务状态和结果"
+// @Failure 400 {object} dto.OpenAIError "请求参数错误"
+// @Failure 401 {object} dto.OpenAIError "未授权"
+// @Failure 403 {object} dto.OpenAIError "无权限"
+// @Failure 500 {object} dto.OpenAIError "服务器内部错误"
+// @Router /kling/v1/videos/image2video [post]
+func KlingImage2VideoGenerations(c *gin.Context) {
+}
+
+type KlingImage2VideoRequest struct {
+ ModelName string `json:"model_name,omitempty" example:"kling-v2-master"`
+ Image string `json:"image" binding:"required" example:"https://h2.inkwai.com/bs2/upload-ylab-stunt/se/ai_portal_queue_mmu_image_upscale_aiweb/3214b798-e1b4-4b00-b7af-72b5b0417420_raw_image_0.jpg"`
+ Prompt string `json:"prompt,omitempty" example:"A cat playing piano in the garden"`
+ NegativePrompt string `json:"negative_prompt,omitempty" example:"blurry, low quality"`
+ CfgScale float64 `json:"cfg_scale,omitempty" example:"0.7"`
+ Mode string `json:"mode,omitempty" example:"std"`
+ CameraControl *KlingCameraControl `json:"camera_control,omitempty"`
+ AspectRatio string `json:"aspect_ratio,omitempty" example:"16:9"`
+ Duration string `json:"duration,omitempty" example:"5"`
+ CallbackURL string `json:"callback_url,omitempty" example:"https://your.domain/callback"`
+ ExternalTaskId string `json:"external_task_id,omitempty" example:"custom-task-002"`
+}
diff --git a/controller/topup_stripe.go b/controller/topup_stripe.go
new file mode 100644
index 00000000..eb320809
--- /dev/null
+++ b/controller/topup_stripe.go
@@ -0,0 +1,275 @@
+package controller
+
+import (
+ "fmt"
+ "io"
+ "log"
+ "net/http"
+ "one-api/common"
+ "one-api/model"
+ "one-api/setting"
+ "strconv"
+ "strings"
+ "time"
+
+ "github.com/gin-gonic/gin"
+ "github.com/stripe/stripe-go/v81"
+ "github.com/stripe/stripe-go/v81/checkout/session"
+ "github.com/stripe/stripe-go/v81/webhook"
+ "github.com/thanhpk/randstr"
+)
+
+const (
+ PaymentMethodStripe = "stripe"
+)
+
+var stripeAdaptor = &StripeAdaptor{}
+
+type StripePayRequest struct {
+ Amount int64 `json:"amount"`
+ PaymentMethod string `json:"payment_method"`
+}
+
+type StripeAdaptor struct {
+}
+
+func (*StripeAdaptor) RequestAmount(c *gin.Context, req *StripePayRequest) {
+ if req.Amount < getStripeMinTopup() {
+ c.JSON(200, gin.H{"message": "error", "data": fmt.Sprintf("充值数量不能小于 %d", getStripeMinTopup())})
+ return
+ }
+ id := c.GetInt("id")
+ group, err := model.GetUserGroup(id, true)
+ if err != nil {
+ c.JSON(200, gin.H{"message": "error", "data": "获取用户分组失败"})
+ return
+ }
+ payMoney := getStripePayMoney(float64(req.Amount), group)
+ if payMoney <= 0.01 {
+ c.JSON(200, gin.H{"message": "error", "data": "充值金额过低"})
+ return
+ }
+ c.JSON(200, gin.H{"message": "success", "data": strconv.FormatFloat(payMoney, 'f', 2, 64)})
+}
+
+func (*StripeAdaptor) RequestPay(c *gin.Context, req *StripePayRequest) {
+ if req.PaymentMethod != PaymentMethodStripe {
+ c.JSON(200, gin.H{"message": "error", "data": "不支持的支付渠道"})
+ return
+ }
+ if req.Amount < getStripeMinTopup() {
+ c.JSON(200, gin.H{"message": fmt.Sprintf("充值数量不能小于 %d", getStripeMinTopup()), "data": 10})
+ return
+ }
+ if req.Amount > 10000 {
+ c.JSON(200, gin.H{"message": "充值数量不能大于 10000", "data": 10})
+ return
+ }
+
+ id := c.GetInt("id")
+ user, _ := model.GetUserById(id, false)
+ chargedMoney := GetChargedAmount(float64(req.Amount), *user)
+
+ reference := fmt.Sprintf("new-api-ref-%d-%d-%s", user.Id, time.Now().UnixMilli(), randstr.String(4))
+ referenceId := "ref_" + common.Sha1([]byte(reference))
+
+ payLink, err := genStripeLink(referenceId, user.StripeCustomer, user.Email, req.Amount)
+ if err != nil {
+ log.Println("获取Stripe Checkout支付链接失败", err)
+ c.JSON(200, gin.H{"message": "error", "data": "拉起支付失败"})
+ return
+ }
+
+ topUp := &model.TopUp{
+ UserId: id,
+ Amount: req.Amount,
+ Money: chargedMoney,
+ TradeNo: referenceId,
+ CreateTime: time.Now().Unix(),
+ Status: common.TopUpStatusPending,
+ }
+ err = topUp.Insert()
+ if err != nil {
+ c.JSON(200, gin.H{"message": "error", "data": "创建订单失败"})
+ return
+ }
+ c.JSON(200, gin.H{
+ "message": "success",
+ "data": gin.H{
+ "pay_link": payLink,
+ },
+ })
+}
+
+func RequestStripeAmount(c *gin.Context) {
+ var req StripePayRequest
+ err := c.ShouldBindJSON(&req)
+ if err != nil {
+ c.JSON(200, gin.H{"message": "error", "data": "参数错误"})
+ return
+ }
+ stripeAdaptor.RequestAmount(c, &req)
+}
+
+func RequestStripePay(c *gin.Context) {
+ var req StripePayRequest
+ err := c.ShouldBindJSON(&req)
+ if err != nil {
+ c.JSON(200, gin.H{"message": "error", "data": "参数错误"})
+ return
+ }
+ stripeAdaptor.RequestPay(c, &req)
+}
+
+func StripeWebhook(c *gin.Context) {
+ payload, err := io.ReadAll(c.Request.Body)
+ if err != nil {
+ log.Printf("解析Stripe Webhook参数失败: %v\n", err)
+ c.AbortWithStatus(http.StatusServiceUnavailable)
+ return
+ }
+
+ signature := c.GetHeader("Stripe-Signature")
+ endpointSecret := setting.StripeWebhookSecret
+ event, err := webhook.ConstructEventWithOptions(payload, signature, endpointSecret, webhook.ConstructEventOptions{
+ IgnoreAPIVersionMismatch: true,
+ })
+
+ if err != nil {
+ log.Printf("Stripe Webhook验签失败: %v\n", err)
+ c.AbortWithStatus(http.StatusBadRequest)
+ return
+ }
+
+ switch event.Type {
+ case stripe.EventTypeCheckoutSessionCompleted:
+ sessionCompleted(event)
+ case stripe.EventTypeCheckoutSessionExpired:
+ sessionExpired(event)
+ default:
+ log.Printf("不支持的Stripe Webhook事件类型: %s\n", event.Type)
+ }
+
+ c.Status(http.StatusOK)
+}
+
+func sessionCompleted(event stripe.Event) {
+ customerId := event.GetObjectValue("customer")
+ referenceId := event.GetObjectValue("client_reference_id")
+ status := event.GetObjectValue("status")
+ if "complete" != status {
+ log.Println("错误的Stripe Checkout完成状态:", status, ",", referenceId)
+ return
+ }
+
+ err := model.Recharge(referenceId, customerId)
+ if err != nil {
+ log.Println(err.Error(), referenceId)
+ return
+ }
+
+ total, _ := strconv.ParseFloat(event.GetObjectValue("amount_total"), 64)
+ currency := strings.ToUpper(event.GetObjectValue("currency"))
+ log.Printf("收到款项:%s, %.2f(%s)", referenceId, total/100, currency)
+}
+
+func sessionExpired(event stripe.Event) {
+ referenceId := event.GetObjectValue("client_reference_id")
+ status := event.GetObjectValue("status")
+ if "expired" != status {
+ log.Println("错误的Stripe Checkout过期状态:", status, ",", referenceId)
+ return
+ }
+
+ if len(referenceId) == 0 {
+ log.Println("未提供支付单号")
+ return
+ }
+
+ topUp := model.GetTopUpByTradeNo(referenceId)
+ if topUp == nil {
+ log.Println("充值订单不存在", referenceId)
+ return
+ }
+
+ if topUp.Status != common.TopUpStatusPending {
+ log.Println("充值订单状态错误", referenceId)
+ }
+
+ topUp.Status = common.TopUpStatusExpired
+ err := topUp.Update()
+ if err != nil {
+ log.Println("过期充值订单失败", referenceId, ", err:", err.Error())
+ return
+ }
+
+ log.Println("充值订单已过期", referenceId)
+}
+
+func genStripeLink(referenceId string, customerId string, email string, amount int64) (string, error) {
+ if !strings.HasPrefix(setting.StripeApiSecret, "sk_") && !strings.HasPrefix(setting.StripeApiSecret, "rk_") {
+ return "", fmt.Errorf("无效的Stripe API密钥")
+ }
+
+ stripe.Key = setting.StripeApiSecret
+
+ params := &stripe.CheckoutSessionParams{
+ ClientReferenceID: stripe.String(referenceId),
+ SuccessURL: stripe.String(setting.ServerAddress + "/log"),
+ CancelURL: stripe.String(setting.ServerAddress + "/topup"),
+ LineItems: []*stripe.CheckoutSessionLineItemParams{
+ {
+ Price: stripe.String(setting.StripePriceId),
+ Quantity: stripe.Int64(amount),
+ },
+ },
+ Mode: stripe.String(string(stripe.CheckoutSessionModePayment)),
+ }
+
+ if "" == customerId {
+ if "" != email {
+ params.CustomerEmail = stripe.String(email)
+ }
+
+ params.CustomerCreation = stripe.String(string(stripe.CheckoutSessionCustomerCreationAlways))
+ } else {
+ params.Customer = stripe.String(customerId)
+ }
+
+ result, err := session.New(params)
+ if err != nil {
+ return "", err
+ }
+
+ return result.URL, nil
+}
+
+func GetChargedAmount(count float64, user model.User) float64 {
+ topUpGroupRatio := common.GetTopupGroupRatio(user.Group)
+ if topUpGroupRatio == 0 {
+ topUpGroupRatio = 1
+ }
+
+ return count * topUpGroupRatio
+}
+
+func getStripePayMoney(amount float64, group string) float64 {
+ if !common.DisplayInCurrencyEnabled {
+ amount = amount / common.QuotaPerUnit
+ }
+ // Using float64 for monetary calculations is acceptable here due to the small amounts involved
+ topupGroupRatio := common.GetTopupGroupRatio(group)
+ if topupGroupRatio == 0 {
+ topupGroupRatio = 1
+ }
+ payMoney := amount * setting.StripeUnitPrice * topupGroupRatio
+ return payMoney
+}
+
+func getStripeMinTopup() int64 {
+ minTopup := setting.StripeMinTopUp
+ if !common.DisplayInCurrencyEnabled {
+ minTopup = minTopup * int(common.QuotaPerUnit)
+ }
+ return int64(minTopup)
+}
diff --git a/go.mod b/go.mod
index 9479ba55..94873c88 100644
--- a/go.mod
+++ b/go.mod
@@ -27,10 +27,13 @@ require (
github.com/samber/lo v1.39.0
github.com/shirou/gopsutil v3.21.11+incompatible
github.com/shopspring/decimal v1.4.0
+ github.com/stripe/stripe-go/v81 v81.4.0
+ github.com/thanhpk/randstr v1.0.6
github.com/tiktoken-go/tokenizer v0.6.2
golang.org/x/crypto v0.35.0
golang.org/x/image v0.23.0
golang.org/x/net v0.35.0
+ golang.org/x/sync v0.11.0
gorm.io/driver/mysql v1.4.3
gorm.io/driver/postgres v1.5.2
gorm.io/gorm v1.25.2
@@ -84,7 +87,6 @@ require (
github.com/yusufpapurcu/wmi v1.2.3 // indirect
golang.org/x/arch v0.12.0 // indirect
golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0 // indirect
- golang.org/x/sync v0.11.0 // indirect
golang.org/x/sys v0.30.0 // indirect
golang.org/x/text v0.22.0 // indirect
google.golang.org/protobuf v1.34.2 // indirect
diff --git a/go.sum b/go.sum
index 71dd83c2..74eecd4c 100644
--- a/go.sum
+++ b/go.sum
@@ -195,6 +195,10 @@ github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
+github.com/stripe/stripe-go/v81 v81.4.0 h1:AuD9XzdAvl193qUCSaLocf8H+nRopOouXhxqJUzCLbw=
+github.com/stripe/stripe-go/v81 v81.4.0/go.mod h1:C/F4jlmnGNacvYtBp/LUHCvVUJEZffFQCobkzwY1WOo=
+github.com/thanhpk/randstr v1.0.6 h1:psAOktJFD4vV9NEVb3qkhRSMvYh4ORRaj1+w/hn4B+o=
+github.com/thanhpk/randstr v1.0.6/go.mod h1:M/H2P1eNLZzlDwAzpkkkUvoyNNMbzRGhESZuEQk3r0U=
github.com/tiktoken-go/tokenizer v0.6.2 h1:t0GN2DvcUZSFWT/62YOgoqb10y7gSXBGs0A+4VCQK+g=
github.com/tiktoken-go/tokenizer v0.6.2/go.mod h1:6UCYI/DtOallbmL7sSy30p6YQv60qNyU/4aVigPOx6w=
github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFAEVmqU=
@@ -224,6 +228,7 @@ golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0/go.mod h1:/lliqkxwWAhPjf5oSO
golang.org/x/image v0.23.0 h1:HseQ7c2OpPKTPVzNjG5fwJsOTCiiwS4QdsYi5XU6H68=
golang.org/x/image v0.23.0/go.mod h1:wJJBTdLfCCf3tiHa1fNxpZmUI4mmoZvwMCPP0ddoNKY=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
+golang.org/x/net v0.0.0-20210520170846-37e1c6afe023/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.35.0 h1:T5GQRQb2y08kTAByq9L4/bz8cipCdA8FbRTXewonqY8=
golang.org/x/net v0.35.0/go.mod h1:EglIi67kWsHKlRzzVMUD93VMSWGFOMSZgxFjparz1Qk=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
@@ -232,6 +237,7 @@ golang.org/x/sync v0.11.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
diff --git a/middleware/kling_adapter.go b/middleware/kling_adapter.go
index 8e2a3551..3d4943d2 100644
--- a/middleware/kling_adapter.go
+++ b/middleware/kling_adapter.go
@@ -18,7 +18,7 @@ func KlingRequestConvert() func(c *gin.Context) {
return
}
- model, _ := originalReq["model"].(string)
+ model, _ := originalReq["model_name"].(string)
prompt, _ := originalReq["prompt"].(string)
unifiedReq := map[string]interface{}{
@@ -36,7 +36,7 @@ func KlingRequestConvert() func(c *gin.Context) {
// Rewrite request body and path
c.Request.Body = io.NopCloser(bytes.NewBuffer(jsonData))
c.Request.URL.Path = "/v1/video/generations"
- if image := originalReq["image"]; image == "" {
+ if image, ok := originalReq["image"]; !ok || image == "" {
c.Set("action", constant.TaskActionTextGenerate)
}
diff --git a/model/log.go b/model/log.go
index 45923075..2070cd6f 100644
--- a/model/log.go
+++ b/model/log.go
@@ -27,7 +27,7 @@ type Log struct {
PromptTokens int `json:"prompt_tokens" gorm:"default:0"`
CompletionTokens int `json:"completion_tokens" gorm:"default:0"`
UseTime int `json:"use_time" gorm:"default:0"`
- IsStream bool `json:"is_stream" gorm:"default:false"`
+ IsStream bool `json:"is_stream"`
ChannelId int `json:"channel" gorm:"index"`
ChannelName string `json:"channel_name" gorm:"->"`
TokenId int `json:"token_id" gorm:"default:0;index"`
diff --git a/model/option.go b/model/option.go
index 9e58a81d..05b99b41 100644
--- a/model/option.go
+++ b/model/option.go
@@ -76,6 +76,11 @@ func InitOptionMap() {
common.OptionMap["Price"] = strconv.FormatFloat(setting.Price, 'f', -1, 64)
common.OptionMap["USDExchangeRate"] = strconv.FormatFloat(setting.USDExchangeRate, 'f', -1, 64)
common.OptionMap["MinTopUp"] = strconv.Itoa(setting.MinTopUp)
+ common.OptionMap["StripeMinTopUp"] = strconv.Itoa(setting.StripeMinTopUp)
+ common.OptionMap["StripeApiSecret"] = setting.StripeApiSecret
+ common.OptionMap["StripeWebhookSecret"] = setting.StripeWebhookSecret
+ common.OptionMap["StripePriceId"] = setting.StripePriceId
+ common.OptionMap["StripeUnitPrice"] = strconv.FormatFloat(setting.StripeUnitPrice, 'f', -1, 64)
common.OptionMap["TopupGroupRatio"] = common.TopupGroupRatio2JSONString()
common.OptionMap["Chats"] = setting.Chats2JsonString()
common.OptionMap["AutoGroups"] = setting.AutoGroups2JsonString()
@@ -311,6 +316,16 @@ func updateOptionMap(key string, value string) (err error) {
setting.USDExchangeRate, _ = strconv.ParseFloat(value, 64)
case "MinTopUp":
setting.MinTopUp, _ = strconv.Atoi(value)
+ case "StripeApiSecret":
+ setting.StripeApiSecret = value
+ case "StripeWebhookSecret":
+ setting.StripeWebhookSecret = value
+ case "StripePriceId":
+ setting.StripePriceId = value
+ case "StripeUnitPrice":
+ setting.StripeUnitPrice, _ = strconv.ParseFloat(value, 64)
+ case "StripeMinTopUp":
+ setting.StripeMinTopUp, _ = strconv.Atoi(value)
case "TopupGroupRatio":
err = common.UpdateTopupGroupRatioByJSONString(value)
case "GitHubClientId":
diff --git a/model/token.go b/model/token.go
index 7e68f185..e85a445e 100644
--- a/model/token.go
+++ b/model/token.go
@@ -20,8 +20,8 @@ type Token struct {
AccessedTime int64 `json:"accessed_time" gorm:"bigint"`
ExpiredTime int64 `json:"expired_time" gorm:"bigint;default:-1"` // -1 means never expired
RemainQuota int `json:"remain_quota" gorm:"default:0"`
- UnlimitedQuota bool `json:"unlimited_quota" gorm:"default:false"`
- ModelLimitsEnabled bool `json:"model_limits_enabled" gorm:"default:false"`
+ UnlimitedQuota bool `json:"unlimited_quota"`
+ ModelLimitsEnabled bool `json:"model_limits_enabled"`
ModelLimits string `json:"model_limits" gorm:"type:varchar(1024);default:''"`
AllowIps *string `json:"allow_ips" gorm:"default:''"`
UsedQuota int `json:"used_quota" gorm:"default:0"` // used quota
diff --git a/model/topup.go b/model/topup.go
index 507b8518..39d96721 100644
--- a/model/topup.go
+++ b/model/topup.go
@@ -1,13 +1,21 @@
package model
+import (
+ "errors"
+ "fmt"
+ "gorm.io/gorm"
+ "one-api/common"
+)
+
type TopUp struct {
- Id int `json:"id"`
- UserId int `json:"user_id" gorm:"index"`
- Amount int64 `json:"amount"`
- Money float64 `json:"money"`
- TradeNo string `json:"trade_no"`
- CreateTime int64 `json:"create_time"`
- Status string `json:"status"`
+ Id int `json:"id"`
+ UserId int `json:"user_id" gorm:"index"`
+ Amount int64 `json:"amount"`
+ Money float64 `json:"money"`
+ TradeNo string `json:"trade_no" gorm:"unique"`
+ CreateTime int64 `json:"create_time"`
+ CompleteTime int64 `json:"complete_time"`
+ Status string `json:"status"`
}
func (topUp *TopUp) Insert() error {
@@ -41,3 +49,51 @@ func GetTopUpByTradeNo(tradeNo string) *TopUp {
}
return topUp
}
+
+func Recharge(referenceId string, customerId string) (err error) {
+ if referenceId == "" {
+ return errors.New("未提供支付单号")
+ }
+
+ var quota float64
+ topUp := &TopUp{}
+
+ refCol := "`trade_no`"
+ if common.UsingPostgreSQL {
+ refCol = `"trade_no"`
+ }
+
+ err = DB.Transaction(func(tx *gorm.DB) error {
+ err := tx.Set("gorm:query_option", "FOR UPDATE").Where(refCol+" = ?", referenceId).First(topUp).Error
+ if err != nil {
+ return errors.New("充值订单不存在")
+ }
+
+ if topUp.Status != common.TopUpStatusPending {
+ return errors.New("充值订单状态错误")
+ }
+
+ topUp.CompleteTime = common.GetTimestamp()
+ topUp.Status = common.TopUpStatusSuccess
+ err = tx.Save(topUp).Error
+ if err != nil {
+ return err
+ }
+
+ quota = topUp.Money * common.QuotaPerUnit
+ err = tx.Model(&User{}).Where("id = ?", topUp.UserId).Updates(map[string]interface{}{"stripe_customer": customerId, "quota": gorm.Expr("quota + ?", quota)}).Error
+ if err != nil {
+ return err
+ }
+
+ return nil
+ })
+
+ if err != nil {
+ return errors.New("充值失败," + err.Error())
+ }
+
+ RecordLog(topUp.UserId, LogTypeTopup, fmt.Sprintf("使用在线充值成功,充值金额: %v,支付金额:%d", common.FormatQuota(int(quota)), topUp.Amount))
+
+ return nil
+}
diff --git a/model/user.go b/model/user.go
index 6bb5a867..6021f495 100644
--- a/model/user.go
+++ b/model/user.go
@@ -43,6 +43,7 @@ type User struct {
LinuxDOId string `json:"linux_do_id" gorm:"column:linux_do_id;index"`
Setting string `json:"setting" gorm:"type:text;column:setting"`
Remark string `json:"remark,omitempty" gorm:"type:varchar(255)" validate:"max=255"`
+ StripeCustomer string `json:"stripe_customer" gorm:"type:varchar(64);column:stripe_customer;index"`
}
func (user *User) ToBaseUser() *UserBase {
diff --git a/relay/channel/api_request.go b/relay/channel/api_request.go
index 97887266..ff7c63fa 100644
--- a/relay/channel/api_request.go
+++ b/relay/channel/api_request.go
@@ -203,6 +203,9 @@ func sendPingData(c *gin.Context, mutex *sync.Mutex) error {
}
}
+func DoRequest(c *gin.Context, req *http.Request, info *common.RelayInfo) (*http.Response, error) {
+ return doRequest(c, req, info)
+}
func doRequest(c *gin.Context, req *http.Request, info *common.RelayInfo) (*http.Response, error) {
var client *http.Client
var err error
diff --git a/relay/channel/jimeng/adaptor.go b/relay/channel/jimeng/adaptor.go
new file mode 100644
index 00000000..0b743879
--- /dev/null
+++ b/relay/channel/jimeng/adaptor.go
@@ -0,0 +1,136 @@
+package jimeng
+
+import (
+ "encoding/json"
+ "errors"
+ "fmt"
+ "github.com/gin-gonic/gin"
+ "io"
+ "net/http"
+ "one-api/dto"
+ "one-api/relay/channel"
+ "one-api/relay/channel/openai"
+ relaycommon "one-api/relay/common"
+ relayconstant "one-api/relay/constant"
+ "one-api/types"
+)
+
+type Adaptor struct {
+}
+
+func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
+ return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
+}
+
+func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
+ return fmt.Sprintf("%s/?Action=CVProcess&Version=2022-08-31", info.BaseUrl), nil
+}
+
+func (a *Adaptor) SetupRequestHeader(c *gin.Context, header *http.Header, info *relaycommon.RelayInfo) error {
+ return errors.New("not implemented")
+}
+
+func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
+ if request == nil {
+ return nil, errors.New("request is nil")
+ }
+ return request, nil
+}
+
+type LogoInfo struct {
+ AddLogo bool `json:"add_logo,omitempty"`
+ Position int `json:"position,omitempty"`
+ Language int `json:"language,omitempty"`
+ Opacity float64 `json:"opacity,omitempty"`
+ LogoTextContent string `json:"logo_text_content,omitempty"`
+}
+
+type imageRequestPayload struct {
+ ReqKey string `json:"req_key"` // Service identifier, fixed value: jimeng_high_aes_general_v21_L
+ Prompt string `json:"prompt"` // Prompt for image generation, supports both Chinese and English
+ Seed int64 `json:"seed,omitempty"` // Random seed, default -1 (random)
+ Width int `json:"width,omitempty"` // Image width, default 512, range [256, 768]
+ Height int `json:"height,omitempty"` // Image height, default 512, range [256, 768]
+ UsePreLLM bool `json:"use_pre_llm,omitempty"` // Enable text expansion, default true
+ UseSR bool `json:"use_sr,omitempty"` // Enable super resolution, default true
+ ReturnURL bool `json:"return_url,omitempty"` // Whether to return image URL (valid for 24 hours)
+ LogoInfo LogoInfo `json:"logo_info,omitempty"` // Watermark information
+ ImageUrls []string `json:"image_urls,omitempty"` // Image URLs for input
+ BinaryData []string `json:"binary_data_base64,omitempty"` // Base64 encoded binary data
+}
+
+func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
+ payload := imageRequestPayload{
+ ReqKey: request.Model,
+ Prompt: request.Prompt,
+ }
+ if request.ResponseFormat == "" || request.ResponseFormat == "url" {
+ payload.ReturnURL = true // Default to returning image URLs
+ }
+
+ if len(request.ExtraFields) > 0 {
+ if err := json.Unmarshal(request.ExtraFields, &payload); err != nil {
+ return nil, fmt.Errorf("failed to unmarshal extra fields: %w", err)
+ }
+ }
+
+ return payload, nil
+}
+
+func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
+ return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
+ return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
+ return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
+ return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
+ fullRequestURL, err := a.GetRequestURL(info)
+ if err != nil {
+ return nil, fmt.Errorf("get request url failed: %w", err)
+ }
+ req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
+ if err != nil {
+ return nil, fmt.Errorf("new request failed: %w", err)
+ }
+ err = Sign(c, req, info.ApiKey)
+ if err != nil {
+ return nil, fmt.Errorf("setup request header failed: %w", err)
+ }
+ resp, err := channel.DoRequest(c, req, info)
+ if err != nil {
+ return nil, fmt.Errorf("do request failed: %w", err)
+ }
+ return resp, nil
+}
+
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
+ if info.RelayMode == relayconstant.RelayModeImagesGenerations {
+ usage, err = jimengImageHandler(c, resp, info)
+ } else if info.IsStream {
+ usage, err = openai.OaiStreamHandler(c, info, resp)
+ } else {
+ usage, err = openai.OpenaiHandler(c, info, resp)
+ }
+ return
+}
+
+func (a *Adaptor) GetModelList() []string {
+ return ModelList
+}
+
+func (a *Adaptor) GetChannelName() string {
+ return ChannelName
+}
diff --git a/relay/channel/jimeng/constants.go b/relay/channel/jimeng/constants.go
new file mode 100644
index 00000000..0d1764e5
--- /dev/null
+++ b/relay/channel/jimeng/constants.go
@@ -0,0 +1,9 @@
+package jimeng
+
+const (
+ ChannelName = "jimeng"
+)
+
+var ModelList = []string{
+ "jimeng_high_aes_general_v21_L",
+}
diff --git a/relay/channel/jimeng/image.go b/relay/channel/jimeng/image.go
new file mode 100644
index 00000000..3c6a1d99
--- /dev/null
+++ b/relay/channel/jimeng/image.go
@@ -0,0 +1,89 @@
+package jimeng
+
+import (
+ "encoding/json"
+ "fmt"
+ "io"
+ "net/http"
+ "one-api/common"
+ "one-api/dto"
+ relaycommon "one-api/relay/common"
+ "one-api/types"
+
+ "github.com/gin-gonic/gin"
+)
+
+type ImageResponse struct {
+ Code int `json:"code"`
+ Message string `json:"message"`
+ Data struct {
+ BinaryDataBase64 []string `json:"binary_data_base64"`
+ ImageUrls []string `json:"image_urls"`
+ RephraseResult string `json:"rephraser_result"`
+ RequestID string `json:"request_id"`
+ // Other fields are omitted for brevity
+ } `json:"data"`
+ RequestID string `json:"request_id"`
+ Status int `json:"status"`
+ TimeElapsed string `json:"time_elapsed"`
+}
+
+func responseJimeng2OpenAIImage(_ *gin.Context, response *ImageResponse, info *relaycommon.RelayInfo) *dto.ImageResponse {
+ imageResponse := dto.ImageResponse{
+ Created: info.StartTime.Unix(),
+ }
+
+ for _, base64Data := range response.Data.BinaryDataBase64 {
+ imageResponse.Data = append(imageResponse.Data, dto.ImageData{
+ B64Json: base64Data,
+ })
+ }
+ for _, imageUrl := range response.Data.ImageUrls {
+ imageResponse.Data = append(imageResponse.Data, dto.ImageData{
+ Url: imageUrl,
+ })
+ }
+
+ return &imageResponse
+}
+
+// jimengImageHandler handles the Jimeng image generation response
+func jimengImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.Usage, *types.NewAPIError) {
+ var jimengResponse ImageResponse
+ responseBody, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed)
+ }
+ common.CloseResponseBodyGracefully(resp)
+
+ err = json.Unmarshal(responseBody, &jimengResponse)
+ if err != nil {
+ return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
+ }
+
+ // Check if the response indicates an error
+ if jimengResponse.Code != 10000 {
+ return nil, types.WithOpenAIError(types.OpenAIError{
+ Message: jimengResponse.Message,
+ Type: "jimeng_error",
+ Param: "",
+ Code: fmt.Sprintf("%d", jimengResponse.Code),
+ }, resp.StatusCode)
+ }
+
+ // Convert Jimeng response to OpenAI format
+ fullTextResponse := responseJimeng2OpenAIImage(c, &jimengResponse, info)
+ jsonResponse, err := json.Marshal(fullTextResponse)
+ if err != nil {
+ return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
+ }
+
+ c.Writer.Header().Set("Content-Type", "application/json")
+ c.Writer.WriteHeader(resp.StatusCode)
+ _, err = c.Writer.Write(jsonResponse)
+ if err != nil {
+ return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
+ }
+
+ return &dto.Usage{}, nil
+}
diff --git a/relay/channel/jimeng/sign.go b/relay/channel/jimeng/sign.go
new file mode 100644
index 00000000..c9db6630
--- /dev/null
+++ b/relay/channel/jimeng/sign.go
@@ -0,0 +1,176 @@
+package jimeng
+
+import (
+ "bytes"
+ "crypto/hmac"
+ "crypto/sha256"
+ "encoding/hex"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "github.com/gin-gonic/gin"
+ "io"
+ "net/http"
+ "net/url"
+ "one-api/common"
+ "sort"
+ "strings"
+ "time"
+)
+
+// SignRequestForJimeng 对即梦 API 请求进行签名,支持 http.Request 或 header+url+body 方式
+//func SignRequestForJimeng(req *http.Request, accessKey, secretKey string) error {
+// var bodyBytes []byte
+// var err error
+//
+// if req.Body != nil {
+// bodyBytes, err = io.ReadAll(req.Body)
+// if err != nil {
+// return fmt.Errorf("read request body failed: %w", err)
+// }
+// _ = req.Body.Close()
+// req.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) // rewind
+// } else {
+// bodyBytes = []byte{}
+// }
+//
+// return signJimengHeaders(&req.Header, req.Method, req.URL, bodyBytes, accessKey, secretKey)
+//}
+
+const HexPayloadHashKey = "HexPayloadHash"
+
+func SetPayloadHash(c *gin.Context, req any) error {
+ body, err := json.Marshal(req)
+ if err != nil {
+ return err
+ }
+ common.LogInfo(c, fmt.Sprintf("SetPayloadHash body: %s", body))
+ payloadHash := sha256.Sum256(body)
+ hexPayloadHash := hex.EncodeToString(payloadHash[:])
+ c.Set(HexPayloadHashKey, hexPayloadHash)
+ return nil
+}
+func getPayloadHash(c *gin.Context) string {
+ return c.GetString(HexPayloadHashKey)
+}
+
+func Sign(c *gin.Context, req *http.Request, apiKey string) error {
+ header := req.Header
+
+ var bodyBytes []byte
+ var err error
+
+ if req.Body != nil {
+ bodyBytes, err = io.ReadAll(req.Body)
+ if err != nil {
+ return err
+ }
+ _ = req.Body.Close()
+ req.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) // Rewind
+ }
+
+ payloadHash := sha256.Sum256(bodyBytes)
+ hexPayloadHash := hex.EncodeToString(payloadHash[:])
+
+ method := c.Request.Method
+ u := req.URL
+ keyParts := strings.Split(apiKey, "|")
+ if len(keyParts) != 2 {
+ return errors.New("invalid api key format for jimeng: expected 'ak|sk'")
+ }
+ accessKey := strings.TrimSpace(keyParts[0])
+ secretKey := strings.TrimSpace(keyParts[1])
+ t := time.Now().UTC()
+ xDate := t.Format("20060102T150405Z")
+ shortDate := t.Format("20060102")
+
+ host := u.Host
+ header.Set("Host", host)
+ header.Set("X-Date", xDate)
+ header.Set("X-Content-Sha256", hexPayloadHash)
+
+ // Sort and encode query parameters to create canonical query string
+ queryParams := u.Query()
+ sortedKeys := make([]string, 0, len(queryParams))
+ for k := range queryParams {
+ sortedKeys = append(sortedKeys, k)
+ }
+ sort.Strings(sortedKeys)
+ var queryParts []string
+ for _, k := range sortedKeys {
+ values := queryParams[k]
+ sort.Strings(values)
+ for _, v := range values {
+ queryParts = append(queryParts, fmt.Sprintf("%s=%s", url.QueryEscape(k), url.QueryEscape(v)))
+ }
+ }
+ canonicalQueryString := strings.Join(queryParts, "&")
+
+ headersToSign := map[string]string{
+ "host": host,
+ "x-date": xDate,
+ "x-content-sha256": hexPayloadHash,
+ }
+ if header.Get("Content-Type") == "" {
+ header.Set("Content-Type", "application/json")
+ }
+ headersToSign["content-type"] = header.Get("Content-Type")
+
+ var signedHeaderKeys []string
+ for k := range headersToSign {
+ signedHeaderKeys = append(signedHeaderKeys, k)
+ }
+ sort.Strings(signedHeaderKeys)
+
+ var canonicalHeaders strings.Builder
+ for _, k := range signedHeaderKeys {
+ canonicalHeaders.WriteString(k)
+ canonicalHeaders.WriteString(":")
+ canonicalHeaders.WriteString(strings.TrimSpace(headersToSign[k]))
+ canonicalHeaders.WriteString("\n")
+ }
+ signedHeaders := strings.Join(signedHeaderKeys, ";")
+
+ canonicalRequest := fmt.Sprintf("%s\n%s\n%s\n%s\n%s\n%s",
+ method,
+ u.Path,
+ canonicalQueryString,
+ canonicalHeaders.String(),
+ signedHeaders,
+ hexPayloadHash,
+ )
+
+ hashedCanonicalRequest := sha256.Sum256([]byte(canonicalRequest))
+ hexHashedCanonicalRequest := hex.EncodeToString(hashedCanonicalRequest[:])
+
+ region := "cn-north-1"
+ serviceName := "cv"
+ credentialScope := fmt.Sprintf("%s/%s/%s/request", shortDate, region, serviceName)
+ stringToSign := fmt.Sprintf("HMAC-SHA256\n%s\n%s\n%s",
+ xDate,
+ credentialScope,
+ hexHashedCanonicalRequest,
+ )
+
+ kDate := hmacSHA256([]byte(secretKey), []byte(shortDate))
+ kRegion := hmacSHA256(kDate, []byte(region))
+ kService := hmacSHA256(kRegion, []byte(serviceName))
+ kSigning := hmacSHA256(kService, []byte("request"))
+ signature := hex.EncodeToString(hmacSHA256(kSigning, []byte(stringToSign)))
+
+ authorization := fmt.Sprintf("HMAC-SHA256 Credential=%s/%s, SignedHeaders=%s, Signature=%s",
+ accessKey,
+ credentialScope,
+ signedHeaders,
+ signature,
+ )
+ header.Set("Authorization", authorization)
+ return nil
+}
+
+// hmacSHA256 计算 HMAC-SHA256
+func hmacSHA256(key []byte, data []byte) []byte {
+ h := hmac.New(sha256.New, key)
+ h.Write(data)
+ return h.Sum(nil)
+}
diff --git a/relay/image_handler.go b/relay/image_handler.go
index 44f44277..8349307f 100644
--- a/relay/image_handler.go
+++ b/relay/image_handler.go
@@ -145,22 +145,25 @@ func ImageHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
} else {
sizeRatio := 1.0
- // Size
- if imageRequest.Size == "256x256" {
- sizeRatio = 0.4
- } else if imageRequest.Size == "512x512" {
- sizeRatio = 0.45
- } else if imageRequest.Size == "1024x1024" {
- sizeRatio = 1
- } else if imageRequest.Size == "1024x1792" || imageRequest.Size == "1792x1024" {
- sizeRatio = 2
- }
-
qualityRatio := 1.0
- if imageRequest.Model == "dall-e-3" && imageRequest.Quality == "hd" {
- qualityRatio = 2.0
- if imageRequest.Size == "1024x1792" || imageRequest.Size == "1792x1024" {
- qualityRatio = 1.5
+
+ if strings.HasPrefix(imageRequest.Model, "dall-e") {
+ // Size
+ if imageRequest.Size == "256x256" {
+ sizeRatio = 0.4
+ } else if imageRequest.Size == "512x512" {
+ sizeRatio = 0.45
+ } else if imageRequest.Size == "1024x1024" {
+ sizeRatio = 1
+ } else if imageRequest.Size == "1024x1792" || imageRequest.Size == "1792x1024" {
+ sizeRatio = 2
+ }
+
+ if imageRequest.Model == "dall-e-3" && imageRequest.Quality == "hd" {
+ qualityRatio = 2.0
+ if imageRequest.Size == "1024x1792" || imageRequest.Size == "1792x1024" {
+ qualityRatio = 1.5
+ }
}
}
diff --git a/relay/relay_adaptor.go b/relay/relay_adaptor.go
index 00e59eac..2ce12a87 100644
--- a/relay/relay_adaptor.go
+++ b/relay/relay_adaptor.go
@@ -15,6 +15,7 @@ import (
"one-api/relay/channel/deepseek"
"one-api/relay/channel/dify"
"one-api/relay/channel/gemini"
+ "one-api/relay/channel/jimeng"
"one-api/relay/channel/jina"
"one-api/relay/channel/mistral"
"one-api/relay/channel/mokaai"
@@ -23,7 +24,7 @@ import (
"one-api/relay/channel/palm"
"one-api/relay/channel/perplexity"
"one-api/relay/channel/siliconflow"
- "one-api/relay/channel/task/jimeng"
+ taskjimeng "one-api/relay/channel/task/jimeng"
"one-api/relay/channel/task/kling"
"one-api/relay/channel/task/suno"
"one-api/relay/channel/tencent"
@@ -93,6 +94,8 @@ func GetAdaptor(apiType int) channel.Adaptor {
return &xai.Adaptor{}
case constant.APITypeCoze:
return &coze.Adaptor{}
+ case constant.APITypeJimeng:
+ return &jimeng.Adaptor{}
}
return nil
}
@@ -106,7 +109,7 @@ func GetTaskAdaptor(platform commonconstant.TaskPlatform) channel.TaskAdaptor {
case commonconstant.TaskPlatformKling:
return &kling.TaskAdaptor{}
case commonconstant.TaskPlatformJimeng:
- return &jimeng.TaskAdaptor{}
+ return &taskjimeng.TaskAdaptor{}
}
return nil
}
diff --git a/relay/relay_task.go b/relay/relay_task.go
index ce6b93ce..25f63d40 100644
--- a/relay/relay_task.go
+++ b/relay/relay_task.go
@@ -37,9 +37,9 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
return
}
- modelName := service.CoverTaskActionToModelName(platform, relayInfo.Action)
- if platform == constant.TaskPlatformKling {
- modelName = relayInfo.OriginModelName
+ modelName := relayInfo.OriginModelName
+ if modelName == "" {
+ modelName = service.CoverTaskActionToModelName(platform, relayInfo.Action)
}
modelPrice, success := ratio_setting.GetModelPrice(modelName, true)
if !success {
diff --git a/router/api-router.go b/router/api-router.go
index 4bd2faff..bc49803a 100644
--- a/router/api-router.go
+++ b/router/api-router.go
@@ -38,6 +38,8 @@ func SetApiRouter(router *gin.Engine) {
apiRouter.GET("/oauth/telegram/bind", middleware.CriticalRateLimit(), controller.TelegramBind)
apiRouter.GET("/ratio_config", middleware.CriticalRateLimit(), controller.GetRatioConfig)
+ apiRouter.POST("/stripe/webhook", controller.StripeWebhook)
+
userRoute := apiRouter.Group("/user")
{
userRoute.POST("/register", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.Register)
@@ -57,9 +59,11 @@ func SetApiRouter(router *gin.Engine) {
selfRoute.DELETE("/self", controller.DeleteSelf)
selfRoute.GET("/token", controller.GenerateAccessToken)
selfRoute.GET("/aff", controller.GetAffCode)
- selfRoute.POST("/topup", controller.TopUp)
- selfRoute.POST("/pay", controller.RequestEpay)
+ selfRoute.POST("/topup", middleware.CriticalRateLimit(), controller.TopUp)
+ selfRoute.POST("/pay", middleware.CriticalRateLimit(), controller.RequestEpay)
selfRoute.POST("/amount", controller.RequestAmount)
+ selfRoute.POST("/stripe/pay", middleware.CriticalRateLimit(), controller.RequestStripePay)
+ selfRoute.POST("/stripe/amount", controller.RequestStripeAmount)
selfRoute.POST("/aff_transfer", controller.TransferAffQuota)
selfRoute.PUT("/setting", controller.UpdateUserSetting)
}
diff --git a/setting/payment_stripe.go b/setting/payment_stripe.go
new file mode 100644
index 00000000..80d877df
--- /dev/null
+++ b/setting/payment_stripe.go
@@ -0,0 +1,7 @@
+package setting
+
+var StripeApiSecret = ""
+var StripeWebhookSecret = ""
+var StripePriceId = ""
+var StripeUnitPrice = 8.0
+var StripeMinTopUp = 1
diff --git a/setting/user_usable_group.go b/setting/user_usable_group.go
index fdf2f723..0ae132d0 100644
--- a/setting/user_usable_group.go
+++ b/setting/user_usable_group.go
@@ -3,14 +3,19 @@ package setting
import (
"encoding/json"
"one-api/common"
+ "sync"
)
var userUsableGroups = map[string]string{
"default": "默认分组",
"vip": "vip分组",
}
+var userUsableGroupsMutex sync.RWMutex
func GetUserUsableGroupsCopy() map[string]string {
+ userUsableGroupsMutex.RLock()
+ defer userUsableGroupsMutex.RUnlock()
+
copyUserUsableGroups := make(map[string]string)
for k, v := range userUsableGroups {
copyUserUsableGroups[k] = v
@@ -19,6 +24,9 @@ func GetUserUsableGroupsCopy() map[string]string {
}
func UserUsableGroups2JSONString() string {
+ userUsableGroupsMutex.RLock()
+ defer userUsableGroupsMutex.RUnlock()
+
jsonBytes, err := json.Marshal(userUsableGroups)
if err != nil {
common.SysError("error marshalling user groups: " + err.Error())
@@ -27,6 +35,9 @@ func UserUsableGroups2JSONString() string {
}
func UpdateUserUsableGroupsByJSONString(jsonStr string) error {
+ userUsableGroupsMutex.Lock()
+ defer userUsableGroupsMutex.Unlock()
+
userUsableGroups = make(map[string]string)
return json.Unmarshal([]byte(jsonStr), &userUsableGroups)
}
@@ -47,11 +58,17 @@ func GetUserUsableGroups(userGroup string) map[string]string {
}
func GroupInUserUsableGroups(groupName string) bool {
+ userUsableGroupsMutex.RLock()
+ defer userUsableGroupsMutex.RUnlock()
+
_, ok := userUsableGroups[groupName]
return ok
}
func GetUsableGroupDescription(groupName string) string {
+ userUsableGroupsMutex.RLock()
+ defer userUsableGroupsMutex.RUnlock()
+
if desc, ok := userUsableGroups[groupName]; ok {
return desc
}
diff --git a/web/src/components/settings/PaymentSetting.js b/web/src/components/settings/PaymentSetting.js
index a65a57b7..ed175a20 100644
--- a/web/src/components/settings/PaymentSetting.js
+++ b/web/src/components/settings/PaymentSetting.js
@@ -2,6 +2,7 @@ import React, { useEffect, useState } from 'react';
import { Card, Spin } from '@douyinfe/semi-ui';
import SettingsGeneralPayment from '../../pages/Setting/Payment/SettingsGeneralPayment.js';
import SettingsPaymentGateway from '../../pages/Setting/Payment/SettingsPaymentGateway.js';
+import SettingsPaymentGatewayStripe from '../../pages/Setting/Payment/SettingsPaymentGatewayStripe.js';
import { API, showError, toBoolean } from '../../helpers';
import { useTranslation } from 'react-i18next';
@@ -17,6 +18,12 @@ const PaymentSetting = () => {
TopupGroupRatio: '',
CustomCallbackAddress: '',
PayMethods: '',
+
+ StripeApiSecret: '',
+ StripeWebhookSecret: '',
+ StripePriceId: '',
+ StripeUnitPrice: 8.0,
+ StripeMinTopUp: 1,
});
let [loading, setLoading] = useState(false);
@@ -38,6 +45,8 @@ const PaymentSetting = () => {
break;
case 'Price':
case 'MinTopUp':
+ case 'StripeUnitPrice':
+ case 'StripeMinTopUp':
newInputs[item.key] = parseFloat(item.value);
break;
default:
@@ -80,6 +89,9 @@ const PaymentSetting = () => {
+ {t('充值数量')}:{stripeTopUpCount}
+
+ {t('实付金额')}:{renderStripeAmount()}
+ {t('是否确认充值?')}
+