diff --git a/service/image.go b/service/image.go index 77a0cc7a..252093f1 100644 --- a/service/image.go +++ b/service/image.go @@ -7,7 +7,9 @@ import ( "fmt" "image" "io" + "net/http" "one-api/common" + "one-api/constant" "strings" "golang.org/x/image/webp" @@ -23,7 +25,7 @@ func DecodeBase64ImageData(base64String string) (image.Config, string, string, e decodedData, err := base64.StdEncoding.DecodeString(base64String) if err != nil { fmt.Println("Error: Failed to decode base64 string") - return image.Config{}, "", "", err + return image.Config{}, "", "", fmt.Errorf("failed to decode base64 string: %s", err.Error()) } // 创建一个bytes.Buffer用于存储解码后的数据 @@ -61,20 +63,51 @@ func DecodeBase64FileData(base64String string) (string, string, error) { func GetImageFromUrl(url string) (mimeType string, data string, err error) { resp, err := DoDownloadRequest(url) if err != nil { - return "", "", err - } - if !strings.HasPrefix(resp.Header.Get("Content-Type"), "image/") { - return "", "", fmt.Errorf("invalid content type: %s, required image/*", resp.Header.Get("Content-Type")) + return "", "", fmt.Errorf("failed to download image: %w", err) } defer resp.Body.Close() - buffer := bytes.NewBuffer(nil) - _, err = buffer.ReadFrom(resp.Body) - if err != nil { - return + + // Check HTTP status code + if resp.StatusCode != http.StatusOK { + return "", "", fmt.Errorf("failed to download image: HTTP %d", resp.StatusCode) } - mimeType = resp.Header.Get("Content-Type") + + contentType := resp.Header.Get("Content-Type") + if contentType != "application/octet-stream" && !strings.HasPrefix(contentType, "image/") { + return "", "", fmt.Errorf("invalid content type: %s, required image/*", contentType) + } + maxImageSize := int64(constant.MaxFileDownloadMB * 1024 * 1024) + + // Check Content-Length if available + if resp.ContentLength > maxImageSize { + return "", "", fmt.Errorf("image size %d exceeds maximum allowed size of %d bytes", resp.ContentLength, maxImageSize) + } + + // Use LimitReader to prevent reading oversized images + limitReader := io.LimitReader(resp.Body, maxImageSize) + buffer := &bytes.Buffer{} + + written, err := io.Copy(buffer, limitReader) + if err != nil { + return "", "", fmt.Errorf("failed to read image data: %w", err) + } + if written >= maxImageSize { + return "", "", fmt.Errorf("image size exceeds maximum allowed size of %d bytes", maxImageSize) + } + data = base64.StdEncoding.EncodeToString(buffer.Bytes()) - return + mimeType = contentType + + // Handle application/octet-stream type + if mimeType == "application/octet-stream" { + _, format, _, err := DecodeBase64ImageData(data) + if err != nil { + return "", "", err + } + mimeType = "image/" + format + } + + return mimeType, data, nil } func DecodeUrlImageData(imageUrl string) (image.Config, string, error) { @@ -92,7 +125,7 @@ func DecodeUrlImageData(imageUrl string) (image.Config, string, error) { mimeType := response.Header.Get("Content-Type") - if !strings.HasPrefix(mimeType, "image/") { + if mimeType != "application/octet-stream" && !strings.HasPrefix(mimeType, "image/") { return image.Config{}, "", fmt.Errorf("invalid content type: %s, required image/*", mimeType) }