mirror of
https://mirror.skon.top/github.com/router-for-me/CLIProxyAPI
synced 2026-04-30 16:20:23 +08:00
feat(api): implement protocol multiplexer and Redis queue for usage integration
- Added `protocol_multiplexer.go`, enabling support for both HTTP and Redis protocols on a single listener. - Introduced `redis_queue_protocol.go` to handle Redis-compatible RESP commands for queue management. - Integrated `redisqueue` package, supporting in-memory queuing with expiration pruning. - Updated server initialization to manage a shared listener and multiplex connections. - Adjusted `Handler` to adopt `AuthenticateManagementKey` for modular key validation, supporting both HTTP and Redis flows.
This commit is contained in:
32
internal/api/buffered_conn.go
Normal file
32
internal/api/buffered_conn.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"crypto/tls"
|
||||
"net"
|
||||
)
|
||||
|
||||
type bufferedConn struct {
|
||||
net.Conn
|
||||
reader *bufio.Reader
|
||||
}
|
||||
|
||||
func (c *bufferedConn) Read(p []byte) (int, error) {
|
||||
if c == nil {
|
||||
return 0, net.ErrClosed
|
||||
}
|
||||
if c.reader == nil {
|
||||
return c.Conn.Read(p)
|
||||
}
|
||||
return c.reader.Read(p)
|
||||
}
|
||||
|
||||
func (c *bufferedConn) ConnectionState() tls.ConnectionState {
|
||||
if c == nil || c.Conn == nil {
|
||||
return tls.ConnectionState{}
|
||||
}
|
||||
if stater, ok := c.Conn.(interface{ ConnectionState() tls.ConnectionState }); ok {
|
||||
return stater.ConnectionState()
|
||||
}
|
||||
return tls.ConnectionState{}
|
||||
}
|
||||
@@ -152,9 +152,6 @@ func (h *Handler) SetPostAuthHook(hook coreauth.PostAuthHook) {
|
||||
// All requests (local and remote) require a valid management key.
|
||||
// Additionally, remote access requires allow-remote-management=true.
|
||||
func (h *Handler) Middleware() gin.HandlerFunc {
|
||||
const maxFailures = 5
|
||||
const banDuration = 30 * time.Minute
|
||||
|
||||
return func(c *gin.Context) {
|
||||
c.Header("X-CPA-VERSION", buildinfo.Version)
|
||||
c.Header("X-CPA-COMMIT", buildinfo.Commit)
|
||||
@@ -162,64 +159,6 @@ func (h *Handler) Middleware() gin.HandlerFunc {
|
||||
|
||||
clientIP := c.ClientIP()
|
||||
localClient := clientIP == "127.0.0.1" || clientIP == "::1"
|
||||
cfg := h.cfg
|
||||
var (
|
||||
allowRemote bool
|
||||
secretHash string
|
||||
)
|
||||
if cfg != nil {
|
||||
allowRemote = cfg.RemoteManagement.AllowRemote
|
||||
secretHash = cfg.RemoteManagement.SecretKey
|
||||
}
|
||||
if h.allowRemoteOverride {
|
||||
allowRemote = true
|
||||
}
|
||||
envSecret := h.envSecret
|
||||
|
||||
fail := func() {}
|
||||
if !localClient {
|
||||
h.attemptsMu.Lock()
|
||||
ai := h.failedAttempts[clientIP]
|
||||
if ai != nil {
|
||||
if !ai.blockedUntil.IsZero() {
|
||||
if time.Now().Before(ai.blockedUntil) {
|
||||
remaining := time.Until(ai.blockedUntil).Round(time.Second)
|
||||
h.attemptsMu.Unlock()
|
||||
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{"error": fmt.Sprintf("IP banned due to too many failed attempts. Try again in %s", remaining)})
|
||||
return
|
||||
}
|
||||
// Ban expired, reset state
|
||||
ai.blockedUntil = time.Time{}
|
||||
ai.count = 0
|
||||
}
|
||||
}
|
||||
h.attemptsMu.Unlock()
|
||||
|
||||
if !allowRemote {
|
||||
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{"error": "remote management disabled"})
|
||||
return
|
||||
}
|
||||
|
||||
fail = func() {
|
||||
h.attemptsMu.Lock()
|
||||
aip := h.failedAttempts[clientIP]
|
||||
if aip == nil {
|
||||
aip = &attemptInfo{}
|
||||
h.failedAttempts[clientIP] = aip
|
||||
}
|
||||
aip.count++
|
||||
aip.lastActivity = time.Now()
|
||||
if aip.count >= maxFailures {
|
||||
aip.blockedUntil = time.Now().Add(banDuration)
|
||||
aip.count = 0
|
||||
}
|
||||
h.attemptsMu.Unlock()
|
||||
}
|
||||
}
|
||||
if secretHash == "" && envSecret == "" {
|
||||
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{"error": "remote management key not set"})
|
||||
return
|
||||
}
|
||||
|
||||
// Accept either Authorization: Bearer <key> or X-Management-Key
|
||||
var provided string
|
||||
@@ -235,44 +174,98 @@ func (h *Handler) Middleware() gin.HandlerFunc {
|
||||
provided = c.GetHeader("X-Management-Key")
|
||||
}
|
||||
|
||||
if provided == "" {
|
||||
if !localClient {
|
||||
fail()
|
||||
}
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "missing management key"})
|
||||
allowed, statusCode, errMsg := h.AuthenticateManagementKey(clientIP, localClient, provided)
|
||||
if !allowed {
|
||||
c.AbortWithStatusJSON(statusCode, gin.H{"error": errMsg})
|
||||
return
|
||||
}
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
if localClient {
|
||||
if lp := h.localPassword; lp != "" {
|
||||
if subtle.ConstantTimeCompare([]byte(provided), []byte(lp)) == 1 {
|
||||
c.Next()
|
||||
return
|
||||
// AuthenticateManagementKey verifies the provided management key for the given client.
|
||||
// It mirrors the behaviour of Middleware() so non-HTTP callers can reuse the same logic.
|
||||
func (h *Handler) AuthenticateManagementKey(clientIP string, localClient bool, provided string) (bool, int, string) {
|
||||
const maxFailures = 5
|
||||
const banDuration = 30 * time.Minute
|
||||
|
||||
if h == nil {
|
||||
return false, http.StatusForbidden, "remote management disabled"
|
||||
}
|
||||
|
||||
cfg := h.cfg
|
||||
var (
|
||||
allowRemote bool
|
||||
secretHash string
|
||||
)
|
||||
if cfg != nil {
|
||||
allowRemote = cfg.RemoteManagement.AllowRemote
|
||||
secretHash = cfg.RemoteManagement.SecretKey
|
||||
}
|
||||
if h.allowRemoteOverride {
|
||||
allowRemote = true
|
||||
}
|
||||
envSecret := h.envSecret
|
||||
|
||||
fail := func() {}
|
||||
if !localClient {
|
||||
h.attemptsMu.Lock()
|
||||
ai := h.failedAttempts[clientIP]
|
||||
if ai != nil {
|
||||
if !ai.blockedUntil.IsZero() {
|
||||
if time.Now().Before(ai.blockedUntil) {
|
||||
remaining := time.Until(ai.blockedUntil).Round(time.Second)
|
||||
h.attemptsMu.Unlock()
|
||||
return false, http.StatusForbidden, fmt.Sprintf("IP banned due to too many failed attempts. Try again in %s", remaining)
|
||||
}
|
||||
// Ban expired, reset state
|
||||
ai.blockedUntil = time.Time{}
|
||||
ai.count = 0
|
||||
}
|
||||
}
|
||||
h.attemptsMu.Unlock()
|
||||
|
||||
if envSecret != "" && subtle.ConstantTimeCompare([]byte(provided), []byte(envSecret)) == 1 {
|
||||
if !localClient {
|
||||
h.attemptsMu.Lock()
|
||||
if ai := h.failedAttempts[clientIP]; ai != nil {
|
||||
ai.count = 0
|
||||
ai.blockedUntil = time.Time{}
|
||||
}
|
||||
h.attemptsMu.Unlock()
|
||||
}
|
||||
c.Next()
|
||||
return
|
||||
if !allowRemote {
|
||||
return false, http.StatusForbidden, "remote management disabled"
|
||||
}
|
||||
|
||||
if secretHash == "" || bcrypt.CompareHashAndPassword([]byte(secretHash), []byte(provided)) != nil {
|
||||
if !localClient {
|
||||
fail()
|
||||
fail = func() {
|
||||
h.attemptsMu.Lock()
|
||||
aip := h.failedAttempts[clientIP]
|
||||
if aip == nil {
|
||||
aip = &attemptInfo{}
|
||||
h.failedAttempts[clientIP] = aip
|
||||
}
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "invalid management key"})
|
||||
return
|
||||
aip.count++
|
||||
aip.lastActivity = time.Now()
|
||||
if aip.count >= maxFailures {
|
||||
aip.blockedUntil = time.Now().Add(banDuration)
|
||||
aip.count = 0
|
||||
}
|
||||
h.attemptsMu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
if secretHash == "" && envSecret == "" {
|
||||
return false, http.StatusForbidden, "remote management key not set"
|
||||
}
|
||||
|
||||
if provided == "" {
|
||||
if !localClient {
|
||||
fail()
|
||||
}
|
||||
return false, http.StatusUnauthorized, "missing management key"
|
||||
}
|
||||
|
||||
if localClient {
|
||||
if lp := h.localPassword; lp != "" {
|
||||
if subtle.ConstantTimeCompare([]byte(provided), []byte(lp)) == 1 {
|
||||
return true, 0, ""
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if envSecret != "" && subtle.ConstantTimeCompare([]byte(provided), []byte(envSecret)) == 1 {
|
||||
if !localClient {
|
||||
h.attemptsMu.Lock()
|
||||
if ai := h.failedAttempts[clientIP]; ai != nil {
|
||||
@@ -281,9 +274,26 @@ func (h *Handler) Middleware() gin.HandlerFunc {
|
||||
}
|
||||
h.attemptsMu.Unlock()
|
||||
}
|
||||
|
||||
c.Next()
|
||||
return true, 0, ""
|
||||
}
|
||||
|
||||
if secretHash == "" || bcrypt.CompareHashAndPassword([]byte(secretHash), []byte(provided)) != nil {
|
||||
if !localClient {
|
||||
fail()
|
||||
}
|
||||
return false, http.StatusUnauthorized, "invalid management key"
|
||||
}
|
||||
|
||||
if !localClient {
|
||||
h.attemptsMu.Lock()
|
||||
if ai := h.failedAttempts[clientIP]; ai != nil {
|
||||
ai.count = 0
|
||||
ai.blockedUntil = time.Time{}
|
||||
}
|
||||
h.attemptsMu.Unlock()
|
||||
}
|
||||
|
||||
return true, 0, ""
|
||||
}
|
||||
|
||||
// persist saves the current in-memory config to disk.
|
||||
|
||||
68
internal/api/mux_listener.go
Normal file
68
internal/api/mux_listener.go
Normal file
@@ -0,0 +1,68 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"net"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type muxListener struct {
|
||||
addr net.Addr
|
||||
connCh chan net.Conn
|
||||
closeCh chan struct{}
|
||||
once sync.Once
|
||||
}
|
||||
|
||||
func newMuxListener(addr net.Addr, buffer int) *muxListener {
|
||||
if buffer <= 0 {
|
||||
buffer = 1
|
||||
}
|
||||
return &muxListener{
|
||||
addr: addr,
|
||||
connCh: make(chan net.Conn, buffer),
|
||||
closeCh: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (l *muxListener) Put(conn net.Conn) error {
|
||||
if conn == nil {
|
||||
return nil
|
||||
}
|
||||
select {
|
||||
case <-l.closeCh:
|
||||
return net.ErrClosed
|
||||
case l.connCh <- conn:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (l *muxListener) Accept() (net.Conn, error) {
|
||||
select {
|
||||
case <-l.closeCh:
|
||||
return nil, net.ErrClosed
|
||||
case conn := <-l.connCh:
|
||||
if conn == nil {
|
||||
return nil, net.ErrClosed
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (l *muxListener) Close() error {
|
||||
if l == nil {
|
||||
return nil
|
||||
}
|
||||
l.once.Do(func() {
|
||||
close(l.closeCh)
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *muxListener) Addr() net.Addr {
|
||||
if l == nil {
|
||||
return &net.TCPAddr{}
|
||||
}
|
||||
if l.addr == nil {
|
||||
return &net.TCPAddr{}
|
||||
}
|
||||
return l.addr
|
||||
}
|
||||
109
internal/api/protocol_multiplexer.go
Normal file
109
internal/api/protocol_multiplexer.go
Normal file
@@ -0,0 +1,109 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func normalizeHTTPServeError(err error) error {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
if errors.Is(err, net.ErrClosed) {
|
||||
return nil
|
||||
}
|
||||
if errors.Is(err, http.ErrServerClosed) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func normalizeListenerError(err error) error {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
if errors.Is(err, net.ErrClosed) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *Server) acceptMuxConnections(listener net.Listener, httpListener *muxListener) error {
|
||||
if s == nil || listener == nil {
|
||||
return net.ErrClosed
|
||||
}
|
||||
|
||||
for {
|
||||
conn, errAccept := listener.Accept()
|
||||
if errAccept != nil {
|
||||
return errAccept
|
||||
}
|
||||
if conn == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
tlsConn, ok := conn.(*tls.Conn)
|
||||
if ok {
|
||||
if errHandshake := tlsConn.Handshake(); errHandshake != nil {
|
||||
if errClose := conn.Close(); errClose != nil {
|
||||
log.Errorf("failed to close connection after TLS handshake error: %v", errClose)
|
||||
}
|
||||
continue
|
||||
}
|
||||
proto := strings.TrimSpace(tlsConn.ConnectionState().NegotiatedProtocol)
|
||||
if proto == "h2" || proto == "http/1.1" {
|
||||
if httpListener == nil {
|
||||
if errClose := conn.Close(); errClose != nil {
|
||||
log.Errorf("failed to close connection: %v", errClose)
|
||||
}
|
||||
continue
|
||||
}
|
||||
if errPut := httpListener.Put(tlsConn); errPut != nil {
|
||||
if errClose := conn.Close(); errClose != nil {
|
||||
log.Errorf("failed to close connection after HTTP routing failure: %v", errClose)
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
reader := bufio.NewReader(conn)
|
||||
prefix, errPeek := reader.Peek(1)
|
||||
if errPeek != nil {
|
||||
if errClose := conn.Close(); errClose != nil {
|
||||
log.Errorf("failed to close connection after protocol peek failure: %v", errClose)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if isRedisRESPPrefix(prefix[0]) {
|
||||
if !s.managementRoutesEnabled.Load() {
|
||||
if errClose := conn.Close(); errClose != nil {
|
||||
log.Errorf("failed to close redis connection while management is disabled: %v", errClose)
|
||||
}
|
||||
continue
|
||||
}
|
||||
go s.handleRedisConnection(conn, reader)
|
||||
continue
|
||||
}
|
||||
|
||||
if httpListener == nil {
|
||||
if errClose := conn.Close(); errClose != nil {
|
||||
log.Errorf("failed to close connection without HTTP listener: %v", errClose)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if errPut := httpListener.Put(&bufferedConn{Conn: conn, reader: reader}); errPut != nil {
|
||||
if errClose := conn.Close(); errClose != nil {
|
||||
log.Errorf("failed to close connection after HTTP routing failure: %v", errClose)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
317
internal/api/redis_queue_protocol.go
Normal file
317
internal/api/redis_queue_protocol.go
Normal file
@@ -0,0 +1,317 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/redisqueue"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func isRedisRESPPrefix(prefix byte) bool {
|
||||
switch prefix {
|
||||
case '*', '$', '+', '-', ':':
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handleRedisConnection(conn net.Conn, reader *bufio.Reader) {
|
||||
if s == nil || conn == nil || reader == nil {
|
||||
return
|
||||
}
|
||||
|
||||
clientIP, localClient := resolveRemoteIP(conn.RemoteAddr())
|
||||
authed := false
|
||||
writer := bufio.NewWriter(conn)
|
||||
defer func() {
|
||||
if errClose := conn.Close(); errClose != nil {
|
||||
log.Errorf("redis connection close error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
|
||||
flush := func() bool {
|
||||
if errFlush := writer.Flush(); errFlush != nil {
|
||||
log.Errorf("redis protocol flush error: %v", errFlush)
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
for {
|
||||
if !s.managementRoutesEnabled.Load() {
|
||||
return
|
||||
}
|
||||
|
||||
args, err := readRESPArray(reader)
|
||||
if err != nil {
|
||||
if !errors.Is(err, io.EOF) {
|
||||
_ = writeRedisError(writer, "ERR "+err.Error())
|
||||
_ = writer.Flush()
|
||||
}
|
||||
return
|
||||
}
|
||||
if len(args) == 0 {
|
||||
_ = writeRedisError(writer, "ERR empty command")
|
||||
if !flush() {
|
||||
return
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
cmd := strings.ToUpper(strings.TrimSpace(args[0]))
|
||||
switch cmd {
|
||||
case "AUTH":
|
||||
password, ok := parseAuthPassword(args)
|
||||
if !ok {
|
||||
_ = writeRedisError(writer, "ERR wrong number of arguments for 'auth' command")
|
||||
if !flush() {
|
||||
return
|
||||
}
|
||||
continue
|
||||
}
|
||||
if s.mgmt == nil {
|
||||
_ = writeRedisError(writer, "ERR remote management disabled")
|
||||
if !flush() {
|
||||
return
|
||||
}
|
||||
continue
|
||||
}
|
||||
allowed, _, errMsg := s.mgmt.AuthenticateManagementKey(clientIP, localClient, password)
|
||||
if !allowed {
|
||||
_ = writeRedisError(writer, "ERR "+errMsg)
|
||||
if !flush() {
|
||||
return
|
||||
}
|
||||
continue
|
||||
}
|
||||
authed = true
|
||||
_ = writeRedisSimpleString(writer, "OK")
|
||||
if !flush() {
|
||||
return
|
||||
}
|
||||
case "LPOP", "RPOP":
|
||||
if !authed {
|
||||
_ = writeRedisError(writer, "NOAUTH Authentication required.")
|
||||
if !flush() {
|
||||
return
|
||||
}
|
||||
continue
|
||||
}
|
||||
count, hasCount, ok := parsePopCount(args)
|
||||
if !ok {
|
||||
_ = writeRedisError(writer, "ERR wrong number of arguments for '"+strings.ToLower(cmd)+"' command")
|
||||
if !flush() {
|
||||
return
|
||||
}
|
||||
continue
|
||||
}
|
||||
if count <= 0 {
|
||||
_ = writeRedisError(writer, "ERR value is not an integer or out of range")
|
||||
if !flush() {
|
||||
return
|
||||
}
|
||||
continue
|
||||
}
|
||||
items := redisqueue.PopOldest(count)
|
||||
if hasCount {
|
||||
_ = writeRedisArrayOfBulkStrings(writer, items)
|
||||
if !flush() {
|
||||
return
|
||||
}
|
||||
continue
|
||||
}
|
||||
if len(items) == 0 {
|
||||
_ = writeRedisNilBulkString(writer)
|
||||
if !flush() {
|
||||
return
|
||||
}
|
||||
continue
|
||||
}
|
||||
_ = writeRedisBulkString(writer, items[0])
|
||||
if !flush() {
|
||||
return
|
||||
}
|
||||
default:
|
||||
_ = writeRedisError(writer, fmt.Sprintf("ERR unknown command '%s'", strings.ToLower(cmd)))
|
||||
if !flush() {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func resolveRemoteIP(addr net.Addr) (ip string, localClient bool) {
|
||||
if addr == nil {
|
||||
return "", false
|
||||
}
|
||||
host := addr.String()
|
||||
if h, _, err := net.SplitHostPort(host); err == nil {
|
||||
host = h
|
||||
}
|
||||
host = strings.TrimSpace(host)
|
||||
localClient = host == "127.0.0.1" || host == "::1"
|
||||
return host, localClient
|
||||
}
|
||||
|
||||
func parseAuthPassword(args []string) (string, bool) {
|
||||
switch len(args) {
|
||||
case 2:
|
||||
return args[1], true
|
||||
case 3:
|
||||
// Support AUTH <username> <password> by ignoring username for compatibility.
|
||||
return args[2], true
|
||||
default:
|
||||
return "", false
|
||||
}
|
||||
}
|
||||
|
||||
func parsePopCount(args []string) (count int, hasCount bool, ok bool) {
|
||||
if len(args) != 2 && len(args) != 3 {
|
||||
return 0, false, false
|
||||
}
|
||||
if len(args) == 2 {
|
||||
return 1, false, true
|
||||
}
|
||||
parsed, err := strconv.Atoi(strings.TrimSpace(args[2]))
|
||||
if err != nil {
|
||||
return 0, true, true
|
||||
}
|
||||
return parsed, true, true
|
||||
}
|
||||
|
||||
func readRESPArray(reader *bufio.Reader) ([]string, error) {
|
||||
prefix, err := reader.ReadByte()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if prefix != '*' {
|
||||
return nil, fmt.Errorf("protocol error")
|
||||
}
|
||||
line, err := readRESPLine(reader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
count, err := strconv.Atoi(line)
|
||||
if err != nil || count < 0 {
|
||||
return nil, fmt.Errorf("protocol error")
|
||||
}
|
||||
args := make([]string, 0, count)
|
||||
for i := 0; i < count; i++ {
|
||||
value, err := readRESPString(reader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
args = append(args, value)
|
||||
}
|
||||
return args, nil
|
||||
}
|
||||
|
||||
func readRESPString(reader *bufio.Reader) (string, error) {
|
||||
prefix, err := reader.ReadByte()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
switch prefix {
|
||||
case '$':
|
||||
return readRESPBulkString(reader)
|
||||
case '+', ':':
|
||||
return readRESPLine(reader)
|
||||
default:
|
||||
return "", fmt.Errorf("protocol error")
|
||||
}
|
||||
}
|
||||
|
||||
func readRESPBulkString(reader *bufio.Reader) (string, error) {
|
||||
line, err := readRESPLine(reader)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
length, err := strconv.Atoi(line)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("protocol error")
|
||||
}
|
||||
if length < 0 {
|
||||
return "", nil
|
||||
}
|
||||
buf := make([]byte, length+2)
|
||||
if _, err := io.ReadFull(reader, buf); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if length+2 < 2 || buf[length] != '\r' || buf[length+1] != '\n' {
|
||||
return "", fmt.Errorf("protocol error")
|
||||
}
|
||||
return string(buf[:length]), nil
|
||||
}
|
||||
|
||||
func readRESPLine(reader *bufio.Reader) (string, error) {
|
||||
line, err := reader.ReadString('\n')
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
line = strings.TrimSuffix(line, "\n")
|
||||
line = strings.TrimSuffix(line, "\r")
|
||||
return line, nil
|
||||
}
|
||||
|
||||
func writeRedisSimpleString(writer *bufio.Writer, value string) error {
|
||||
if writer == nil {
|
||||
return net.ErrClosed
|
||||
}
|
||||
_, err := writer.WriteString("+" + value + "\r\n")
|
||||
return err
|
||||
}
|
||||
|
||||
func writeRedisError(writer *bufio.Writer, message string) error {
|
||||
if writer == nil {
|
||||
return net.ErrClosed
|
||||
}
|
||||
_, err := writer.WriteString("-" + message + "\r\n")
|
||||
return err
|
||||
}
|
||||
|
||||
func writeRedisNilBulkString(writer *bufio.Writer) error {
|
||||
if writer == nil {
|
||||
return net.ErrClosed
|
||||
}
|
||||
_, err := writer.WriteString("$-1\r\n")
|
||||
return err
|
||||
}
|
||||
|
||||
func writeRedisBulkString(writer *bufio.Writer, payload []byte) error {
|
||||
if writer == nil {
|
||||
return net.ErrClosed
|
||||
}
|
||||
if payload == nil {
|
||||
return writeRedisNilBulkString(writer)
|
||||
}
|
||||
if _, err := writer.WriteString("$" + strconv.Itoa(len(payload)) + "\r\n"); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := writer.Write(payload); err != nil {
|
||||
return err
|
||||
}
|
||||
_, err := writer.WriteString("\r\n")
|
||||
return err
|
||||
}
|
||||
|
||||
func writeRedisArrayOfBulkStrings(writer *bufio.Writer, items [][]byte) error {
|
||||
if writer == nil {
|
||||
return net.ErrClosed
|
||||
}
|
||||
if _, err := writer.WriteString("*" + strconv.Itoa(len(items)) + "\r\n"); err != nil {
|
||||
return err
|
||||
}
|
||||
for i := range items {
|
||||
if err := writeRedisBulkString(writer, items[i]); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
304
internal/api/redis_queue_protocol_integration_test.go
Normal file
304
internal/api/redis_queue_protocol_integration_test.go
Normal file
@@ -0,0 +1,304 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/redisqueue"
|
||||
)
|
||||
|
||||
func startRedisMuxListener(t *testing.T, server *Server) (addr string, stop func()) {
|
||||
t.Helper()
|
||||
|
||||
listener, errListen := net.Listen("tcp", "127.0.0.1:0")
|
||||
if errListen != nil {
|
||||
t.Fatalf("failed to listen: %v", errListen)
|
||||
}
|
||||
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
errCh <- server.acceptMuxConnections(listener, nil)
|
||||
}()
|
||||
|
||||
stop = func() {
|
||||
_ = listener.Close()
|
||||
select {
|
||||
case err := <-errCh:
|
||||
if err != nil && !errors.Is(err, net.ErrClosed) {
|
||||
t.Errorf("accept loop returned unexpected error: %v", err)
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Errorf("timeout waiting for accept loop to exit")
|
||||
}
|
||||
}
|
||||
|
||||
return listener.Addr().String(), stop
|
||||
}
|
||||
|
||||
func writeTestRESPCommand(conn net.Conn, args ...string) error {
|
||||
if conn == nil {
|
||||
return net.ErrClosed
|
||||
}
|
||||
if len(args) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
fmt.Fprintf(&buf, "*%d\r\n", len(args))
|
||||
for _, arg := range args {
|
||||
fmt.Fprintf(&buf, "$%d\r\n%s\r\n", len(arg), arg)
|
||||
}
|
||||
_, err := conn.Write(buf.Bytes())
|
||||
return err
|
||||
}
|
||||
|
||||
func readTestRESPLine(r *bufio.Reader) (string, error) {
|
||||
line, err := r.ReadString('\n')
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if !strings.HasSuffix(line, "\r\n") {
|
||||
return "", fmt.Errorf("invalid RESP line terminator: %q", line)
|
||||
}
|
||||
return strings.TrimSuffix(line, "\r\n"), nil
|
||||
}
|
||||
|
||||
func readTestRESPSimpleString(r *bufio.Reader) (string, error) {
|
||||
prefix, err := r.ReadByte()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if prefix != '+' {
|
||||
return "", fmt.Errorf("expected simple string prefix '+', got %q", prefix)
|
||||
}
|
||||
return readTestRESPLine(r)
|
||||
}
|
||||
|
||||
func readTestRESPError(r *bufio.Reader) (string, error) {
|
||||
prefix, err := r.ReadByte()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if prefix != '-' {
|
||||
return "", fmt.Errorf("expected error prefix '-', got %q", prefix)
|
||||
}
|
||||
return readTestRESPLine(r)
|
||||
}
|
||||
|
||||
func readTestRESPBulkString(r *bufio.Reader) ([]byte, error) {
|
||||
prefix, err := r.ReadByte()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if prefix != '$' {
|
||||
return nil, fmt.Errorf("expected bulk string prefix '$', got %q", prefix)
|
||||
}
|
||||
|
||||
line, err := readTestRESPLine(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
length, err := strconv.Atoi(line)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid bulk string length %q: %v", line, err)
|
||||
}
|
||||
if length == -1 {
|
||||
return nil, nil
|
||||
}
|
||||
if length < -1 {
|
||||
return nil, fmt.Errorf("invalid bulk string length %d", length)
|
||||
}
|
||||
|
||||
payload := make([]byte, length+2)
|
||||
if _, err := io.ReadFull(r, payload); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if payload[length] != '\r' || payload[length+1] != '\n' {
|
||||
return nil, fmt.Errorf("invalid bulk string terminator")
|
||||
}
|
||||
return payload[:length], nil
|
||||
}
|
||||
|
||||
func readRESPArrayOfBulkStrings(r *bufio.Reader) ([][]byte, error) {
|
||||
prefix, err := r.ReadByte()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if prefix != '*' {
|
||||
return nil, fmt.Errorf("expected array prefix '*', got %q", prefix)
|
||||
}
|
||||
|
||||
line, err := readTestRESPLine(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
count, err := strconv.Atoi(line)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid array length %q: %v", line, err)
|
||||
}
|
||||
if count < 0 {
|
||||
return nil, fmt.Errorf("invalid array length %d", count)
|
||||
}
|
||||
|
||||
out := make([][]byte, 0, count)
|
||||
for i := 0; i < count; i++ {
|
||||
item, err := readTestRESPBulkString(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out = append(out, item)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func TestRedisProtocol_ManagementDisabled_RejectsConnection(t *testing.T) {
|
||||
t.Setenv("MANAGEMENT_PASSWORD", "")
|
||||
redisqueue.SetEnabled(false)
|
||||
|
||||
server := newTestServer(t)
|
||||
if server.managementRoutesEnabled.Load() {
|
||||
t.Fatalf("expected managementRoutesEnabled to be false")
|
||||
}
|
||||
|
||||
addr, stop := startRedisMuxListener(t, server)
|
||||
t.Cleanup(stop)
|
||||
|
||||
conn, errDial := net.DialTimeout("tcp", addr, time.Second)
|
||||
if errDial != nil {
|
||||
t.Fatalf("failed to dial redis listener: %v", errDial)
|
||||
}
|
||||
t.Cleanup(func() { _ = conn.Close() })
|
||||
|
||||
_ = conn.SetDeadline(time.Now().Add(2 * time.Second))
|
||||
if errWrite := writeTestRESPCommand(conn, "PING"); errWrite != nil {
|
||||
t.Fatalf("failed to write RESP command: %v", errWrite)
|
||||
}
|
||||
|
||||
buf := make([]byte, 1)
|
||||
_, errRead := conn.Read(buf)
|
||||
if errRead == nil {
|
||||
t.Fatalf("expected connection to be closed when management is disabled")
|
||||
}
|
||||
if ne, ok := errRead.(net.Error); ok && ne.Timeout() {
|
||||
t.Fatalf("expected connection to be closed when management is disabled, got timeout: %v", errRead)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedisProtocol_AUTH_And_PopContracts(t *testing.T) {
|
||||
const managementPassword = "test-management-password"
|
||||
|
||||
t.Setenv("MANAGEMENT_PASSWORD", managementPassword)
|
||||
redisqueue.SetEnabled(false)
|
||||
t.Cleanup(func() { redisqueue.SetEnabled(false) })
|
||||
|
||||
server := newTestServer(t)
|
||||
if !server.managementRoutesEnabled.Load() {
|
||||
t.Fatalf("expected managementRoutesEnabled to be true")
|
||||
}
|
||||
|
||||
addr, stop := startRedisMuxListener(t, server)
|
||||
t.Cleanup(stop)
|
||||
|
||||
conn, errDial := net.DialTimeout("tcp", addr, time.Second)
|
||||
if errDial != nil {
|
||||
t.Fatalf("failed to dial redis listener: %v", errDial)
|
||||
}
|
||||
t.Cleanup(func() { _ = conn.Close() })
|
||||
|
||||
reader := bufio.NewReader(conn)
|
||||
|
||||
_ = conn.SetDeadline(time.Now().Add(5 * time.Second))
|
||||
|
||||
if errWrite := writeTestRESPCommand(conn, "AUTH", "test-key"); errWrite != nil {
|
||||
t.Fatalf("failed to write AUTH command: %v", errWrite)
|
||||
}
|
||||
if msg, err := readTestRESPError(reader); err != nil {
|
||||
t.Fatalf("failed to read AUTH error: %v", err)
|
||||
} else if msg != "ERR invalid management key" {
|
||||
t.Fatalf("unexpected AUTH error: %q", msg)
|
||||
}
|
||||
|
||||
if errWrite := writeTestRESPCommand(conn, "LPOP", "queue"); errWrite != nil {
|
||||
t.Fatalf("failed to write LPOP command: %v", errWrite)
|
||||
}
|
||||
if msg, err := readTestRESPError(reader); err != nil {
|
||||
t.Fatalf("failed to read LPOP NOAUTH error: %v", err)
|
||||
} else if msg != "NOAUTH Authentication required." {
|
||||
t.Fatalf("unexpected LPOP NOAUTH error: %q", msg)
|
||||
}
|
||||
|
||||
if errWrite := writeTestRESPCommand(conn, "AUTH", managementPassword); errWrite != nil {
|
||||
t.Fatalf("failed to write AUTH command: %v", errWrite)
|
||||
}
|
||||
if msg, err := readTestRESPSimpleString(reader); err != nil {
|
||||
t.Fatalf("failed to read AUTH response: %v", err)
|
||||
} else if msg != "OK" {
|
||||
t.Fatalf("unexpected AUTH response: %q", msg)
|
||||
}
|
||||
|
||||
if !redisqueue.Enabled() {
|
||||
t.Fatalf("expected redisqueue to be enabled")
|
||||
}
|
||||
redisqueue.Enqueue([]byte("a"))
|
||||
redisqueue.Enqueue([]byte("b"))
|
||||
redisqueue.Enqueue([]byte("c"))
|
||||
|
||||
if errWrite := writeTestRESPCommand(conn, "RPOP", "queue"); errWrite != nil {
|
||||
t.Fatalf("failed to write RPOP command: %v", errWrite)
|
||||
}
|
||||
if item, err := readTestRESPBulkString(reader); err != nil {
|
||||
t.Fatalf("failed to read RPOP response: %v", err)
|
||||
} else if string(item) != "a" {
|
||||
t.Fatalf("unexpected RPOP item: %q", string(item))
|
||||
}
|
||||
|
||||
if errWrite := writeTestRESPCommand(conn, "LPOP", "queue"); errWrite != nil {
|
||||
t.Fatalf("failed to write LPOP command: %v", errWrite)
|
||||
}
|
||||
if item, err := readTestRESPBulkString(reader); err != nil {
|
||||
t.Fatalf("failed to read LPOP response: %v", err)
|
||||
} else if string(item) != "b" {
|
||||
t.Fatalf("unexpected LPOP item: %q", string(item))
|
||||
}
|
||||
|
||||
if errWrite := writeTestRESPCommand(conn, "RPOP", "queue", "10"); errWrite != nil {
|
||||
t.Fatalf("failed to write RPOP count command: %v", errWrite)
|
||||
}
|
||||
items, errItems := readRESPArrayOfBulkStrings(reader)
|
||||
if errItems != nil {
|
||||
t.Fatalf("failed to read RPOP count response: %v", errItems)
|
||||
}
|
||||
if len(items) != 1 || string(items[0]) != "c" {
|
||||
t.Fatalf("unexpected RPOP count items: %#v", items)
|
||||
}
|
||||
|
||||
if errWrite := writeTestRESPCommand(conn, "LPOP", "queue"); errWrite != nil {
|
||||
t.Fatalf("failed to write LPOP empty command: %v", errWrite)
|
||||
}
|
||||
item, errItem := readTestRESPBulkString(reader)
|
||||
if errItem != nil {
|
||||
t.Fatalf("failed to read LPOP empty response: %v", errItem)
|
||||
}
|
||||
if item != nil {
|
||||
t.Fatalf("expected nil bulk string for empty queue, got %q", string(item))
|
||||
}
|
||||
|
||||
if errWrite := writeTestRESPCommand(conn, "RPOP", "queue", "2"); errWrite != nil {
|
||||
t.Fatalf("failed to write RPOP empty count command: %v", errWrite)
|
||||
}
|
||||
emptyItems, errEmpty := readRESPArrayOfBulkStrings(reader)
|
||||
if errEmpty != nil {
|
||||
t.Fatalf("failed to read RPOP empty count response: %v", errEmpty)
|
||||
}
|
||||
if len(emptyItems) != 0 {
|
||||
t.Fatalf("expected empty array for empty queue with count, got %#v", emptyItems)
|
||||
}
|
||||
}
|
||||
@@ -7,8 +7,10 @@ package api
|
||||
import (
|
||||
"context"
|
||||
"crypto/subtle"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
@@ -28,6 +30,7 @@ import (
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/managementasset"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/redisqueue"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/usage"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
sdkaccess "github.com/router-for-me/CLIProxyAPI/v6/sdk/access"
|
||||
@@ -38,6 +41,7 @@ import (
|
||||
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/net/http2"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
@@ -127,6 +131,12 @@ type Server struct {
|
||||
// server is the underlying HTTP server.
|
||||
server *http.Server
|
||||
|
||||
// muxBaseListener is the shared TCP listener used to serve both HTTP and Redis protocol traffic.
|
||||
muxBaseListener net.Listener
|
||||
|
||||
// muxHTTPListener receives HTTP connections selected by the multiplexer.
|
||||
muxHTTPListener *muxListener
|
||||
|
||||
// handlers contains the API handlers for processing requests.
|
||||
handlers *handlers.BaseAPIHandler
|
||||
|
||||
@@ -299,6 +309,7 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk
|
||||
// or when a local management password is provided (e.g. TUI mode).
|
||||
hasManagementSecret := cfg.RemoteManagement.SecretKey != "" || envManagementSecret || s.localPassword != ""
|
||||
s.managementRoutesEnabled.Store(hasManagementSecret)
|
||||
redisqueue.SetEnabled(hasManagementSecret)
|
||||
if hasManagementSecret {
|
||||
s.registerManagementRoutes()
|
||||
}
|
||||
@@ -797,26 +808,98 @@ func (s *Server) Start() error {
|
||||
return fmt.Errorf("failed to start HTTP server: server not initialized")
|
||||
}
|
||||
|
||||
addr := s.server.Addr
|
||||
listener, errListen := net.Listen("tcp", addr)
|
||||
if errListen != nil {
|
||||
return fmt.Errorf("failed to start HTTP server: %v", errListen)
|
||||
}
|
||||
|
||||
useTLS := s.cfg != nil && s.cfg.TLS.Enable
|
||||
if useTLS {
|
||||
cert := strings.TrimSpace(s.cfg.TLS.Cert)
|
||||
key := strings.TrimSpace(s.cfg.TLS.Key)
|
||||
if cert == "" || key == "" {
|
||||
certPath := strings.TrimSpace(s.cfg.TLS.Cert)
|
||||
keyPath := strings.TrimSpace(s.cfg.TLS.Key)
|
||||
if certPath == "" || keyPath == "" {
|
||||
if errClose := listener.Close(); errClose != nil {
|
||||
log.Errorf("failed to close listener after TLS validation failure: %v", errClose)
|
||||
}
|
||||
return fmt.Errorf("failed to start HTTPS server: tls.cert or tls.key is empty")
|
||||
}
|
||||
log.Debugf("Starting API server on %s with TLS", s.server.Addr)
|
||||
if errServeTLS := s.server.ListenAndServeTLS(cert, key); errServeTLS != nil && !errors.Is(errServeTLS, http.ErrServerClosed) {
|
||||
return fmt.Errorf("failed to start HTTPS server: %v", errServeTLS)
|
||||
certPair, errLoad := tls.LoadX509KeyPair(certPath, keyPath)
|
||||
if errLoad != nil {
|
||||
if errClose := listener.Close(); errClose != nil {
|
||||
log.Errorf("failed to close listener after TLS key pair load failure: %v", errClose)
|
||||
}
|
||||
return fmt.Errorf("failed to start HTTPS server: %v", errLoad)
|
||||
}
|
||||
|
||||
tlsConfig := &tls.Config{
|
||||
Certificates: []tls.Certificate{certPair},
|
||||
NextProtos: []string{"h2", "http/1.1"},
|
||||
}
|
||||
s.server.TLSConfig = tlsConfig
|
||||
if errHTTP2 := http2.ConfigureServer(s.server, &http2.Server{}); errHTTP2 != nil {
|
||||
log.Warnf("failed to configure HTTP/2: %v", errHTTP2)
|
||||
}
|
||||
listener = tls.NewListener(listener, tlsConfig)
|
||||
log.Debugf("Starting API server on %s with TLS", addr)
|
||||
} else {
|
||||
log.Debugf("Starting API server on %s", addr)
|
||||
}
|
||||
|
||||
httpListener := newMuxListener(listener.Addr(), 1024)
|
||||
s.muxBaseListener = listener
|
||||
s.muxHTTPListener = httpListener
|
||||
|
||||
httpErrCh := make(chan error, 1)
|
||||
acceptErrCh := make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
httpErrCh <- s.server.Serve(httpListener)
|
||||
}()
|
||||
go func() {
|
||||
acceptErrCh <- s.acceptMuxConnections(listener, httpListener)
|
||||
}()
|
||||
|
||||
select {
|
||||
case errServe := <-httpErrCh:
|
||||
if s.muxBaseListener != nil {
|
||||
if errClose := s.muxBaseListener.Close(); errClose != nil && !errors.Is(errClose, net.ErrClosed) {
|
||||
log.Debugf("failed to close shared listener after HTTP serve exit: %v", errClose)
|
||||
}
|
||||
}
|
||||
if s.muxHTTPListener != nil {
|
||||
_ = s.muxHTTPListener.Close()
|
||||
}
|
||||
errAccept := <-acceptErrCh
|
||||
errServe = normalizeHTTPServeError(errServe)
|
||||
errAccept = normalizeListenerError(errAccept)
|
||||
if errServe != nil {
|
||||
return fmt.Errorf("failed to start HTTP server: %v", errServe)
|
||||
}
|
||||
if errAccept != nil {
|
||||
return fmt.Errorf("failed to start HTTP server: %v", errAccept)
|
||||
}
|
||||
return nil
|
||||
case errAccept := <-acceptErrCh:
|
||||
if s.muxHTTPListener != nil {
|
||||
_ = s.muxHTTPListener.Close()
|
||||
}
|
||||
if s.muxBaseListener != nil {
|
||||
if errClose := s.muxBaseListener.Close(); errClose != nil && !errors.Is(errClose, net.ErrClosed) {
|
||||
log.Debugf("failed to close shared listener after accept loop exit: %v", errClose)
|
||||
}
|
||||
}
|
||||
errServe := <-httpErrCh
|
||||
errServe = normalizeHTTPServeError(errServe)
|
||||
errAccept = normalizeListenerError(errAccept)
|
||||
if errAccept != nil {
|
||||
return fmt.Errorf("failed to start HTTP server: %v", errAccept)
|
||||
}
|
||||
if errServe != nil {
|
||||
return fmt.Errorf("failed to start HTTP server: %v", errServe)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Debugf("Starting API server on %s", s.server.Addr)
|
||||
if errServe := s.server.ListenAndServe(); errServe != nil && !errors.Is(errServe, http.ErrServerClosed) {
|
||||
return fmt.Errorf("failed to start HTTP server: %v", errServe)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop gracefully shuts down the API server without interrupting any
|
||||
@@ -837,6 +920,15 @@ func (s *Server) Stop(ctx context.Context) error {
|
||||
}
|
||||
}
|
||||
|
||||
if s.muxHTTPListener != nil {
|
||||
_ = s.muxHTTPListener.Close()
|
||||
}
|
||||
if s.muxBaseListener != nil {
|
||||
if errClose := s.muxBaseListener.Close(); errClose != nil && !errors.Is(errClose, net.ErrClosed) {
|
||||
log.Debugf("failed to close shared listener: %v", errClose)
|
||||
}
|
||||
}
|
||||
|
||||
// Shutdown the HTTP server.
|
||||
if err := s.server.Shutdown(ctx); err != nil {
|
||||
return fmt.Errorf("failed to shutdown HTTP server: %v", err)
|
||||
@@ -963,6 +1055,7 @@ func (s *Server) UpdateClients(cfg *config.Config) {
|
||||
s.managementRoutesEnabled.Store(!newSecretEmpty)
|
||||
}
|
||||
}
|
||||
redisqueue.SetEnabled(s.managementRoutesEnabled.Load())
|
||||
|
||||
s.applyAccessConfig(oldCfg, cfg)
|
||||
s.cfg = cfg
|
||||
|
||||
145
internal/redisqueue/plugin.go
Normal file
145
internal/redisqueue/plugin.go
Normal file
@@ -0,0 +1,145 @@
|
||||
package redisqueue
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
internallogging "github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
||||
internalusage "github.com/router-for-me/CLIProxyAPI/v6/internal/usage"
|
||||
coreusage "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/usage"
|
||||
)
|
||||
|
||||
func init() {
|
||||
coreusage.RegisterPlugin(&usageQueuePlugin{})
|
||||
}
|
||||
|
||||
type usageQueuePlugin struct{}
|
||||
|
||||
func (p *usageQueuePlugin) HandleUsage(ctx context.Context, record coreusage.Record) {
|
||||
if p == nil {
|
||||
return
|
||||
}
|
||||
if !Enabled() || !internalusage.StatisticsEnabled() {
|
||||
return
|
||||
}
|
||||
|
||||
timestamp := record.RequestedAt
|
||||
if timestamp.IsZero() {
|
||||
timestamp = time.Now()
|
||||
}
|
||||
|
||||
modelName := strings.TrimSpace(record.Model)
|
||||
if modelName == "" {
|
||||
modelName = "unknown"
|
||||
}
|
||||
provider := strings.TrimSpace(record.Provider)
|
||||
if provider == "" {
|
||||
provider = "unknown"
|
||||
}
|
||||
authType := strings.TrimSpace(record.AuthType)
|
||||
if authType == "" {
|
||||
authType = "unknown"
|
||||
}
|
||||
apiKey := strings.TrimSpace(record.APIKey)
|
||||
requestID := strings.TrimSpace(internallogging.GetRequestID(ctx))
|
||||
if requestID == "" {
|
||||
if ginCtx, ok := ctx.Value("gin").(*gin.Context); ok && ginCtx != nil {
|
||||
requestID = strings.TrimSpace(internallogging.GetGinRequestID(ginCtx))
|
||||
}
|
||||
}
|
||||
|
||||
tokens := internalusage.TokenStats{
|
||||
InputTokens: record.Detail.InputTokens,
|
||||
OutputTokens: record.Detail.OutputTokens,
|
||||
ReasoningTokens: record.Detail.ReasoningTokens,
|
||||
CachedTokens: record.Detail.CachedTokens,
|
||||
TotalTokens: record.Detail.TotalTokens,
|
||||
}
|
||||
if tokens.TotalTokens == 0 {
|
||||
tokens.TotalTokens = tokens.InputTokens + tokens.OutputTokens + tokens.ReasoningTokens
|
||||
}
|
||||
if tokens.TotalTokens == 0 {
|
||||
tokens.TotalTokens = tokens.InputTokens + tokens.OutputTokens + tokens.ReasoningTokens + tokens.CachedTokens
|
||||
}
|
||||
|
||||
failed := record.Failed
|
||||
if !failed {
|
||||
failed = !resolveSuccess(ctx)
|
||||
}
|
||||
|
||||
detail := internalusage.RequestDetail{
|
||||
Timestamp: timestamp,
|
||||
LatencyMs: record.Latency.Milliseconds(),
|
||||
Source: record.Source,
|
||||
AuthIndex: record.AuthIndex,
|
||||
Tokens: tokens,
|
||||
Failed: failed,
|
||||
}
|
||||
|
||||
payload, err := json.Marshal(queuedUsageDetail{
|
||||
RequestDetail: detail,
|
||||
Provider: provider,
|
||||
Model: modelName,
|
||||
Endpoint: resolveEndpoint(ctx),
|
||||
AuthType: authType,
|
||||
APIKey: apiKey,
|
||||
RequestID: requestID,
|
||||
})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
Enqueue(payload)
|
||||
}
|
||||
|
||||
type queuedUsageDetail struct {
|
||||
internalusage.RequestDetail
|
||||
Provider string `json:"provider"`
|
||||
Model string `json:"model"`
|
||||
Endpoint string `json:"endpoint"`
|
||||
AuthType string `json:"auth_type"`
|
||||
APIKey string `json:"api_key"`
|
||||
RequestID string `json:"request_id"`
|
||||
}
|
||||
|
||||
func resolveSuccess(ctx context.Context) bool {
|
||||
if ctx == nil {
|
||||
return true
|
||||
}
|
||||
ginCtx, ok := ctx.Value("gin").(*gin.Context)
|
||||
if !ok || ginCtx == nil {
|
||||
return true
|
||||
}
|
||||
status := ginCtx.Writer.Status()
|
||||
if status == 0 {
|
||||
return true
|
||||
}
|
||||
return status < http.StatusBadRequest
|
||||
}
|
||||
|
||||
func resolveEndpoint(ctx context.Context) string {
|
||||
if ctx == nil {
|
||||
return ""
|
||||
}
|
||||
ginCtx, ok := ctx.Value("gin").(*gin.Context)
|
||||
if !ok || ginCtx == nil || ginCtx.Request == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
path := strings.TrimSpace(ginCtx.FullPath())
|
||||
if path == "" && ginCtx.Request.URL != nil {
|
||||
path = strings.TrimSpace(ginCtx.Request.URL.Path)
|
||||
}
|
||||
if path == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
method := strings.TrimSpace(ginCtx.Request.Method)
|
||||
if method == "" {
|
||||
return path
|
||||
}
|
||||
return method + " " + path
|
||||
}
|
||||
160
internal/redisqueue/plugin_test.go
Normal file
160
internal/redisqueue/plugin_test.go
Normal file
@@ -0,0 +1,160 @@
|
||||
package redisqueue
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
internallogging "github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
||||
internalusage "github.com/router-for-me/CLIProxyAPI/v6/internal/usage"
|
||||
coreusage "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/usage"
|
||||
)
|
||||
|
||||
func TestUsageQueuePluginPayloadIncludesStableFieldsAndSuccess(t *testing.T) {
|
||||
withEnabledQueue(t, func() {
|
||||
ginCtx := newTestGinContext(t, http.MethodPost, "/v1/chat/completions", http.StatusOK)
|
||||
internallogging.SetGinRequestID(ginCtx, "gin-request-id-ignored")
|
||||
ctx := context.WithValue(internallogging.WithRequestID(context.Background(), "ctx-request-id"), "gin", ginCtx)
|
||||
|
||||
plugin := &usageQueuePlugin{}
|
||||
plugin.HandleUsage(ctx, coreusage.Record{
|
||||
Provider: "openai",
|
||||
Model: "gpt-5.4",
|
||||
APIKey: "test-key",
|
||||
AuthIndex: "0",
|
||||
AuthType: "apikey",
|
||||
Source: "user@example.com",
|
||||
RequestedAt: time.Date(2026, 4, 25, 0, 0, 0, 0, time.UTC),
|
||||
Latency: 1500 * time.Millisecond,
|
||||
Detail: coreusage.Detail{
|
||||
InputTokens: 10,
|
||||
OutputTokens: 20,
|
||||
TotalTokens: 30,
|
||||
},
|
||||
})
|
||||
|
||||
payload := popSinglePayload(t)
|
||||
requireStringField(t, payload, "provider", "openai")
|
||||
requireStringField(t, payload, "model", "gpt-5.4")
|
||||
requireStringField(t, payload, "endpoint", "POST /v1/chat/completions")
|
||||
requireStringField(t, payload, "auth_type", "apikey")
|
||||
requireStringField(t, payload, "request_id", "ctx-request-id")
|
||||
requireBoolField(t, payload, "failed", false)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUsageQueuePluginPayloadIncludesStableFieldsAndFailureAndGinRequestID(t *testing.T) {
|
||||
withEnabledQueue(t, func() {
|
||||
ginCtx := newTestGinContext(t, http.MethodGet, "/v1/responses", http.StatusInternalServerError)
|
||||
internallogging.SetGinRequestID(ginCtx, "gin-request-id")
|
||||
ctx := context.WithValue(context.Background(), "gin", ginCtx)
|
||||
|
||||
plugin := &usageQueuePlugin{}
|
||||
plugin.HandleUsage(ctx, coreusage.Record{
|
||||
Provider: "openai",
|
||||
Model: "gpt-5.4-mini",
|
||||
APIKey: "test-key",
|
||||
AuthIndex: "0",
|
||||
AuthType: "apikey",
|
||||
Source: "user@example.com",
|
||||
RequestedAt: time.Date(2026, 4, 25, 0, 0, 0, 0, time.UTC),
|
||||
Latency: 2500 * time.Millisecond,
|
||||
Detail: coreusage.Detail{
|
||||
InputTokens: 10,
|
||||
OutputTokens: 20,
|
||||
TotalTokens: 30,
|
||||
},
|
||||
})
|
||||
|
||||
payload := popSinglePayload(t)
|
||||
requireStringField(t, payload, "provider", "openai")
|
||||
requireStringField(t, payload, "model", "gpt-5.4-mini")
|
||||
requireStringField(t, payload, "endpoint", "GET /v1/responses")
|
||||
requireStringField(t, payload, "auth_type", "apikey")
|
||||
requireStringField(t, payload, "request_id", "gin-request-id")
|
||||
requireBoolField(t, payload, "failed", true)
|
||||
})
|
||||
}
|
||||
|
||||
func withEnabledQueue(t *testing.T, fn func()) {
|
||||
t.Helper()
|
||||
|
||||
prevQueueEnabled := Enabled()
|
||||
prevStatsEnabled := internalusage.StatisticsEnabled()
|
||||
|
||||
SetEnabled(false)
|
||||
SetEnabled(true)
|
||||
internalusage.SetStatisticsEnabled(true)
|
||||
|
||||
defer func() {
|
||||
SetEnabled(false)
|
||||
SetEnabled(prevQueueEnabled)
|
||||
internalusage.SetStatisticsEnabled(prevStatsEnabled)
|
||||
}()
|
||||
|
||||
fn()
|
||||
}
|
||||
|
||||
func newTestGinContext(t *testing.T, method, path string, status int) *gin.Context {
|
||||
t.Helper()
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
recorder := httptest.NewRecorder()
|
||||
ginCtx, _ := gin.CreateTestContext(recorder)
|
||||
ginCtx.Request = httptest.NewRequest(method, "http://example.com"+path, nil)
|
||||
if status != 0 {
|
||||
ginCtx.Status(status)
|
||||
}
|
||||
return ginCtx
|
||||
}
|
||||
|
||||
func popSinglePayload(t *testing.T) map[string]json.RawMessage {
|
||||
t.Helper()
|
||||
|
||||
items := PopOldest(10)
|
||||
if len(items) != 1 {
|
||||
t.Fatalf("PopOldest() items = %d, want 1", len(items))
|
||||
}
|
||||
|
||||
var payload map[string]json.RawMessage
|
||||
if err := json.Unmarshal(items[0], &payload); err != nil {
|
||||
t.Fatalf("unmarshal payload: %v", err)
|
||||
}
|
||||
return payload
|
||||
}
|
||||
|
||||
func requireStringField(t *testing.T, payload map[string]json.RawMessage, key, want string) {
|
||||
t.Helper()
|
||||
|
||||
raw, ok := payload[key]
|
||||
if !ok {
|
||||
t.Fatalf("payload missing %q", key)
|
||||
}
|
||||
var got string
|
||||
if err := json.Unmarshal(raw, &got); err != nil {
|
||||
t.Fatalf("unmarshal %q: %v", key, err)
|
||||
}
|
||||
if got != want {
|
||||
t.Fatalf("%s = %q, want %q", key, got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func requireBoolField(t *testing.T, payload map[string]json.RawMessage, key string, want bool) {
|
||||
t.Helper()
|
||||
|
||||
raw, ok := payload[key]
|
||||
if !ok {
|
||||
t.Fatalf("payload missing %q", key)
|
||||
}
|
||||
var got bool
|
||||
if err := json.Unmarshal(raw, &got); err != nil {
|
||||
t.Fatalf("unmarshal %q: %v", key, err)
|
||||
}
|
||||
if got != want {
|
||||
t.Fatalf("%s = %t, want %t", key, got, want)
|
||||
}
|
||||
}
|
||||
133
internal/redisqueue/queue.go
Normal file
133
internal/redisqueue/queue.go
Normal file
@@ -0,0 +1,133 @@
|
||||
package redisqueue
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
const retentionWindow = time.Minute
|
||||
|
||||
type queueItem struct {
|
||||
enqueuedAt time.Time
|
||||
payload []byte
|
||||
}
|
||||
|
||||
type queue struct {
|
||||
mu sync.Mutex
|
||||
items []queueItem
|
||||
head int
|
||||
}
|
||||
|
||||
var (
|
||||
enabled atomic.Bool
|
||||
global queue
|
||||
)
|
||||
|
||||
func SetEnabled(value bool) {
|
||||
enabled.Store(value)
|
||||
if !value {
|
||||
global.clear()
|
||||
}
|
||||
}
|
||||
|
||||
func Enabled() bool {
|
||||
return enabled.Load()
|
||||
}
|
||||
|
||||
func Enqueue(payload []byte) {
|
||||
if !Enabled() {
|
||||
return
|
||||
}
|
||||
if len(payload) == 0 {
|
||||
return
|
||||
}
|
||||
global.enqueue(payload)
|
||||
}
|
||||
|
||||
func PopOldest(count int) [][]byte {
|
||||
if !Enabled() {
|
||||
return nil
|
||||
}
|
||||
if count <= 0 {
|
||||
return nil
|
||||
}
|
||||
return global.popOldest(count)
|
||||
}
|
||||
|
||||
func (q *queue) clear() {
|
||||
q.mu.Lock()
|
||||
defer q.mu.Unlock()
|
||||
q.items = nil
|
||||
q.head = 0
|
||||
}
|
||||
|
||||
func (q *queue) enqueue(payload []byte) {
|
||||
now := time.Now()
|
||||
|
||||
q.mu.Lock()
|
||||
defer q.mu.Unlock()
|
||||
|
||||
q.pruneLocked(now)
|
||||
q.items = append(q.items, queueItem{
|
||||
enqueuedAt: now,
|
||||
payload: append([]byte(nil), payload...),
|
||||
})
|
||||
q.maybeCompactLocked()
|
||||
}
|
||||
|
||||
func (q *queue) popOldest(count int) [][]byte {
|
||||
now := time.Now()
|
||||
|
||||
q.mu.Lock()
|
||||
defer q.mu.Unlock()
|
||||
|
||||
q.pruneLocked(now)
|
||||
available := len(q.items) - q.head
|
||||
if available <= 0 {
|
||||
q.items = nil
|
||||
q.head = 0
|
||||
return nil
|
||||
}
|
||||
if count > available {
|
||||
count = available
|
||||
}
|
||||
|
||||
out := make([][]byte, 0, count)
|
||||
for i := 0; i < count; i++ {
|
||||
item := q.items[q.head+i]
|
||||
out = append(out, item.payload)
|
||||
}
|
||||
q.head += count
|
||||
q.maybeCompactLocked()
|
||||
return out
|
||||
}
|
||||
|
||||
func (q *queue) pruneLocked(now time.Time) {
|
||||
if q.head >= len(q.items) {
|
||||
q.items = nil
|
||||
q.head = 0
|
||||
return
|
||||
}
|
||||
|
||||
cutoff := now.Add(-retentionWindow)
|
||||
for q.head < len(q.items) && q.items[q.head].enqueuedAt.Before(cutoff) {
|
||||
q.head++
|
||||
}
|
||||
}
|
||||
|
||||
func (q *queue) maybeCompactLocked() {
|
||||
if q.head == 0 {
|
||||
return
|
||||
}
|
||||
if q.head >= len(q.items) {
|
||||
q.items = nil
|
||||
q.head = 0
|
||||
return
|
||||
}
|
||||
if q.head < 1024 && q.head*2 < len(q.items) {
|
||||
return
|
||||
}
|
||||
q.items = append([]queueItem(nil), q.items[q.head:]...)
|
||||
q.head = 0
|
||||
}
|
||||
@@ -20,6 +20,7 @@ type UsageReporter struct {
|
||||
model string
|
||||
authID string
|
||||
authIndex string
|
||||
authType string
|
||||
apiKey string
|
||||
source string
|
||||
requestedAt time.Time
|
||||
@@ -34,6 +35,7 @@ func NewUsageReporter(ctx context.Context, provider, model string, auth *cliprox
|
||||
requestedAt: time.Now(),
|
||||
apiKey: apiKey,
|
||||
source: resolveUsageSource(auth, apiKey),
|
||||
authType: resolveUsageAuthType(auth),
|
||||
}
|
||||
if auth != nil {
|
||||
reporter.authID = auth.ID
|
||||
@@ -98,6 +100,7 @@ func (r *UsageReporter) buildRecord(detail usage.Detail, failed bool) usage.Reco
|
||||
APIKey: r.apiKey,
|
||||
AuthID: r.authID,
|
||||
AuthIndex: r.authIndex,
|
||||
AuthType: r.authType,
|
||||
RequestedAt: r.requestedAt,
|
||||
Latency: r.latency(),
|
||||
Failed: failed,
|
||||
@@ -181,6 +184,18 @@ func resolveUsageSource(auth *cliproxyauth.Auth, ctxAPIKey string) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func resolveUsageAuthType(auth *cliproxyauth.Auth) string {
|
||||
if auth == nil {
|
||||
return ""
|
||||
}
|
||||
kind, _ := auth.AccountInfo()
|
||||
kind = strings.TrimSpace(kind)
|
||||
if kind == "api_key" {
|
||||
return "apikey"
|
||||
}
|
||||
return kind
|
||||
}
|
||||
|
||||
func ParseCodexUsage(data []byte) (usage.Detail, bool) {
|
||||
usageNode := gjson.ParseBytes(data).Get("response.usage")
|
||||
if !usageNode.Exists() {
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/api"
|
||||
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/redisqueue"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor"
|
||||
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/usage"
|
||||
|
||||
@@ -15,6 +15,7 @@ type Record struct {
|
||||
APIKey string
|
||||
AuthID string
|
||||
AuthIndex string
|
||||
AuthType string
|
||||
Source string
|
||||
RequestedAt time.Time
|
||||
Latency time.Duration
|
||||
|
||||
Reference in New Issue
Block a user