138 lines
3.5 KiB
Go
138 lines
3.5 KiB
Go
package repository
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"log"
|
|
"time"
|
|
|
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
|
"github.com/redis/go-redis/v9"
|
|
)
|
|
|
|
const (
|
|
apiKeyRateLimitKeyPrefix = "apikey:ratelimit:"
|
|
apiKeyRateLimitDuration = 24 * time.Hour
|
|
apiKeyAuthCachePrefix = "apikey:auth:"
|
|
authCacheInvalidateChannel = "auth:cache:invalidate"
|
|
)
|
|
|
|
// apiKeyRateLimitKey generates the Redis key for API key creation rate limiting.
|
|
func apiKeyRateLimitKey(userID int64) string {
|
|
return fmt.Sprintf("%s%d", apiKeyRateLimitKeyPrefix, userID)
|
|
}
|
|
|
|
func apiKeyAuthCacheKey(key string) string {
|
|
return fmt.Sprintf("%s%s", apiKeyAuthCachePrefix, key)
|
|
}
|
|
|
|
type apiKeyCache struct {
|
|
rdb *redis.Client
|
|
}
|
|
|
|
func NewAPIKeyCache(rdb *redis.Client) service.APIKeyCache {
|
|
return &apiKeyCache{rdb: rdb}
|
|
}
|
|
|
|
func (c *apiKeyCache) GetCreateAttemptCount(ctx context.Context, userID int64) (int, error) {
|
|
key := apiKeyRateLimitKey(userID)
|
|
count, err := c.rdb.Get(ctx, key).Int()
|
|
if errors.Is(err, redis.Nil) {
|
|
return 0, nil
|
|
}
|
|
return count, err
|
|
}
|
|
|
|
func (c *apiKeyCache) IncrementCreateAttemptCount(ctx context.Context, userID int64) error {
|
|
key := apiKeyRateLimitKey(userID)
|
|
pipe := c.rdb.Pipeline()
|
|
pipe.Incr(ctx, key)
|
|
pipe.Expire(ctx, key, apiKeyRateLimitDuration)
|
|
_, err := pipe.Exec(ctx)
|
|
return err
|
|
}
|
|
|
|
func (c *apiKeyCache) DeleteCreateAttemptCount(ctx context.Context, userID int64) error {
|
|
key := apiKeyRateLimitKey(userID)
|
|
return c.rdb.Del(ctx, key).Err()
|
|
}
|
|
|
|
func (c *apiKeyCache) IncrementDailyUsage(ctx context.Context, apiKey string) error {
|
|
return c.rdb.Incr(ctx, apiKey).Err()
|
|
}
|
|
|
|
func (c *apiKeyCache) SetDailyUsageExpiry(ctx context.Context, apiKey string, ttl time.Duration) error {
|
|
return c.rdb.Expire(ctx, apiKey, ttl).Err()
|
|
}
|
|
|
|
func (c *apiKeyCache) GetAuthCache(ctx context.Context, key string) (*service.APIKeyAuthCacheEntry, error) {
|
|
val, err := c.rdb.Get(ctx, apiKeyAuthCacheKey(key)).Bytes()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
var entry service.APIKeyAuthCacheEntry
|
|
if err := json.Unmarshal(val, &entry); err != nil {
|
|
return nil, err
|
|
}
|
|
return &entry, nil
|
|
}
|
|
|
|
func (c *apiKeyCache) SetAuthCache(ctx context.Context, key string, entry *service.APIKeyAuthCacheEntry, ttl time.Duration) error {
|
|
if entry == nil {
|
|
return nil
|
|
}
|
|
payload, err := json.Marshal(entry)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return c.rdb.Set(ctx, apiKeyAuthCacheKey(key), payload, ttl).Err()
|
|
}
|
|
|
|
func (c *apiKeyCache) DeleteAuthCache(ctx context.Context, key string) error {
|
|
return c.rdb.Del(ctx, apiKeyAuthCacheKey(key)).Err()
|
|
}
|
|
|
|
// PublishAuthCacheInvalidation publishes a cache invalidation message to all instances
|
|
func (c *apiKeyCache) PublishAuthCacheInvalidation(ctx context.Context, cacheKey string) error {
|
|
return c.rdb.Publish(ctx, authCacheInvalidateChannel, cacheKey).Err()
|
|
}
|
|
|
|
// SubscribeAuthCacheInvalidation subscribes to cache invalidation messages
|
|
func (c *apiKeyCache) SubscribeAuthCacheInvalidation(ctx context.Context, handler func(cacheKey string)) error {
|
|
pubsub := c.rdb.Subscribe(ctx, authCacheInvalidateChannel)
|
|
|
|
// Verify subscription is working
|
|
_, err := pubsub.Receive(ctx)
|
|
if err != nil {
|
|
_ = pubsub.Close()
|
|
return fmt.Errorf("subscribe to auth cache invalidation: %w", err)
|
|
}
|
|
|
|
go func() {
|
|
defer func() {
|
|
if err := pubsub.Close(); err != nil {
|
|
log.Printf("Warning: failed to close auth cache invalidation pubsub: %v", err)
|
|
}
|
|
}()
|
|
|
|
ch := pubsub.Channel()
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
case msg, ok := <-ch:
|
|
if !ok {
|
|
return
|
|
}
|
|
if msg != nil {
|
|
handler(msg.Payload)
|
|
}
|
|
}
|
|
}
|
|
}()
|
|
|
|
return nil
|
|
}
|