first commit: one-api base code + SAAS plan document
Some checks failed
CI / Unit tests (push) Has been cancelled
CI / commit_lint (push) Has been cancelled

This commit is contained in:
huangzhenpc
2025-12-29 22:52:27 +08:00
commit cb7c48bfa7
564 changed files with 61468 additions and 0 deletions

29
common/blacklist/main.go Normal file
View File

@@ -0,0 +1,29 @@
package blacklist
import (
"fmt"
"sync"
)
var blackList sync.Map
func init() {
blackList = sync.Map{}
}
func userId2Key(id int) string {
return fmt.Sprintf("userid_%d", id)
}
func BanUser(id int) {
blackList.Store(userId2Key(id), true)
}
func UnbanUser(id int) {
blackList.Delete(userId2Key(id))
}
func IsUserBanned(id int) bool {
_, ok := blackList.Load(userId2Key(id))
return ok
}

60
common/client/init.go Normal file
View File

@@ -0,0 +1,60 @@
package client
import (
"fmt"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger"
"net/http"
"net/url"
"time"
)
var HTTPClient *http.Client
var ImpatientHTTPClient *http.Client
var UserContentRequestHTTPClient *http.Client
func Init() {
if config.UserContentRequestProxy != "" {
logger.SysLog(fmt.Sprintf("using %s as proxy to fetch user content", config.UserContentRequestProxy))
proxyURL, err := url.Parse(config.UserContentRequestProxy)
if err != nil {
logger.FatalLog(fmt.Sprintf("USER_CONTENT_REQUEST_PROXY set but invalid: %s", config.UserContentRequestProxy))
}
transport := &http.Transport{
Proxy: http.ProxyURL(proxyURL),
}
UserContentRequestHTTPClient = &http.Client{
Transport: transport,
Timeout: time.Second * time.Duration(config.UserContentRequestTimeout),
}
} else {
UserContentRequestHTTPClient = &http.Client{}
}
var transport http.RoundTripper
if config.RelayProxy != "" {
logger.SysLog(fmt.Sprintf("using %s as api relay proxy", config.RelayProxy))
proxyURL, err := url.Parse(config.RelayProxy)
if err != nil {
logger.FatalLog(fmt.Sprintf("USER_CONTENT_REQUEST_PROXY set but invalid: %s", config.UserContentRequestProxy))
}
transport = &http.Transport{
Proxy: http.ProxyURL(proxyURL),
}
}
if config.RelayTimeout == 0 {
HTTPClient = &http.Client{
Transport: transport,
}
} else {
HTTPClient = &http.Client{
Timeout: time.Duration(config.RelayTimeout) * time.Second,
Transport: transport,
}
}
ImpatientHTTPClient = &http.Client{
Timeout: 5 * time.Second,
Transport: transport,
}
}

166
common/config/config.go Normal file
View File

@@ -0,0 +1,166 @@
package config
import (
"os"
"strconv"
"strings"
"sync"
"time"
"github.com/songquanpeng/one-api/common/env"
"github.com/google/uuid"
)
var SystemName = "One API"
var ServerAddress = "http://localhost:3000"
var Footer = ""
var Logo = ""
var TopUpLink = ""
var ChatLink = ""
var QuotaPerUnit = 500 * 1000.0 // $0.002 / 1K tokens
var DisplayInCurrencyEnabled = true
var DisplayTokenStatEnabled = true
// Any options with "Secret", "Token" in its key won't be return by GetOptions
var SessionSecret = uuid.New().String()
var OptionMap map[string]string
var OptionMapRWMutex sync.RWMutex
var ItemsPerPage = 10
var MaxRecentItems = 100
var PasswordLoginEnabled = true
var PasswordRegisterEnabled = true
var EmailVerificationEnabled = false
var GitHubOAuthEnabled = false
var OidcEnabled = false
var WeChatAuthEnabled = false
var TurnstileCheckEnabled = false
var RegisterEnabled = true
var EmailDomainRestrictionEnabled = false
var EmailDomainWhitelist = []string{
"gmail.com",
"163.com",
"126.com",
"qq.com",
"outlook.com",
"hotmail.com",
"icloud.com",
"yahoo.com",
"foxmail.com",
}
var DebugEnabled = strings.ToLower(os.Getenv("DEBUG")) == "true"
var DebugSQLEnabled = strings.ToLower(os.Getenv("DEBUG_SQL")) == "true"
var MemoryCacheEnabled = strings.ToLower(os.Getenv("MEMORY_CACHE_ENABLED")) == "true"
var LogConsumeEnabled = true
var SMTPServer = ""
var SMTPPort = 587
var SMTPAccount = ""
var SMTPFrom = ""
var SMTPToken = ""
var GitHubClientId = ""
var GitHubClientSecret = ""
var LarkClientId = ""
var LarkClientSecret = ""
var OidcClientId = ""
var OidcClientSecret = ""
var OidcWellKnown = ""
var OidcAuthorizationEndpoint = ""
var OidcTokenEndpoint = ""
var OidcUserinfoEndpoint = ""
var WeChatServerAddress = ""
var WeChatServerToken = ""
var WeChatAccountQRCodeImageURL = ""
var MessagePusherAddress = ""
var MessagePusherToken = ""
var TurnstileSiteKey = ""
var TurnstileSecretKey = ""
var QuotaForNewUser int64 = 0
var QuotaForInviter int64 = 0
var QuotaForInvitee int64 = 0
var ChannelDisableThreshold = 5.0
var AutomaticDisableChannelEnabled = false
var AutomaticEnableChannelEnabled = false
var QuotaRemindThreshold int64 = 1000
var PreConsumedQuota int64 = 500
var ApproximateTokenEnabled = false
var RetryTimes = 0
var RootUserEmail = ""
var IsMasterNode = os.Getenv("NODE_TYPE") != "slave"
var requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL"))
var RequestInterval = time.Duration(requestInterval) * time.Second
var SyncFrequency = env.Int("SYNC_FREQUENCY", 10*60) // unit is second
var BatchUpdateEnabled = false
var BatchUpdateInterval = env.Int("BATCH_UPDATE_INTERVAL", 5)
var RelayTimeout = env.Int("RELAY_TIMEOUT", 0) // unit is second
var GeminiSafetySetting = env.String("GEMINI_SAFETY_SETTING", "BLOCK_NONE")
var Theme = env.String("THEME", "default")
var ValidThemes = map[string]bool{
"default": true,
"berry": true,
"air": true,
}
// All duration's unit is seconds
// Shouldn't larger then RateLimitKeyExpirationDuration
var (
GlobalApiRateLimitNum = env.Int("GLOBAL_API_RATE_LIMIT", 480)
GlobalApiRateLimitDuration int64 = 3 * 60
GlobalWebRateLimitNum = env.Int("GLOBAL_WEB_RATE_LIMIT", 240)
GlobalWebRateLimitDuration int64 = 3 * 60
UploadRateLimitNum = 10
UploadRateLimitDuration int64 = 60
DownloadRateLimitNum = 10
DownloadRateLimitDuration int64 = 60
CriticalRateLimitNum = 20
CriticalRateLimitDuration int64 = 20 * 60
)
var RateLimitKeyExpirationDuration = 20 * time.Minute
var EnableMetric = env.Bool("ENABLE_METRIC", false)
var MetricQueueSize = env.Int("METRIC_QUEUE_SIZE", 10)
var MetricSuccessRateThreshold = env.Float64("METRIC_SUCCESS_RATE_THRESHOLD", 0.8)
var MetricSuccessChanSize = env.Int("METRIC_SUCCESS_CHAN_SIZE", 1024)
var MetricFailChanSize = env.Int("METRIC_FAIL_CHAN_SIZE", 128)
var InitialRootToken = os.Getenv("INITIAL_ROOT_TOKEN")
var InitialRootAccessToken = os.Getenv("INITIAL_ROOT_ACCESS_TOKEN")
var GeminiVersion = env.String("GEMINI_VERSION", "v1")
var OnlyOneLogFile = env.Bool("ONLY_ONE_LOG_FILE", false)
var RelayProxy = env.String("RELAY_PROXY", "")
var UserContentRequestProxy = env.String("USER_CONTENT_REQUEST_PROXY", "")
var UserContentRequestTimeout = env.Int("USER_CONTENT_REQUEST_TIMEOUT", 30)
var EnforceIncludeUsage = env.Bool("ENFORCE_INCLUDE_USAGE", false)
var TestPrompt = env.String("TEST_PROMPT", "Output only your specific model name with no additional text.")

6
common/constants.go Normal file
View File

@@ -0,0 +1,6 @@
package common
import "time"
var StartTime = time.Now().Unix() // unit: second
var Version = "v0.0.0" // this hard coding will be replaced automatically when building, no need to manually change

6
common/conv/any.go Normal file
View File

@@ -0,0 +1,6 @@
package conv
func AsString(v any) string {
str, _ := v.(string)
return str
}

14
common/crypto.go Normal file
View File

@@ -0,0 +1,14 @@
package common
import "golang.org/x/crypto/bcrypt"
func Password2Hash(password string) (string, error) {
passwordBytes := []byte(password)
hashedPassword, err := bcrypt.GenerateFromPassword(passwordBytes, bcrypt.DefaultCost)
return string(hashedPassword), err
}
func ValidatePasswordAndHash(password string, hash string) bool {
err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(password))
return err == nil
}

24
common/ctxkey/key.go Normal file
View File

@@ -0,0 +1,24 @@
package ctxkey
const (
Config = "config"
Id = "id"
Username = "username"
Role = "role"
Status = "status"
Channel = "channel"
ChannelId = "channel_id"
SpecificChannelId = "specific_channel_id"
RequestModel = "request_model"
ConvertedRequest = "converted_request"
OriginalModel = "original_model"
Group = "group"
ModelMapping = "model_mapping"
ChannelName = "channel_name"
TokenId = "token_id"
TokenName = "token_name"
BaseURL = "base_url"
AvailableModels = "available_models"
KeyRequestBody = "key_request_body"
SystemPrompt = "system_prompt"
)

82
common/custom-event.go Normal file
View File

@@ -0,0 +1,82 @@
// Copyright 2014 Manu Martinez-Almeida. All rights reserved.
// Use of this source code is governed by a MIT style
// license that can be found in the LICENSE file.
package common
import (
"fmt"
"io"
"net/http"
"strings"
)
type stringWriter interface {
io.Writer
writeString(string) (int, error)
}
type stringWrapper struct {
io.Writer
}
func (w stringWrapper) writeString(str string) (int, error) {
return w.Writer.Write([]byte(str))
}
func checkWriter(writer io.Writer) stringWriter {
if w, ok := writer.(stringWriter); ok {
return w
} else {
return stringWrapper{writer}
}
}
// Server-Sent Events
// W3C Working Draft 29 October 2009
// http://www.w3.org/TR/2009/WD-eventsource-20091029/
var contentType = []string{"text/event-stream"}
var noCache = []string{"no-cache"}
var fieldReplacer = strings.NewReplacer(
"\n", "\\n",
"\r", "\\r")
var dataReplacer = strings.NewReplacer(
"\n", "\ndata:",
"\r", "\\r")
type CustomEvent struct {
Event string
Id string
Retry uint
Data interface{}
}
func encode(writer io.Writer, event CustomEvent) error {
w := checkWriter(writer)
return writeData(w, event.Data)
}
func writeData(w stringWriter, data interface{}) error {
dataReplacer.WriteString(w, fmt.Sprint(data))
if strings.HasPrefix(data.(string), "data") {
w.writeString("\n\n")
}
return nil
}
func (r CustomEvent) Render(w http.ResponseWriter) error {
r.WriteContentType(w)
return encode(w, r)
}
func (r CustomEvent) WriteContentType(w http.ResponseWriter) {
header := w.Header()
header["Content-Type"] = contentType
if _, exist := header["Cache-Control"]; !exist {
header["Cache-Control"] = noCache
}
}

12
common/database.go Normal file
View File

@@ -0,0 +1,12 @@
package common
import (
"github.com/songquanpeng/one-api/common/env"
)
var UsingSQLite = false
var UsingPostgreSQL = false
var UsingMySQL = false
var SQLitePath = "one-api.db"
var SQLiteBusyTimeout = env.Int("SQLITE_BUSY_TIMEOUT", 3000)

View File

@@ -0,0 +1,29 @@
package common
import (
"embed"
"github.com/gin-contrib/static"
"io/fs"
"net/http"
)
// Credit: https://github.com/gin-contrib/static/issues/19
type embedFileSystem struct {
http.FileSystem
}
func (e embedFileSystem) Exists(prefix string, path string) bool {
_, err := e.Open(path)
return err == nil
}
func EmbedFolder(fsEmbed embed.FS, targetPath string) static.ServeFileSystem {
efs, err := fs.Sub(fsEmbed, targetPath)
if err != nil {
panic(err)
}
return embedFileSystem{
FileSystem: http.FS(efs),
}
}

42
common/env/helper.go vendored Normal file
View File

@@ -0,0 +1,42 @@
package env
import (
"os"
"strconv"
)
func Bool(env string, defaultValue bool) bool {
if env == "" || os.Getenv(env) == "" {
return defaultValue
}
return os.Getenv(env) == "true"
}
func Int(env string, defaultValue int) int {
if env == "" || os.Getenv(env) == "" {
return defaultValue
}
num, err := strconv.Atoi(os.Getenv(env))
if err != nil {
return defaultValue
}
return num
}
func Float64(env string, defaultValue float64) float64 {
if env == "" || os.Getenv(env) == "" {
return defaultValue
}
num, err := strconv.ParseFloat(os.Getenv(env), 64)
if err != nil {
return defaultValue
}
return num
}
func String(env string, defaultValue string) string {
if env == "" || os.Getenv(env) == "" {
return defaultValue
}
return os.Getenv(env)
}

53
common/gin.go Normal file
View File

@@ -0,0 +1,53 @@
package common
import (
"bytes"
"encoding/json"
"io"
"strings"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/ctxkey"
)
func GetRequestBody(c *gin.Context) ([]byte, error) {
requestBody, _ := c.Get(ctxkey.KeyRequestBody)
if requestBody != nil {
return requestBody.([]byte), nil
}
requestBody, err := io.ReadAll(c.Request.Body)
if err != nil {
return nil, err
}
_ = c.Request.Body.Close()
c.Set(ctxkey.KeyRequestBody, requestBody)
return requestBody.([]byte), nil
}
func UnmarshalBodyReusable(c *gin.Context, v any) error {
requestBody, err := GetRequestBody(c)
if err != nil {
return err
}
contentType := c.Request.Header.Get("Content-Type")
if strings.HasPrefix(contentType, "application/json") {
err = json.Unmarshal(requestBody, &v)
} else {
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
err = c.ShouldBind(&v)
}
if err != nil {
return err
}
// Reset request body
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
return nil
}
func SetEventStreamHeaders(c *gin.Context) {
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.Header().Set("Transfer-Encoding", "chunked")
c.Writer.Header().Set("X-Accel-Buffering", "no")
}

174
common/helper/helper.go Normal file
View File

@@ -0,0 +1,174 @@
package helper
import (
"context"
"fmt"
"html/template"
"log"
"net"
"os/exec"
"runtime"
"strconv"
"strings"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/random"
)
func OpenBrowser(url string) {
var err error
switch runtime.GOOS {
case "linux":
err = exec.Command("xdg-open", url).Start()
case "windows":
err = exec.Command("rundll32", "url.dll,FileProtocolHandler", url).Start()
case "darwin":
err = exec.Command("open", url).Start()
}
if err != nil {
log.Println(err)
}
}
func GetIp() (ip string) {
ips, err := net.InterfaceAddrs()
if err != nil {
log.Println(err)
return ip
}
for _, a := range ips {
if ipNet, ok := a.(*net.IPNet); ok && !ipNet.IP.IsLoopback() {
if ipNet.IP.To4() != nil {
ip = ipNet.IP.String()
if strings.HasPrefix(ip, "10") {
return
}
if strings.HasPrefix(ip, "172") {
return
}
if strings.HasPrefix(ip, "192.168") {
return
}
ip = ""
}
}
}
return
}
var sizeKB = 1024
var sizeMB = sizeKB * 1024
var sizeGB = sizeMB * 1024
func Bytes2Size(num int64) string {
numStr := ""
unit := "B"
if num/int64(sizeGB) > 1 {
numStr = fmt.Sprintf("%.2f", float64(num)/float64(sizeGB))
unit = "GB"
} else if num/int64(sizeMB) > 1 {
numStr = fmt.Sprintf("%d", int(float64(num)/float64(sizeMB)))
unit = "MB"
} else if num/int64(sizeKB) > 1 {
numStr = fmt.Sprintf("%d", int(float64(num)/float64(sizeKB)))
unit = "KB"
} else {
numStr = fmt.Sprintf("%d", num)
}
return numStr + " " + unit
}
func Interface2String(inter interface{}) string {
switch inter := inter.(type) {
case string:
return inter
case int:
return fmt.Sprintf("%d", inter)
case float64:
return fmt.Sprintf("%f", inter)
}
return "Not Implemented"
}
func UnescapeHTML(x string) interface{} {
return template.HTML(x)
}
func IntMax(a int, b int) int {
if a >= b {
return a
} else {
return b
}
}
func GenRequestID() string {
return GetTimeString() + random.GetRandomNumberString(8)
}
func SetRequestID(ctx context.Context, id string) context.Context {
return context.WithValue(ctx, RequestIdKey, id)
}
func GetRequestID(ctx context.Context) string {
rawRequestId := ctx.Value(RequestIdKey)
if rawRequestId == nil {
return ""
}
return rawRequestId.(string)
}
func GetResponseID(c *gin.Context) string {
logID := c.GetString(RequestIdKey)
return fmt.Sprintf("chatcmpl-%s", logID)
}
func Max(a int, b int) int {
if a >= b {
return a
} else {
return b
}
}
func AssignOrDefault(value string, defaultValue string) string {
if len(value) != 0 {
return value
}
return defaultValue
}
func MessageWithRequestId(message string, id string) string {
return fmt.Sprintf("%s (request id: %s)", message, id)
}
func String2Int(str string) int {
num, err := strconv.Atoi(str)
if err != nil {
return 0
}
return num
}
func Float64PtrMax(p *float64, maxValue float64) *float64 {
if p == nil {
return nil
}
if *p > maxValue {
return &maxValue
}
return p
}
func Float64PtrMin(p *float64, minValue float64) *float64 {
if p == nil {
return nil
}
if *p < minValue {
return &minValue
}
return p
}

5
common/helper/key.go Normal file
View File

@@ -0,0 +1,5 @@
package helper
const (
RequestIdKey = "X-Oneapi-Request-Id"
)

20
common/helper/time.go Normal file
View File

@@ -0,0 +1,20 @@
package helper
import (
"fmt"
"time"
)
func GetTimestamp() int64 {
return time.Now().Unix()
}
func GetTimeString() string {
now := time.Now()
return fmt.Sprintf("%s%d", now.Format("20060102150405"), now.UnixNano()%1e9)
}
// CalcElapsedTime return the elapsed time in milliseconds (ms)
func CalcElapsedTime(start time.Time) int64 {
return time.Now().Sub(start).Milliseconds()
}

72
common/i18n/i18n.go Normal file
View File

@@ -0,0 +1,72 @@
package i18n
import (
"embed"
"encoding/json"
"strings"
"github.com/gin-gonic/gin"
)
//go:embed locales/*.json
var localesFS embed.FS
var (
translations = make(map[string]map[string]string)
defaultLang = "en"
ContextKey = "i18n"
)
// Init loads all translation files from embedded filesystem
func Init() error {
entries, err := localesFS.ReadDir("locales")
if err != nil {
return err
}
for _, entry := range entries {
if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".json") {
continue
}
langCode := strings.TrimSuffix(entry.Name(), ".json")
content, err := localesFS.ReadFile("locales/" + entry.Name())
if err != nil {
return err
}
var translation map[string]string
if err := json.Unmarshal(content, &translation); err != nil {
return err
}
translations[langCode] = translation
}
return nil
}
func GetLang(c *gin.Context) string {
rawLang, ok := c.Get(ContextKey)
if !ok {
return defaultLang
}
lang, _ := rawLang.(string)
if lang != "" {
return lang
}
return defaultLang
}
func Translate(c *gin.Context, message string) string {
lang := GetLang(c)
return translateHelper(lang, message)
}
func translateHelper(lang, message string) string {
if trans, ok := translations[lang]; ok {
if translated, exists := trans[message]; exists {
return translated
}
}
return message
}

View File

@@ -0,0 +1,5 @@
{
"invalid_input": "Invalid input, please check your input",
"send_email_failed": "failed to send email: ",
"invalid_parameter": "invalid parameter"
}

View File

@@ -0,0 +1,5 @@
{
"invalid_input": "无效的输入,请检查您的输入",
"send_email_failed": "发送邮件失败:",
"invalid_parameter": "无效的参数"
}

112
common/image/image.go Normal file
View File

@@ -0,0 +1,112 @@
package image
import (
"bytes"
"encoding/base64"
"github.com/songquanpeng/one-api/common/client"
"image"
_ "image/gif"
_ "image/jpeg"
_ "image/png"
"net/http"
"regexp"
"strings"
"sync"
_ "golang.org/x/image/webp"
)
// Regex to match data URL pattern
var dataURLPattern = regexp.MustCompile(`data:image/([^;]+);base64,(.*)`)
func IsImageUrl(url string) (bool, error) {
resp, err := client.UserContentRequestHTTPClient.Head(url)
if err != nil {
return false, err
}
if !strings.HasPrefix(resp.Header.Get("Content-Type"), "image/") {
return false, nil
}
return true, nil
}
func GetImageSizeFromUrl(url string) (width int, height int, err error) {
isImage, err := IsImageUrl(url)
if !isImage {
return
}
resp, err := client.UserContentRequestHTTPClient.Get(url)
if err != nil {
return
}
defer resp.Body.Close()
img, _, err := image.DecodeConfig(resp.Body)
if err != nil {
return
}
return img.Width, img.Height, nil
}
func GetImageFromUrl(url string) (mimeType string, data string, err error) {
// Check if the URL is a data URL
matches := dataURLPattern.FindStringSubmatch(url)
if len(matches) == 3 {
// URL is a data URL
mimeType = "image/" + matches[1]
data = matches[2]
return
}
isImage, err := IsImageUrl(url)
if !isImage {
return
}
resp, err := http.Get(url)
if err != nil {
return
}
defer resp.Body.Close()
buffer := bytes.NewBuffer(nil)
_, err = buffer.ReadFrom(resp.Body)
if err != nil {
return
}
mimeType = resp.Header.Get("Content-Type")
data = base64.StdEncoding.EncodeToString(buffer.Bytes())
return
}
var (
reg = regexp.MustCompile(`data:image/([^;]+);base64,`)
)
var readerPool = sync.Pool{
New: func() interface{} {
return &bytes.Reader{}
},
}
func GetImageSizeFromBase64(encoded string) (width int, height int, err error) {
decoded, err := base64.StdEncoding.DecodeString(reg.ReplaceAllString(encoded, ""))
if err != nil {
return 0, 0, err
}
reader := readerPool.Get().(*bytes.Reader)
defer readerPool.Put(reader)
reader.Reset(decoded)
img, _, err := image.DecodeConfig(reader)
if err != nil {
return 0, 0, err
}
return img.Width, img.Height, nil
}
func GetImageSize(image string) (width int, height int, err error) {
if strings.HasPrefix(image, "data:image/") {
return GetImageSizeFromBase64(image)
}
return GetImageSizeFromUrl(image)
}

177
common/image/image_test.go Normal file
View File

@@ -0,0 +1,177 @@
package image_test
import (
"encoding/base64"
"github.com/songquanpeng/one-api/common/client"
"image"
_ "image/gif"
_ "image/jpeg"
_ "image/png"
"io"
"net/http"
"strconv"
"strings"
"testing"
img "github.com/songquanpeng/one-api/common/image"
"github.com/stretchr/testify/assert"
_ "golang.org/x/image/webp"
)
type CountingReader struct {
reader io.Reader
BytesRead int
}
func (r *CountingReader) Read(p []byte) (n int, err error) {
n, err = r.reader.Read(p)
r.BytesRead += n
return n, err
}
var (
cases = []struct {
url string
format string
width int
height int
}{
{"https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg", "jpeg", 2560, 1669},
{"https://upload.wikimedia.org/wikipedia/commons/9/97/Basshunter_live_performances.png", "png", 4500, 2592},
{"https://upload.wikimedia.org/wikipedia/commons/c/c6/TO_THE_ONE_SOMETHINGNESS.webp", "webp", 984, 985},
{"https://upload.wikimedia.org/wikipedia/commons/d/d0/01_Das_Sandberg-Modell.gif", "gif", 1917, 1533},
{"https://upload.wikimedia.org/wikipedia/commons/6/62/102Cervus.jpg", "jpeg", 270, 230},
}
)
func TestMain(m *testing.M) {
client.Init()
m.Run()
}
func TestDecode(t *testing.T) {
// Bytes read: varies sometimes
// jpeg: 1063892
// png: 294462
// webp: 99529
// gif: 956153
// jpeg#01: 32805
for _, c := range cases {
t.Run("Decode:"+c.format, func(t *testing.T) {
resp, err := http.Get(c.url)
assert.NoError(t, err)
defer resp.Body.Close()
reader := &CountingReader{reader: resp.Body}
img, format, err := image.Decode(reader)
assert.NoError(t, err)
size := img.Bounds().Size()
assert.Equal(t, c.format, format)
assert.Equal(t, c.width, size.X)
assert.Equal(t, c.height, size.Y)
t.Logf("Bytes read: %d", reader.BytesRead)
})
}
// Bytes read:
// jpeg: 4096
// png: 4096
// webp: 4096
// gif: 4096
// jpeg#01: 4096
for _, c := range cases {
t.Run("DecodeConfig:"+c.format, func(t *testing.T) {
resp, err := http.Get(c.url)
assert.NoError(t, err)
defer resp.Body.Close()
reader := &CountingReader{reader: resp.Body}
config, format, err := image.DecodeConfig(reader)
assert.NoError(t, err)
assert.Equal(t, c.format, format)
assert.Equal(t, c.width, config.Width)
assert.Equal(t, c.height, config.Height)
t.Logf("Bytes read: %d", reader.BytesRead)
})
}
}
func TestBase64(t *testing.T) {
// Bytes read:
// jpeg: 1063892
// png: 294462
// webp: 99072
// gif: 953856
// jpeg#01: 32805
for _, c := range cases {
t.Run("Decode:"+c.format, func(t *testing.T) {
resp, err := http.Get(c.url)
assert.NoError(t, err)
defer resp.Body.Close()
data, err := io.ReadAll(resp.Body)
assert.NoError(t, err)
encoded := base64.StdEncoding.EncodeToString(data)
body := base64.NewDecoder(base64.StdEncoding, strings.NewReader(encoded))
reader := &CountingReader{reader: body}
img, format, err := image.Decode(reader)
assert.NoError(t, err)
size := img.Bounds().Size()
assert.Equal(t, c.format, format)
assert.Equal(t, c.width, size.X)
assert.Equal(t, c.height, size.Y)
t.Logf("Bytes read: %d", reader.BytesRead)
})
}
// Bytes read:
// jpeg: 1536
// png: 768
// webp: 768
// gif: 1536
// jpeg#01: 3840
for _, c := range cases {
t.Run("DecodeConfig:"+c.format, func(t *testing.T) {
resp, err := http.Get(c.url)
assert.NoError(t, err)
defer resp.Body.Close()
data, err := io.ReadAll(resp.Body)
assert.NoError(t, err)
encoded := base64.StdEncoding.EncodeToString(data)
body := base64.NewDecoder(base64.StdEncoding, strings.NewReader(encoded))
reader := &CountingReader{reader: body}
config, format, err := image.DecodeConfig(reader)
assert.NoError(t, err)
assert.Equal(t, c.format, format)
assert.Equal(t, c.width, config.Width)
assert.Equal(t, c.height, config.Height)
t.Logf("Bytes read: %d", reader.BytesRead)
})
}
}
func TestGetImageSize(t *testing.T) {
for i, c := range cases {
t.Run("Decode:"+strconv.Itoa(i), func(t *testing.T) {
width, height, err := img.GetImageSize(c.url)
assert.NoError(t, err)
assert.Equal(t, c.width, width)
assert.Equal(t, c.height, height)
})
}
}
func TestGetImageSizeFromBase64(t *testing.T) {
for i, c := range cases {
t.Run("Decode:"+strconv.Itoa(i), func(t *testing.T) {
resp, err := http.Get(c.url)
assert.NoError(t, err)
defer resp.Body.Close()
data, err := io.ReadAll(resp.Body)
assert.NoError(t, err)
encoded := base64.StdEncoding.EncodeToString(data)
width, height, err := img.GetImageSizeFromBase64(encoded)
assert.NoError(t, err)
assert.Equal(t, c.width, width)
assert.Equal(t, c.height, height)
})
}
}

64
common/init.go Normal file
View File

@@ -0,0 +1,64 @@
package common
import (
"flag"
"fmt"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger"
"log"
"os"
"path/filepath"
)
var (
Port = flag.Int("port", 3000, "the listening port")
PrintVersion = flag.Bool("version", false, "print version and exit")
PrintHelp = flag.Bool("help", false, "print help and exit")
LogDir = flag.String("log-dir", "./logs", "specify the log directory")
)
func printHelp() {
fmt.Println("One API " + Version + " - All in one API service for OpenAI API.")
fmt.Println("Copyright (C) 2023 JustSong. All rights reserved.")
fmt.Println("GitHub: https://github.com/songquanpeng/one-api")
fmt.Println("Usage: one-api [--port <port>] [--log-dir <log directory>] [--version] [--help]")
}
func Init() {
flag.Parse()
if *PrintVersion {
fmt.Println(Version)
os.Exit(0)
}
if *PrintHelp {
printHelp()
os.Exit(0)
}
if os.Getenv("SESSION_SECRET") != "" {
if os.Getenv("SESSION_SECRET") == "random_string" {
logger.SysError("SESSION_SECRET is set to an example value, please change it to a random string.")
} else {
config.SessionSecret = os.Getenv("SESSION_SECRET")
}
}
if os.Getenv("SQLITE_PATH") != "" {
SQLitePath = os.Getenv("SQLITE_PATH")
}
if *LogDir != "" {
var err error
*LogDir, err = filepath.Abs(*LogDir)
if err != nil {
log.Fatal(err)
}
if _, err := os.Stat(*LogDir); os.IsNotExist(err) {
err = os.Mkdir(*LogDir, 0777)
if err != nil {
log.Fatal(err)
}
}
logger.LogDir = *LogDir
}
}

View File

@@ -0,0 +1,3 @@
package logger
var LogDir string

160
common/logger/logger.go Normal file
View File

@@ -0,0 +1,160 @@
package logger
import (
"context"
"fmt"
"io"
"log"
"os"
"path/filepath"
"runtime"
"strings"
"sync"
"time"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper"
)
type loggerLevel string
const (
loggerDEBUG loggerLevel = "DEBUG"
loggerINFO loggerLevel = "INFO"
loggerWarn loggerLevel = "WARN"
loggerError loggerLevel = "ERROR"
loggerFatal loggerLevel = "FATAL"
)
var setupLogOnce sync.Once
func SetupLogger() {
setupLogOnce.Do(func() {
if LogDir != "" {
var logPath string
if config.OnlyOneLogFile {
logPath = filepath.Join(LogDir, "oneapi.log")
} else {
logPath = filepath.Join(LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102")))
}
fd, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
if err != nil {
log.Fatal("failed to open log file")
}
gin.DefaultWriter = io.MultiWriter(os.Stdout, fd)
gin.DefaultErrorWriter = io.MultiWriter(os.Stderr, fd)
}
})
}
func SysLog(s string) {
logHelper(nil, loggerINFO, s)
}
func SysLogf(format string, a ...any) {
logHelper(nil, loggerINFO, fmt.Sprintf(format, a...))
}
func SysWarn(s string) {
logHelper(nil, loggerWarn, s)
}
func SysWarnf(format string, a ...any) {
logHelper(nil, loggerWarn, fmt.Sprintf(format, a...))
}
func SysError(s string) {
logHelper(nil, loggerError, s)
}
func SysErrorf(format string, a ...any) {
logHelper(nil, loggerError, fmt.Sprintf(format, a...))
}
func Debug(ctx context.Context, msg string) {
if !config.DebugEnabled {
return
}
logHelper(ctx, loggerDEBUG, msg)
}
func Info(ctx context.Context, msg string) {
logHelper(ctx, loggerINFO, msg)
}
func Warn(ctx context.Context, msg string) {
logHelper(ctx, loggerWarn, msg)
}
func Error(ctx context.Context, msg string) {
logHelper(ctx, loggerError, msg)
}
func Debugf(ctx context.Context, format string, a ...any) {
if !config.DebugEnabled {
return
}
logHelper(ctx, loggerDEBUG, fmt.Sprintf(format, a...))
}
func Infof(ctx context.Context, format string, a ...any) {
logHelper(ctx, loggerINFO, fmt.Sprintf(format, a...))
}
func Warnf(ctx context.Context, format string, a ...any) {
logHelper(ctx, loggerWarn, fmt.Sprintf(format, a...))
}
func Errorf(ctx context.Context, format string, a ...any) {
logHelper(ctx, loggerError, fmt.Sprintf(format, a...))
}
func FatalLog(s string) {
logHelper(nil, loggerFatal, s)
}
func FatalLogf(format string, a ...any) {
logHelper(nil, loggerFatal, fmt.Sprintf(format, a...))
}
func logHelper(ctx context.Context, level loggerLevel, msg string) {
writer := gin.DefaultErrorWriter
if level == loggerINFO {
writer = gin.DefaultWriter
}
var requestId string
if ctx != nil {
rawRequestId := helper.GetRequestID(ctx)
if rawRequestId != "" {
requestId = fmt.Sprintf(" | %s", rawRequestId)
}
}
lineInfo, funcName := getLineInfo()
now := time.Now()
_, _ = fmt.Fprintf(writer, "[%s] %v%s%s %s%s \n", level, now.Format("2006/01/02 - 15:04:05"), requestId, lineInfo, funcName, msg)
SetupLogger()
if level == loggerFatal {
os.Exit(1)
}
}
func getLineInfo() (string, string) {
funcName := "[unknown] "
pc, file, line, ok := runtime.Caller(3)
if ok {
if fn := runtime.FuncForPC(pc); fn != nil {
parts := strings.Split(fn.Name(), ".")
funcName = "[" + parts[len(parts)-1] + "] "
}
} else {
file = "unknown"
line = 0
}
parts := strings.Split(file, "one-api/")
if len(parts) > 1 {
file = parts[1]
}
return fmt.Sprintf(" | %s:%d", file, line), funcName
}

111
common/message/email.go Normal file
View File

@@ -0,0 +1,111 @@
package message
import (
"crypto/rand"
"crypto/tls"
"encoding/base64"
"fmt"
"net"
"net/smtp"
"strings"
"time"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger"
)
func shouldAuth() bool {
return config.SMTPAccount != "" || config.SMTPToken != ""
}
func SendEmail(subject string, receiver string, content string) error {
if receiver == "" {
return fmt.Errorf("receiver is empty")
}
if config.SMTPFrom == "" { // for compatibility
config.SMTPFrom = config.SMTPAccount
}
encodedSubject := fmt.Sprintf("=?UTF-8?B?%s?=", base64.StdEncoding.EncodeToString([]byte(subject)))
// Extract domain from SMTPFrom
parts := strings.Split(config.SMTPFrom, "@")
var domain string
if len(parts) > 1 {
domain = parts[1]
}
// Generate a unique Message-ID
buf := make([]byte, 16)
_, err := rand.Read(buf)
if err != nil {
return err
}
messageId := fmt.Sprintf("<%x@%s>", buf, domain)
mail := []byte(fmt.Sprintf("To: %s\r\n"+
"From: %s<%s>\r\n"+
"Subject: %s\r\n"+
"Message-ID: %s\r\n"+ // add Message-ID header to avoid being treated as spam, RFC 5322
"Date: %s\r\n"+
"Content-Type: text/html; charset=UTF-8\r\n\r\n%s\r\n",
receiver, config.SystemName, config.SMTPFrom, encodedSubject, messageId, time.Now().Format(time.RFC1123Z), content))
auth := smtp.PlainAuth("", config.SMTPAccount, config.SMTPToken, config.SMTPServer)
addr := fmt.Sprintf("%s:%d", config.SMTPServer, config.SMTPPort)
to := strings.Split(receiver, ";")
if config.SMTPPort == 465 || !shouldAuth() {
// need advanced client
var conn net.Conn
var err error
if config.SMTPPort == 465 {
tlsConfig := &tls.Config{
InsecureSkipVerify: true,
ServerName: config.SMTPServer,
}
conn, err = tls.Dial("tcp", fmt.Sprintf("%s:%d", config.SMTPServer, config.SMTPPort), tlsConfig)
} else {
conn, err = net.Dial("tcp", fmt.Sprintf("%s:%d", config.SMTPServer, config.SMTPPort))
}
if err != nil {
return err
}
client, err := smtp.NewClient(conn, config.SMTPServer)
if err != nil {
return err
}
defer client.Close()
if shouldAuth() {
if err = client.Auth(auth); err != nil {
return err
}
}
if err = client.Mail(config.SMTPFrom); err != nil {
return err
}
receiverEmails := strings.Split(receiver, ";")
for _, receiver := range receiverEmails {
if err = client.Rcpt(receiver); err != nil {
return err
}
}
w, err := client.Data()
if err != nil {
return err
}
_, err = w.Write(mail)
if err != nil {
return err
}
err = w.Close()
if err != nil {
return err
}
return nil
}
err = smtp.SendMail(addr, auth, config.SMTPAccount, to, mail)
if err != nil && strings.Contains(err.Error(), "short response") { // 部分提供商返回该错误,但实际上邮件已经发送成功
logger.SysWarnf("short response from SMTP server, return nil instead of error: %s", err.Error())
return nil
}
return err
}

22
common/message/main.go Normal file
View File

@@ -0,0 +1,22 @@
package message
import (
"fmt"
"github.com/songquanpeng/one-api/common/config"
)
const (
ByAll = "all"
ByEmail = "email"
ByMessagePusher = "message_pusher"
)
func Notify(by string, title string, description string, content string) error {
if by == ByEmail {
return SendEmail(title, config.RootUserEmail, content)
}
if by == ByMessagePusher {
return SendMessage(title, description, content)
}
return fmt.Errorf("unknown notify method: %s", by)
}

View File

@@ -0,0 +1,53 @@
package message
import (
"bytes"
"encoding/json"
"errors"
"github.com/songquanpeng/one-api/common/config"
"net/http"
)
type request struct {
Title string `json:"title"`
Description string `json:"description"`
Content string `json:"content"`
URL string `json:"url"`
Channel string `json:"channel"`
Token string `json:"token"`
}
type response struct {
Success bool `json:"success"`
Message string `json:"message"`
}
func SendMessage(title string, description string, content string) error {
if config.MessagePusherAddress == "" {
return errors.New("message pusher address is not set")
}
req := request{
Title: title,
Description: description,
Content: content,
Token: config.MessagePusherToken,
}
data, err := json.Marshal(req)
if err != nil {
return err
}
resp, err := http.Post(config.MessagePusherAddress,
"application/json", bytes.NewBuffer(data))
if err != nil {
return err
}
var res response
err = json.NewDecoder(resp.Body).Decode(&res)
if err != nil {
return err
}
if !res.Success {
return errors.New(res.Message)
}
return nil
}

View File

@@ -0,0 +1,34 @@
package message
import (
"fmt"
"github.com/songquanpeng/one-api/common/config"
)
// EmailTemplate 生成美观的 HTML 邮件内容
func EmailTemplate(title, content string) string {
return fmt.Sprintf(`
<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
</head>
<body style="margin: 0; padding: 20px; font-family: Arial, sans-serif; line-height: 1.6; background-color: #f4f4f4;">
<div style="max-width: 600px; margin: 20px auto; padding: 30px; background-color: #ffffff; border-radius: 8px; box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1);">
<div style="text-align: center; margin-bottom: 30px;">
<h2 style="color: #333; margin: 0; font-size: 24px;">%s</h2>
</div>
<div style="color: #555; font-size: 16px;">
%s
</div>
<div style="margin-top: 40px; padding-top: 20px; border-top: 1px solid #eee; color: #888; font-size: 14px; text-align: center;">
<p style="margin: 5px 0;">此邮件由系统自动发送,请勿直接回复</p>
<p style="margin: 5px 0;">%s</p>
</div>
</div>
</body>
</html>
`, title, content, config.SystemName)
}

52
common/network/ip.go Normal file
View File

@@ -0,0 +1,52 @@
package network
import (
"context"
"fmt"
"github.com/songquanpeng/one-api/common/logger"
"net"
"strings"
)
func splitSubnets(subnets string) []string {
res := strings.Split(subnets, ",")
for i := 0; i < len(res); i++ {
res[i] = strings.TrimSpace(res[i])
}
return res
}
func isValidSubnet(subnet string) error {
_, _, err := net.ParseCIDR(subnet)
if err != nil {
return fmt.Errorf("failed to parse subnet: %w", err)
}
return nil
}
func isIpInSubnet(ctx context.Context, ip string, subnet string) bool {
_, ipNet, err := net.ParseCIDR(subnet)
if err != nil {
logger.Errorf(ctx, "failed to parse subnet: %s", err.Error())
return false
}
return ipNet.Contains(net.ParseIP(ip))
}
func IsValidSubnets(subnets string) error {
for _, subnet := range splitSubnets(subnets) {
if err := isValidSubnet(subnet); err != nil {
return err
}
}
return nil
}
func IsIpInSubnets(ctx context.Context, ip string, subnets string) bool {
for _, subnet := range splitSubnets(subnets) {
if isIpInSubnet(ctx, ip, subnet) {
return true
}
}
return false
}

19
common/network/ip_test.go Normal file
View File

@@ -0,0 +1,19 @@
package network
import (
"context"
"testing"
. "github.com/smartystreets/goconvey/convey"
)
func TestIsIpInSubnet(t *testing.T) {
ctx := context.Background()
ip1 := "192.168.0.5"
ip2 := "125.216.250.89"
subnet := "192.168.0.0/24"
Convey("TestIsIpInSubnet", t, func() {
So(isIpInSubnet(ctx, ip1, subnet), ShouldBeTrue)
So(isIpInSubnet(ctx, ip2, subnet), ShouldBeFalse)
})
}

61
common/random/main.go Normal file
View File

@@ -0,0 +1,61 @@
package random
import (
"github.com/google/uuid"
"math/rand"
"strings"
"time"
)
func GetUUID() string {
code := uuid.New().String()
code = strings.Replace(code, "-", "", -1)
return code
}
const keyChars = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
const keyNumbers = "0123456789"
func init() {
rand.Seed(time.Now().UnixNano())
}
func GenerateKey() string {
rand.Seed(time.Now().UnixNano())
key := make([]byte, 48)
for i := 0; i < 16; i++ {
key[i] = keyChars[rand.Intn(len(keyChars))]
}
uuid_ := GetUUID()
for i := 0; i < 32; i++ {
c := uuid_[i]
if i%2 == 0 && c >= 'a' && c <= 'z' {
c = c - 'a' + 'A'
}
key[i+16] = c
}
return string(key)
}
func GetRandomString(length int) string {
rand.Seed(time.Now().UnixNano())
key := make([]byte, length)
for i := 0; i < length; i++ {
key[i] = keyChars[rand.Intn(len(keyChars))]
}
return string(key)
}
func GetRandomNumberString(length int) string {
rand.Seed(time.Now().UnixNano())
key := make([]byte, length)
for i := 0; i < length; i++ {
key[i] = keyNumbers[rand.Intn(len(keyNumbers))]
}
return string(key)
}
// RandRange returns a random number between min and max (max is not included)
func RandRange(min, max int) int {
return min + rand.Intn(max-min)
}

70
common/rate-limit.go Normal file
View File

@@ -0,0 +1,70 @@
package common
import (
"sync"
"time"
)
type InMemoryRateLimiter struct {
store map[string]*[]int64
mutex sync.Mutex
expirationDuration time.Duration
}
func (l *InMemoryRateLimiter) Init(expirationDuration time.Duration) {
if l.store == nil {
l.mutex.Lock()
if l.store == nil {
l.store = make(map[string]*[]int64)
l.expirationDuration = expirationDuration
if expirationDuration > 0 {
go l.clearExpiredItems()
}
}
l.mutex.Unlock()
}
}
func (l *InMemoryRateLimiter) clearExpiredItems() {
for {
time.Sleep(l.expirationDuration)
l.mutex.Lock()
now := time.Now().Unix()
for key := range l.store {
queue := l.store[key]
size := len(*queue)
if size == 0 || now-(*queue)[size-1] > int64(l.expirationDuration.Seconds()) {
delete(l.store, key)
}
}
l.mutex.Unlock()
}
}
// Request parameter duration's unit is seconds
func (l *InMemoryRateLimiter) Request(key string, maxRequestNum int, duration int64) bool {
l.mutex.Lock()
defer l.mutex.Unlock()
// [old <-- new]
queue, ok := l.store[key]
now := time.Now().Unix()
if ok {
if len(*queue) < maxRequestNum {
*queue = append(*queue, now)
return true
} else {
if now-(*queue)[0] >= duration {
*queue = (*queue)[1:]
*queue = append(*queue, now)
return true
} else {
return false
}
}
} else {
s := make([]int64, 0, maxRequestNum)
l.store[key] = &s
*(l.store[key]) = append(*(l.store[key]), now)
}
return true
}

81
common/redis.go Normal file
View File

@@ -0,0 +1,81 @@
package common
import (
"context"
"os"
"strings"
"time"
"github.com/go-redis/redis/v8"
"github.com/songquanpeng/one-api/common/logger"
)
var RDB redis.Cmdable
var RedisEnabled = true
// InitRedisClient This function is called after init()
func InitRedisClient() (err error) {
if os.Getenv("REDIS_CONN_STRING") == "" {
RedisEnabled = false
logger.SysLog("REDIS_CONN_STRING not set, Redis is not enabled")
return nil
}
if os.Getenv("SYNC_FREQUENCY") == "" {
RedisEnabled = false
logger.SysLog("SYNC_FREQUENCY not set, Redis is disabled")
return nil
}
redisConnString := os.Getenv("REDIS_CONN_STRING")
if os.Getenv("REDIS_MASTER_NAME") == "" {
logger.SysLog("Redis is enabled")
opt, err := redis.ParseURL(redisConnString)
if err != nil {
logger.FatalLog("failed to parse Redis connection string: " + err.Error())
}
RDB = redis.NewClient(opt)
} else {
// cluster mode
logger.SysLog("Redis cluster mode enabled")
RDB = redis.NewUniversalClient(&redis.UniversalOptions{
Addrs: strings.Split(redisConnString, ","),
Password: os.Getenv("REDIS_PASSWORD"),
MasterName: os.Getenv("REDIS_MASTER_NAME"),
})
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_, err = RDB.Ping(ctx).Result()
if err != nil {
logger.FatalLog("Redis ping test failed: " + err.Error())
}
return err
}
func ParseRedisOption() *redis.Options {
opt, err := redis.ParseURL(os.Getenv("REDIS_CONN_STRING"))
if err != nil {
logger.FatalLog("failed to parse Redis connection string: " + err.Error())
}
return opt
}
func RedisSet(key string, value string, expiration time.Duration) error {
ctx := context.Background()
return RDB.Set(ctx, key, value, expiration).Err()
}
func RedisGet(key string) (string, error) {
ctx := context.Background()
return RDB.Get(ctx, key).Result()
}
func RedisDel(key string) error {
ctx := context.Background()
return RDB.Del(ctx, key).Err()
}
func RedisDecrease(key string, value int64) error {
ctx := context.Background()
return RDB.DecrBy(ctx, key, value).Err()
}

30
common/render/render.go Normal file
View File

@@ -0,0 +1,30 @@
package render
import (
"encoding/json"
"fmt"
"strings"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
)
func StringData(c *gin.Context, str string) {
str = strings.TrimPrefix(str, "data: ")
str = strings.TrimSuffix(str, "\r")
c.Render(-1, common.CustomEvent{Data: "data: " + str})
c.Writer.Flush()
}
func ObjectData(c *gin.Context, object interface{}) error {
jsonData, err := json.Marshal(object)
if err != nil {
return fmt.Errorf("error marshalling object: %w", err)
}
StringData(c, string(jsonData))
return nil
}
func Done(c *gin.Context) {
StringData(c, "[DONE]")
}

14
common/utils.go Normal file
View File

@@ -0,0 +1,14 @@
package common
import (
"fmt"
"github.com/songquanpeng/one-api/common/config"
)
func LogQuota(quota int64) string {
if config.DisplayInCurrencyEnabled {
return fmt.Sprintf("%.6f 额度", float64(quota)/config.QuotaPerUnit)
} else {
return fmt.Sprintf("%d 点额度", quota)
}
}

13
common/utils/array.go Normal file
View File

@@ -0,0 +1,13 @@
package utils
func DeDuplication(slice []string) []string {
m := make(map[string]bool)
for _, v := range slice {
m[v] = true
}
result := make([]string, 0, len(m))
for v := range m {
result = append(result, v)
}
return result
}

9
common/validate.go Normal file
View File

@@ -0,0 +1,9 @@
package common
import "github.com/go-playground/validator/v10"
var Validate *validator.Validate
func init() {
Validate = validator.New()
}

77
common/verification.go Normal file
View File

@@ -0,0 +1,77 @@
package common
import (
"github.com/google/uuid"
"strings"
"sync"
"time"
)
type verificationValue struct {
code string
time time.Time
}
const (
EmailVerificationPurpose = "v"
PasswordResetPurpose = "r"
)
var verificationMutex sync.Mutex
var verificationMap map[string]verificationValue
var verificationMapMaxSize = 10
var VerificationValidMinutes = 10
func GenerateVerificationCode(length int) string {
code := uuid.New().String()
code = strings.Replace(code, "-", "", -1)
if length == 0 {
return code
}
return code[:length]
}
func RegisterVerificationCodeWithKey(key string, code string, purpose string) {
verificationMutex.Lock()
defer verificationMutex.Unlock()
verificationMap[purpose+key] = verificationValue{
code: code,
time: time.Now(),
}
if len(verificationMap) > verificationMapMaxSize {
removeExpiredPairs()
}
}
func VerifyCodeWithKey(key string, code string, purpose string) bool {
verificationMutex.Lock()
defer verificationMutex.Unlock()
value, okay := verificationMap[purpose+key]
now := time.Now()
if !okay || int(now.Sub(value.time).Seconds()) >= VerificationValidMinutes*60 {
return false
}
return code == value.code
}
func DeleteKey(key string, purpose string) {
verificationMutex.Lock()
defer verificationMutex.Unlock()
delete(verificationMap, purpose+key)
}
// no lock inside, so the caller must lock the verificationMap before calling!
func removeExpiredPairs() {
now := time.Now()
for key := range verificationMap {
if int(now.Sub(verificationMap[key].time).Seconds()) >= VerificationValidMinutes*60 {
delete(verificationMap, key)
}
}
}
func init() {
verificationMutex.Lock()
defer verificationMutex.Unlock()
verificationMap = make(map[string]verificationValue)
}