131 lines
4.4 KiB
Go
131 lines
4.4 KiB
Go
package handler
|
|
|
|
import (
|
|
"context"
|
|
"net/http"
|
|
"strings"
|
|
|
|
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
|
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
|
"github.com/gin-gonic/gin"
|
|
"go.uber.org/zap"
|
|
)
|
|
|
|
func (h *GatewayHandler) checkContentModeration(c *gin.Context, reqLog *zap.Logger, apiKey *service.APIKey, subject middleware2.AuthSubject, protocol string, model string, body []byte) *service.ContentModerationDecision {
|
|
if h == nil || h.contentModerationService == nil {
|
|
return nil
|
|
}
|
|
return runContentModeration(c, reqLog, h.contentModerationService, apiKey, subject, protocol, model, body)
|
|
}
|
|
|
|
func contentModerationStatus(decision *service.ContentModerationDecision) int {
|
|
if decision == nil || decision.StatusCode < 400 || decision.StatusCode > 599 {
|
|
return http.StatusForbidden
|
|
}
|
|
return decision.StatusCode
|
|
}
|
|
|
|
func contentModerationErrorCode(decision *service.ContentModerationDecision) string {
|
|
return "content_policy_violation"
|
|
}
|
|
|
|
func (h *OpenAIGatewayHandler) checkContentModeration(c *gin.Context, reqLog *zap.Logger, apiKey *service.APIKey, subject middleware2.AuthSubject, protocol string, model string, body []byte) *service.ContentModerationDecision {
|
|
if h == nil || h.contentModerationService == nil {
|
|
return nil
|
|
}
|
|
return runContentModeration(c, reqLog, h.contentModerationService, apiKey, subject, protocol, model, body)
|
|
}
|
|
|
|
func runContentModeration(c *gin.Context, reqLog *zap.Logger, svc *service.ContentModerationService, apiKey *service.APIKey, subject middleware2.AuthSubject, protocol string, model string, body []byte) *service.ContentModerationDecision {
|
|
if svc == nil || c == nil || c.Request == nil {
|
|
return nil
|
|
}
|
|
input := buildContentModerationInput(c, apiKey, subject, protocol, model, body)
|
|
if reqLog != nil {
|
|
reqLog.Info("content_moderation.gateway_check_start",
|
|
zap.String("request_id", input.RequestID),
|
|
zap.Int64("user_id", input.UserID),
|
|
zap.Int64("api_key_id", input.APIKeyID),
|
|
zap.String("api_key_name", input.APIKeyName),
|
|
zap.Int64p("group_id", input.GroupID),
|
|
zap.String("group_name", input.GroupName),
|
|
zap.String("endpoint", input.Endpoint),
|
|
zap.String("provider", input.Provider),
|
|
zap.String("protocol", input.Protocol),
|
|
zap.String("model", input.Model),
|
|
zap.Int("body_bytes", len(body)),
|
|
)
|
|
}
|
|
decision, err := svc.Check(c.Request.Context(), input)
|
|
if err != nil {
|
|
if reqLog != nil {
|
|
reqLog.Warn("content_moderation.check_failed", zap.Error(err))
|
|
}
|
|
return nil
|
|
}
|
|
if reqLog != nil && decision != nil {
|
|
reqLog.Info("content_moderation.gateway_check_done",
|
|
zap.String("request_id", input.RequestID),
|
|
zap.Bool("allowed", decision.Allowed),
|
|
zap.Bool("blocked", decision.Blocked),
|
|
zap.Bool("flagged", decision.Flagged),
|
|
zap.String("action", decision.Action),
|
|
zap.Int("status_code", decision.StatusCode),
|
|
zap.String("highest_category", decision.HighestCategory),
|
|
zap.Float64("highest_score", decision.HighestScore),
|
|
)
|
|
}
|
|
return decision
|
|
}
|
|
|
|
func buildContentModerationInput(c *gin.Context, apiKey *service.APIKey, subject middleware2.AuthSubject, protocol string, model string, body []byte) service.ContentModerationCheckInput {
|
|
input := service.ContentModerationCheckInput{
|
|
RequestID: contentModerationRequestID(c.Request.Context()),
|
|
UserID: subject.UserID,
|
|
Endpoint: GetInboundEndpoint(c),
|
|
Provider: contentModerationProvider(apiKey),
|
|
Model: strings.TrimSpace(model),
|
|
Protocol: protocol,
|
|
Body: body,
|
|
}
|
|
if forcedPlatform, ok := middleware2.GetForcePlatformFromContext(c); ok {
|
|
input.Provider = strings.TrimSpace(forcedPlatform)
|
|
}
|
|
if apiKey != nil {
|
|
input.APIKeyID = apiKey.ID
|
|
input.APIKeyName = apiKey.Name
|
|
if apiKey.User != nil {
|
|
input.UserEmail = apiKey.User.Email
|
|
}
|
|
if apiKey.GroupID != nil {
|
|
groupID := *apiKey.GroupID
|
|
input.GroupID = &groupID
|
|
}
|
|
if apiKey.Group != nil {
|
|
input.GroupName = apiKey.Group.Name
|
|
}
|
|
}
|
|
if input.Endpoint == "" && c.Request != nil && c.Request.URL != nil {
|
|
input.Endpoint = c.Request.URL.Path
|
|
}
|
|
return input
|
|
}
|
|
|
|
func contentModerationProvider(apiKey *service.APIKey) string {
|
|
if apiKey == nil || apiKey.Group == nil {
|
|
return ""
|
|
}
|
|
return strings.TrimSpace(apiKey.Group.Platform)
|
|
}
|
|
|
|
func contentModerationRequestID(ctx context.Context) string {
|
|
if ctx == nil {
|
|
return ""
|
|
}
|
|
if requestID, ok := ctx.Value(ctxkey.RequestID).(string); ok {
|
|
return strings.TrimSpace(requestID)
|
|
}
|
|
return ""
|
|
}
|