feat(api): implement protocol multiplexer and Redis queue for usage integration

- Added `protocol_multiplexer.go`, enabling support for both HTTP and Redis protocols on a single listener.
- Introduced `redis_queue_protocol.go` to handle Redis-compatible RESP commands for queue management.
- Integrated `redisqueue` package, supporting in-memory queuing with expiration pruning.
- Updated server initialization to manage a shared listener and multiplex connections.
- Adjusted `Handler` to adopt `AuthenticateManagementKey` for modular key validation, supporting both HTTP and Redis flows.
This commit is contained in:
Luis Pater
2026-04-25 16:12:35 +08:00
parent be0fe6fab3
commit 28d78273e4
13 changed files with 1490 additions and 102 deletions

View File

@@ -0,0 +1,32 @@
package api
import (
"bufio"
"crypto/tls"
"net"
)
type bufferedConn struct {
net.Conn
reader *bufio.Reader
}
func (c *bufferedConn) Read(p []byte) (int, error) {
if c == nil {
return 0, net.ErrClosed
}
if c.reader == nil {
return c.Conn.Read(p)
}
return c.reader.Read(p)
}
func (c *bufferedConn) ConnectionState() tls.ConnectionState {
if c == nil || c.Conn == nil {
return tls.ConnectionState{}
}
if stater, ok := c.Conn.(interface{ ConnectionState() tls.ConnectionState }); ok {
return stater.ConnectionState()
}
return tls.ConnectionState{}
}

View File

@@ -152,9 +152,6 @@ func (h *Handler) SetPostAuthHook(hook coreauth.PostAuthHook) {
// All requests (local and remote) require a valid management key.
// Additionally, remote access requires allow-remote-management=true.
func (h *Handler) Middleware() gin.HandlerFunc {
const maxFailures = 5
const banDuration = 30 * time.Minute
return func(c *gin.Context) {
c.Header("X-CPA-VERSION", buildinfo.Version)
c.Header("X-CPA-COMMIT", buildinfo.Commit)
@@ -162,64 +159,6 @@ func (h *Handler) Middleware() gin.HandlerFunc {
clientIP := c.ClientIP()
localClient := clientIP == "127.0.0.1" || clientIP == "::1"
cfg := h.cfg
var (
allowRemote bool
secretHash string
)
if cfg != nil {
allowRemote = cfg.RemoteManagement.AllowRemote
secretHash = cfg.RemoteManagement.SecretKey
}
if h.allowRemoteOverride {
allowRemote = true
}
envSecret := h.envSecret
fail := func() {}
if !localClient {
h.attemptsMu.Lock()
ai := h.failedAttempts[clientIP]
if ai != nil {
if !ai.blockedUntil.IsZero() {
if time.Now().Before(ai.blockedUntil) {
remaining := time.Until(ai.blockedUntil).Round(time.Second)
h.attemptsMu.Unlock()
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{"error": fmt.Sprintf("IP banned due to too many failed attempts. Try again in %s", remaining)})
return
}
// Ban expired, reset state
ai.blockedUntil = time.Time{}
ai.count = 0
}
}
h.attemptsMu.Unlock()
if !allowRemote {
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{"error": "remote management disabled"})
return
}
fail = func() {
h.attemptsMu.Lock()
aip := h.failedAttempts[clientIP]
if aip == nil {
aip = &attemptInfo{}
h.failedAttempts[clientIP] = aip
}
aip.count++
aip.lastActivity = time.Now()
if aip.count >= maxFailures {
aip.blockedUntil = time.Now().Add(banDuration)
aip.count = 0
}
h.attemptsMu.Unlock()
}
}
if secretHash == "" && envSecret == "" {
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{"error": "remote management key not set"})
return
}
// Accept either Authorization: Bearer <key> or X-Management-Key
var provided string
@@ -235,44 +174,98 @@ func (h *Handler) Middleware() gin.HandlerFunc {
provided = c.GetHeader("X-Management-Key")
}
if provided == "" {
if !localClient {
fail()
}
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "missing management key"})
allowed, statusCode, errMsg := h.AuthenticateManagementKey(clientIP, localClient, provided)
if !allowed {
c.AbortWithStatusJSON(statusCode, gin.H{"error": errMsg})
return
}
c.Next()
}
}
if localClient {
if lp := h.localPassword; lp != "" {
if subtle.ConstantTimeCompare([]byte(provided), []byte(lp)) == 1 {
c.Next()
return
// AuthenticateManagementKey verifies the provided management key for the given client.
// It mirrors the behaviour of Middleware() so non-HTTP callers can reuse the same logic.
func (h *Handler) AuthenticateManagementKey(clientIP string, localClient bool, provided string) (bool, int, string) {
const maxFailures = 5
const banDuration = 30 * time.Minute
if h == nil {
return false, http.StatusForbidden, "remote management disabled"
}
cfg := h.cfg
var (
allowRemote bool
secretHash string
)
if cfg != nil {
allowRemote = cfg.RemoteManagement.AllowRemote
secretHash = cfg.RemoteManagement.SecretKey
}
if h.allowRemoteOverride {
allowRemote = true
}
envSecret := h.envSecret
fail := func() {}
if !localClient {
h.attemptsMu.Lock()
ai := h.failedAttempts[clientIP]
if ai != nil {
if !ai.blockedUntil.IsZero() {
if time.Now().Before(ai.blockedUntil) {
remaining := time.Until(ai.blockedUntil).Round(time.Second)
h.attemptsMu.Unlock()
return false, http.StatusForbidden, fmt.Sprintf("IP banned due to too many failed attempts. Try again in %s", remaining)
}
// Ban expired, reset state
ai.blockedUntil = time.Time{}
ai.count = 0
}
}
h.attemptsMu.Unlock()
if envSecret != "" && subtle.ConstantTimeCompare([]byte(provided), []byte(envSecret)) == 1 {
if !localClient {
h.attemptsMu.Lock()
if ai := h.failedAttempts[clientIP]; ai != nil {
ai.count = 0
ai.blockedUntil = time.Time{}
}
h.attemptsMu.Unlock()
}
c.Next()
return
if !allowRemote {
return false, http.StatusForbidden, "remote management disabled"
}
if secretHash == "" || bcrypt.CompareHashAndPassword([]byte(secretHash), []byte(provided)) != nil {
if !localClient {
fail()
fail = func() {
h.attemptsMu.Lock()
aip := h.failedAttempts[clientIP]
if aip == nil {
aip = &attemptInfo{}
h.failedAttempts[clientIP] = aip
}
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "invalid management key"})
return
aip.count++
aip.lastActivity = time.Now()
if aip.count >= maxFailures {
aip.blockedUntil = time.Now().Add(banDuration)
aip.count = 0
}
h.attemptsMu.Unlock()
}
}
if secretHash == "" && envSecret == "" {
return false, http.StatusForbidden, "remote management key not set"
}
if provided == "" {
if !localClient {
fail()
}
return false, http.StatusUnauthorized, "missing management key"
}
if localClient {
if lp := h.localPassword; lp != "" {
if subtle.ConstantTimeCompare([]byte(provided), []byte(lp)) == 1 {
return true, 0, ""
}
}
}
if envSecret != "" && subtle.ConstantTimeCompare([]byte(provided), []byte(envSecret)) == 1 {
if !localClient {
h.attemptsMu.Lock()
if ai := h.failedAttempts[clientIP]; ai != nil {
@@ -281,9 +274,26 @@ func (h *Handler) Middleware() gin.HandlerFunc {
}
h.attemptsMu.Unlock()
}
c.Next()
return true, 0, ""
}
if secretHash == "" || bcrypt.CompareHashAndPassword([]byte(secretHash), []byte(provided)) != nil {
if !localClient {
fail()
}
return false, http.StatusUnauthorized, "invalid management key"
}
if !localClient {
h.attemptsMu.Lock()
if ai := h.failedAttempts[clientIP]; ai != nil {
ai.count = 0
ai.blockedUntil = time.Time{}
}
h.attemptsMu.Unlock()
}
return true, 0, ""
}
// persist saves the current in-memory config to disk.

View File

@@ -0,0 +1,68 @@
package api
import (
"net"
"sync"
)
type muxListener struct {
addr net.Addr
connCh chan net.Conn
closeCh chan struct{}
once sync.Once
}
func newMuxListener(addr net.Addr, buffer int) *muxListener {
if buffer <= 0 {
buffer = 1
}
return &muxListener{
addr: addr,
connCh: make(chan net.Conn, buffer),
closeCh: make(chan struct{}),
}
}
func (l *muxListener) Put(conn net.Conn) error {
if conn == nil {
return nil
}
select {
case <-l.closeCh:
return net.ErrClosed
case l.connCh <- conn:
return nil
}
}
func (l *muxListener) Accept() (net.Conn, error) {
select {
case <-l.closeCh:
return nil, net.ErrClosed
case conn := <-l.connCh:
if conn == nil {
return nil, net.ErrClosed
}
return conn, nil
}
}
func (l *muxListener) Close() error {
if l == nil {
return nil
}
l.once.Do(func() {
close(l.closeCh)
})
return nil
}
func (l *muxListener) Addr() net.Addr {
if l == nil {
return &net.TCPAddr{}
}
if l.addr == nil {
return &net.TCPAddr{}
}
return l.addr
}

View File

@@ -0,0 +1,109 @@
package api
import (
"bufio"
"crypto/tls"
"errors"
"net"
"net/http"
"strings"
log "github.com/sirupsen/logrus"
)
func normalizeHTTPServeError(err error) error {
if err == nil {
return nil
}
if errors.Is(err, net.ErrClosed) {
return nil
}
if errors.Is(err, http.ErrServerClosed) {
return nil
}
return err
}
func normalizeListenerError(err error) error {
if err == nil {
return nil
}
if errors.Is(err, net.ErrClosed) {
return nil
}
return err
}
func (s *Server) acceptMuxConnections(listener net.Listener, httpListener *muxListener) error {
if s == nil || listener == nil {
return net.ErrClosed
}
for {
conn, errAccept := listener.Accept()
if errAccept != nil {
return errAccept
}
if conn == nil {
continue
}
tlsConn, ok := conn.(*tls.Conn)
if ok {
if errHandshake := tlsConn.Handshake(); errHandshake != nil {
if errClose := conn.Close(); errClose != nil {
log.Errorf("failed to close connection after TLS handshake error: %v", errClose)
}
continue
}
proto := strings.TrimSpace(tlsConn.ConnectionState().NegotiatedProtocol)
if proto == "h2" || proto == "http/1.1" {
if httpListener == nil {
if errClose := conn.Close(); errClose != nil {
log.Errorf("failed to close connection: %v", errClose)
}
continue
}
if errPut := httpListener.Put(tlsConn); errPut != nil {
if errClose := conn.Close(); errClose != nil {
log.Errorf("failed to close connection after HTTP routing failure: %v", errClose)
}
}
continue
}
}
reader := bufio.NewReader(conn)
prefix, errPeek := reader.Peek(1)
if errPeek != nil {
if errClose := conn.Close(); errClose != nil {
log.Errorf("failed to close connection after protocol peek failure: %v", errClose)
}
continue
}
if isRedisRESPPrefix(prefix[0]) {
if !s.managementRoutesEnabled.Load() {
if errClose := conn.Close(); errClose != nil {
log.Errorf("failed to close redis connection while management is disabled: %v", errClose)
}
continue
}
go s.handleRedisConnection(conn, reader)
continue
}
if httpListener == nil {
if errClose := conn.Close(); errClose != nil {
log.Errorf("failed to close connection without HTTP listener: %v", errClose)
}
continue
}
if errPut := httpListener.Put(&bufferedConn{Conn: conn, reader: reader}); errPut != nil {
if errClose := conn.Close(); errClose != nil {
log.Errorf("failed to close connection after HTTP routing failure: %v", errClose)
}
}
}
}

View File

@@ -0,0 +1,317 @@
package api
import (
"bufio"
"errors"
"fmt"
"io"
"net"
"strconv"
"strings"
"github.com/router-for-me/CLIProxyAPI/v6/internal/redisqueue"
log "github.com/sirupsen/logrus"
)
func isRedisRESPPrefix(prefix byte) bool {
switch prefix {
case '*', '$', '+', '-', ':':
return true
default:
return false
}
}
func (s *Server) handleRedisConnection(conn net.Conn, reader *bufio.Reader) {
if s == nil || conn == nil || reader == nil {
return
}
clientIP, localClient := resolveRemoteIP(conn.RemoteAddr())
authed := false
writer := bufio.NewWriter(conn)
defer func() {
if errClose := conn.Close(); errClose != nil {
log.Errorf("redis connection close error: %v", errClose)
}
}()
flush := func() bool {
if errFlush := writer.Flush(); errFlush != nil {
log.Errorf("redis protocol flush error: %v", errFlush)
return false
}
return true
}
for {
if !s.managementRoutesEnabled.Load() {
return
}
args, err := readRESPArray(reader)
if err != nil {
if !errors.Is(err, io.EOF) {
_ = writeRedisError(writer, "ERR "+err.Error())
_ = writer.Flush()
}
return
}
if len(args) == 0 {
_ = writeRedisError(writer, "ERR empty command")
if !flush() {
return
}
continue
}
cmd := strings.ToUpper(strings.TrimSpace(args[0]))
switch cmd {
case "AUTH":
password, ok := parseAuthPassword(args)
if !ok {
_ = writeRedisError(writer, "ERR wrong number of arguments for 'auth' command")
if !flush() {
return
}
continue
}
if s.mgmt == nil {
_ = writeRedisError(writer, "ERR remote management disabled")
if !flush() {
return
}
continue
}
allowed, _, errMsg := s.mgmt.AuthenticateManagementKey(clientIP, localClient, password)
if !allowed {
_ = writeRedisError(writer, "ERR "+errMsg)
if !flush() {
return
}
continue
}
authed = true
_ = writeRedisSimpleString(writer, "OK")
if !flush() {
return
}
case "LPOP", "RPOP":
if !authed {
_ = writeRedisError(writer, "NOAUTH Authentication required.")
if !flush() {
return
}
continue
}
count, hasCount, ok := parsePopCount(args)
if !ok {
_ = writeRedisError(writer, "ERR wrong number of arguments for '"+strings.ToLower(cmd)+"' command")
if !flush() {
return
}
continue
}
if count <= 0 {
_ = writeRedisError(writer, "ERR value is not an integer or out of range")
if !flush() {
return
}
continue
}
items := redisqueue.PopOldest(count)
if hasCount {
_ = writeRedisArrayOfBulkStrings(writer, items)
if !flush() {
return
}
continue
}
if len(items) == 0 {
_ = writeRedisNilBulkString(writer)
if !flush() {
return
}
continue
}
_ = writeRedisBulkString(writer, items[0])
if !flush() {
return
}
default:
_ = writeRedisError(writer, fmt.Sprintf("ERR unknown command '%s'", strings.ToLower(cmd)))
if !flush() {
return
}
}
}
}
func resolveRemoteIP(addr net.Addr) (ip string, localClient bool) {
if addr == nil {
return "", false
}
host := addr.String()
if h, _, err := net.SplitHostPort(host); err == nil {
host = h
}
host = strings.TrimSpace(host)
localClient = host == "127.0.0.1" || host == "::1"
return host, localClient
}
func parseAuthPassword(args []string) (string, bool) {
switch len(args) {
case 2:
return args[1], true
case 3:
// Support AUTH <username> <password> by ignoring username for compatibility.
return args[2], true
default:
return "", false
}
}
func parsePopCount(args []string) (count int, hasCount bool, ok bool) {
if len(args) != 2 && len(args) != 3 {
return 0, false, false
}
if len(args) == 2 {
return 1, false, true
}
parsed, err := strconv.Atoi(strings.TrimSpace(args[2]))
if err != nil {
return 0, true, true
}
return parsed, true, true
}
func readRESPArray(reader *bufio.Reader) ([]string, error) {
prefix, err := reader.ReadByte()
if err != nil {
return nil, err
}
if prefix != '*' {
return nil, fmt.Errorf("protocol error")
}
line, err := readRESPLine(reader)
if err != nil {
return nil, err
}
count, err := strconv.Atoi(line)
if err != nil || count < 0 {
return nil, fmt.Errorf("protocol error")
}
args := make([]string, 0, count)
for i := 0; i < count; i++ {
value, err := readRESPString(reader)
if err != nil {
return nil, err
}
args = append(args, value)
}
return args, nil
}
func readRESPString(reader *bufio.Reader) (string, error) {
prefix, err := reader.ReadByte()
if err != nil {
return "", err
}
switch prefix {
case '$':
return readRESPBulkString(reader)
case '+', ':':
return readRESPLine(reader)
default:
return "", fmt.Errorf("protocol error")
}
}
func readRESPBulkString(reader *bufio.Reader) (string, error) {
line, err := readRESPLine(reader)
if err != nil {
return "", err
}
length, err := strconv.Atoi(line)
if err != nil {
return "", fmt.Errorf("protocol error")
}
if length < 0 {
return "", nil
}
buf := make([]byte, length+2)
if _, err := io.ReadFull(reader, buf); err != nil {
return "", err
}
if length+2 < 2 || buf[length] != '\r' || buf[length+1] != '\n' {
return "", fmt.Errorf("protocol error")
}
return string(buf[:length]), nil
}
func readRESPLine(reader *bufio.Reader) (string, error) {
line, err := reader.ReadString('\n')
if err != nil {
return "", err
}
line = strings.TrimSuffix(line, "\n")
line = strings.TrimSuffix(line, "\r")
return line, nil
}
func writeRedisSimpleString(writer *bufio.Writer, value string) error {
if writer == nil {
return net.ErrClosed
}
_, err := writer.WriteString("+" + value + "\r\n")
return err
}
func writeRedisError(writer *bufio.Writer, message string) error {
if writer == nil {
return net.ErrClosed
}
_, err := writer.WriteString("-" + message + "\r\n")
return err
}
func writeRedisNilBulkString(writer *bufio.Writer) error {
if writer == nil {
return net.ErrClosed
}
_, err := writer.WriteString("$-1\r\n")
return err
}
func writeRedisBulkString(writer *bufio.Writer, payload []byte) error {
if writer == nil {
return net.ErrClosed
}
if payload == nil {
return writeRedisNilBulkString(writer)
}
if _, err := writer.WriteString("$" + strconv.Itoa(len(payload)) + "\r\n"); err != nil {
return err
}
if _, err := writer.Write(payload); err != nil {
return err
}
_, err := writer.WriteString("\r\n")
return err
}
func writeRedisArrayOfBulkStrings(writer *bufio.Writer, items [][]byte) error {
if writer == nil {
return net.ErrClosed
}
if _, err := writer.WriteString("*" + strconv.Itoa(len(items)) + "\r\n"); err != nil {
return err
}
for i := range items {
if err := writeRedisBulkString(writer, items[i]); err != nil {
return err
}
}
return nil
}

View File

@@ -0,0 +1,304 @@
package api
import (
"bufio"
"bytes"
"errors"
"fmt"
"io"
"net"
"strconv"
"strings"
"testing"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/redisqueue"
)
func startRedisMuxListener(t *testing.T, server *Server) (addr string, stop func()) {
t.Helper()
listener, errListen := net.Listen("tcp", "127.0.0.1:0")
if errListen != nil {
t.Fatalf("failed to listen: %v", errListen)
}
errCh := make(chan error, 1)
go func() {
errCh <- server.acceptMuxConnections(listener, nil)
}()
stop = func() {
_ = listener.Close()
select {
case err := <-errCh:
if err != nil && !errors.Is(err, net.ErrClosed) {
t.Errorf("accept loop returned unexpected error: %v", err)
}
case <-time.After(2 * time.Second):
t.Errorf("timeout waiting for accept loop to exit")
}
}
return listener.Addr().String(), stop
}
func writeTestRESPCommand(conn net.Conn, args ...string) error {
if conn == nil {
return net.ErrClosed
}
if len(args) == 0 {
return nil
}
var buf bytes.Buffer
fmt.Fprintf(&buf, "*%d\r\n", len(args))
for _, arg := range args {
fmt.Fprintf(&buf, "$%d\r\n%s\r\n", len(arg), arg)
}
_, err := conn.Write(buf.Bytes())
return err
}
func readTestRESPLine(r *bufio.Reader) (string, error) {
line, err := r.ReadString('\n')
if err != nil {
return "", err
}
if !strings.HasSuffix(line, "\r\n") {
return "", fmt.Errorf("invalid RESP line terminator: %q", line)
}
return strings.TrimSuffix(line, "\r\n"), nil
}
func readTestRESPSimpleString(r *bufio.Reader) (string, error) {
prefix, err := r.ReadByte()
if err != nil {
return "", err
}
if prefix != '+' {
return "", fmt.Errorf("expected simple string prefix '+', got %q", prefix)
}
return readTestRESPLine(r)
}
func readTestRESPError(r *bufio.Reader) (string, error) {
prefix, err := r.ReadByte()
if err != nil {
return "", err
}
if prefix != '-' {
return "", fmt.Errorf("expected error prefix '-', got %q", prefix)
}
return readTestRESPLine(r)
}
func readTestRESPBulkString(r *bufio.Reader) ([]byte, error) {
prefix, err := r.ReadByte()
if err != nil {
return nil, err
}
if prefix != '$' {
return nil, fmt.Errorf("expected bulk string prefix '$', got %q", prefix)
}
line, err := readTestRESPLine(r)
if err != nil {
return nil, err
}
length, err := strconv.Atoi(line)
if err != nil {
return nil, fmt.Errorf("invalid bulk string length %q: %v", line, err)
}
if length == -1 {
return nil, nil
}
if length < -1 {
return nil, fmt.Errorf("invalid bulk string length %d", length)
}
payload := make([]byte, length+2)
if _, err := io.ReadFull(r, payload); err != nil {
return nil, err
}
if payload[length] != '\r' || payload[length+1] != '\n' {
return nil, fmt.Errorf("invalid bulk string terminator")
}
return payload[:length], nil
}
func readRESPArrayOfBulkStrings(r *bufio.Reader) ([][]byte, error) {
prefix, err := r.ReadByte()
if err != nil {
return nil, err
}
if prefix != '*' {
return nil, fmt.Errorf("expected array prefix '*', got %q", prefix)
}
line, err := readTestRESPLine(r)
if err != nil {
return nil, err
}
count, err := strconv.Atoi(line)
if err != nil {
return nil, fmt.Errorf("invalid array length %q: %v", line, err)
}
if count < 0 {
return nil, fmt.Errorf("invalid array length %d", count)
}
out := make([][]byte, 0, count)
for i := 0; i < count; i++ {
item, err := readTestRESPBulkString(r)
if err != nil {
return nil, err
}
out = append(out, item)
}
return out, nil
}
func TestRedisProtocol_ManagementDisabled_RejectsConnection(t *testing.T) {
t.Setenv("MANAGEMENT_PASSWORD", "")
redisqueue.SetEnabled(false)
server := newTestServer(t)
if server.managementRoutesEnabled.Load() {
t.Fatalf("expected managementRoutesEnabled to be false")
}
addr, stop := startRedisMuxListener(t, server)
t.Cleanup(stop)
conn, errDial := net.DialTimeout("tcp", addr, time.Second)
if errDial != nil {
t.Fatalf("failed to dial redis listener: %v", errDial)
}
t.Cleanup(func() { _ = conn.Close() })
_ = conn.SetDeadline(time.Now().Add(2 * time.Second))
if errWrite := writeTestRESPCommand(conn, "PING"); errWrite != nil {
t.Fatalf("failed to write RESP command: %v", errWrite)
}
buf := make([]byte, 1)
_, errRead := conn.Read(buf)
if errRead == nil {
t.Fatalf("expected connection to be closed when management is disabled")
}
if ne, ok := errRead.(net.Error); ok && ne.Timeout() {
t.Fatalf("expected connection to be closed when management is disabled, got timeout: %v", errRead)
}
}
func TestRedisProtocol_AUTH_And_PopContracts(t *testing.T) {
const managementPassword = "test-management-password"
t.Setenv("MANAGEMENT_PASSWORD", managementPassword)
redisqueue.SetEnabled(false)
t.Cleanup(func() { redisqueue.SetEnabled(false) })
server := newTestServer(t)
if !server.managementRoutesEnabled.Load() {
t.Fatalf("expected managementRoutesEnabled to be true")
}
addr, stop := startRedisMuxListener(t, server)
t.Cleanup(stop)
conn, errDial := net.DialTimeout("tcp", addr, time.Second)
if errDial != nil {
t.Fatalf("failed to dial redis listener: %v", errDial)
}
t.Cleanup(func() { _ = conn.Close() })
reader := bufio.NewReader(conn)
_ = conn.SetDeadline(time.Now().Add(5 * time.Second))
if errWrite := writeTestRESPCommand(conn, "AUTH", "test-key"); errWrite != nil {
t.Fatalf("failed to write AUTH command: %v", errWrite)
}
if msg, err := readTestRESPError(reader); err != nil {
t.Fatalf("failed to read AUTH error: %v", err)
} else if msg != "ERR invalid management key" {
t.Fatalf("unexpected AUTH error: %q", msg)
}
if errWrite := writeTestRESPCommand(conn, "LPOP", "queue"); errWrite != nil {
t.Fatalf("failed to write LPOP command: %v", errWrite)
}
if msg, err := readTestRESPError(reader); err != nil {
t.Fatalf("failed to read LPOP NOAUTH error: %v", err)
} else if msg != "NOAUTH Authentication required." {
t.Fatalf("unexpected LPOP NOAUTH error: %q", msg)
}
if errWrite := writeTestRESPCommand(conn, "AUTH", managementPassword); errWrite != nil {
t.Fatalf("failed to write AUTH command: %v", errWrite)
}
if msg, err := readTestRESPSimpleString(reader); err != nil {
t.Fatalf("failed to read AUTH response: %v", err)
} else if msg != "OK" {
t.Fatalf("unexpected AUTH response: %q", msg)
}
if !redisqueue.Enabled() {
t.Fatalf("expected redisqueue to be enabled")
}
redisqueue.Enqueue([]byte("a"))
redisqueue.Enqueue([]byte("b"))
redisqueue.Enqueue([]byte("c"))
if errWrite := writeTestRESPCommand(conn, "RPOP", "queue"); errWrite != nil {
t.Fatalf("failed to write RPOP command: %v", errWrite)
}
if item, err := readTestRESPBulkString(reader); err != nil {
t.Fatalf("failed to read RPOP response: %v", err)
} else if string(item) != "a" {
t.Fatalf("unexpected RPOP item: %q", string(item))
}
if errWrite := writeTestRESPCommand(conn, "LPOP", "queue"); errWrite != nil {
t.Fatalf("failed to write LPOP command: %v", errWrite)
}
if item, err := readTestRESPBulkString(reader); err != nil {
t.Fatalf("failed to read LPOP response: %v", err)
} else if string(item) != "b" {
t.Fatalf("unexpected LPOP item: %q", string(item))
}
if errWrite := writeTestRESPCommand(conn, "RPOP", "queue", "10"); errWrite != nil {
t.Fatalf("failed to write RPOP count command: %v", errWrite)
}
items, errItems := readRESPArrayOfBulkStrings(reader)
if errItems != nil {
t.Fatalf("failed to read RPOP count response: %v", errItems)
}
if len(items) != 1 || string(items[0]) != "c" {
t.Fatalf("unexpected RPOP count items: %#v", items)
}
if errWrite := writeTestRESPCommand(conn, "LPOP", "queue"); errWrite != nil {
t.Fatalf("failed to write LPOP empty command: %v", errWrite)
}
item, errItem := readTestRESPBulkString(reader)
if errItem != nil {
t.Fatalf("failed to read LPOP empty response: %v", errItem)
}
if item != nil {
t.Fatalf("expected nil bulk string for empty queue, got %q", string(item))
}
if errWrite := writeTestRESPCommand(conn, "RPOP", "queue", "2"); errWrite != nil {
t.Fatalf("failed to write RPOP empty count command: %v", errWrite)
}
emptyItems, errEmpty := readRESPArrayOfBulkStrings(reader)
if errEmpty != nil {
t.Fatalf("failed to read RPOP empty count response: %v", errEmpty)
}
if len(emptyItems) != 0 {
t.Fatalf("expected empty array for empty queue with count, got %#v", emptyItems)
}
}

View File

@@ -7,8 +7,10 @@ package api
import (
"context"
"crypto/subtle"
"crypto/tls"
"errors"
"fmt"
"net"
"net/http"
"os"
"path/filepath"
@@ -28,6 +30,7 @@ import (
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
"github.com/router-for-me/CLIProxyAPI/v6/internal/managementasset"
"github.com/router-for-me/CLIProxyAPI/v6/internal/redisqueue"
"github.com/router-for-me/CLIProxyAPI/v6/internal/usage"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
sdkaccess "github.com/router-for-me/CLIProxyAPI/v6/sdk/access"
@@ -38,6 +41,7 @@ import (
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
log "github.com/sirupsen/logrus"
"golang.org/x/net/http2"
"gopkg.in/yaml.v3"
)
@@ -127,6 +131,12 @@ type Server struct {
// server is the underlying HTTP server.
server *http.Server
// muxBaseListener is the shared TCP listener used to serve both HTTP and Redis protocol traffic.
muxBaseListener net.Listener
// muxHTTPListener receives HTTP connections selected by the multiplexer.
muxHTTPListener *muxListener
// handlers contains the API handlers for processing requests.
handlers *handlers.BaseAPIHandler
@@ -299,6 +309,7 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk
// or when a local management password is provided (e.g. TUI mode).
hasManagementSecret := cfg.RemoteManagement.SecretKey != "" || envManagementSecret || s.localPassword != ""
s.managementRoutesEnabled.Store(hasManagementSecret)
redisqueue.SetEnabled(hasManagementSecret)
if hasManagementSecret {
s.registerManagementRoutes()
}
@@ -797,26 +808,98 @@ func (s *Server) Start() error {
return fmt.Errorf("failed to start HTTP server: server not initialized")
}
addr := s.server.Addr
listener, errListen := net.Listen("tcp", addr)
if errListen != nil {
return fmt.Errorf("failed to start HTTP server: %v", errListen)
}
useTLS := s.cfg != nil && s.cfg.TLS.Enable
if useTLS {
cert := strings.TrimSpace(s.cfg.TLS.Cert)
key := strings.TrimSpace(s.cfg.TLS.Key)
if cert == "" || key == "" {
certPath := strings.TrimSpace(s.cfg.TLS.Cert)
keyPath := strings.TrimSpace(s.cfg.TLS.Key)
if certPath == "" || keyPath == "" {
if errClose := listener.Close(); errClose != nil {
log.Errorf("failed to close listener after TLS validation failure: %v", errClose)
}
return fmt.Errorf("failed to start HTTPS server: tls.cert or tls.key is empty")
}
log.Debugf("Starting API server on %s with TLS", s.server.Addr)
if errServeTLS := s.server.ListenAndServeTLS(cert, key); errServeTLS != nil && !errors.Is(errServeTLS, http.ErrServerClosed) {
return fmt.Errorf("failed to start HTTPS server: %v", errServeTLS)
certPair, errLoad := tls.LoadX509KeyPair(certPath, keyPath)
if errLoad != nil {
if errClose := listener.Close(); errClose != nil {
log.Errorf("failed to close listener after TLS key pair load failure: %v", errClose)
}
return fmt.Errorf("failed to start HTTPS server: %v", errLoad)
}
tlsConfig := &tls.Config{
Certificates: []tls.Certificate{certPair},
NextProtos: []string{"h2", "http/1.1"},
}
s.server.TLSConfig = tlsConfig
if errHTTP2 := http2.ConfigureServer(s.server, &http2.Server{}); errHTTP2 != nil {
log.Warnf("failed to configure HTTP/2: %v", errHTTP2)
}
listener = tls.NewListener(listener, tlsConfig)
log.Debugf("Starting API server on %s with TLS", addr)
} else {
log.Debugf("Starting API server on %s", addr)
}
httpListener := newMuxListener(listener.Addr(), 1024)
s.muxBaseListener = listener
s.muxHTTPListener = httpListener
httpErrCh := make(chan error, 1)
acceptErrCh := make(chan error, 1)
go func() {
httpErrCh <- s.server.Serve(httpListener)
}()
go func() {
acceptErrCh <- s.acceptMuxConnections(listener, httpListener)
}()
select {
case errServe := <-httpErrCh:
if s.muxBaseListener != nil {
if errClose := s.muxBaseListener.Close(); errClose != nil && !errors.Is(errClose, net.ErrClosed) {
log.Debugf("failed to close shared listener after HTTP serve exit: %v", errClose)
}
}
if s.muxHTTPListener != nil {
_ = s.muxHTTPListener.Close()
}
errAccept := <-acceptErrCh
errServe = normalizeHTTPServeError(errServe)
errAccept = normalizeListenerError(errAccept)
if errServe != nil {
return fmt.Errorf("failed to start HTTP server: %v", errServe)
}
if errAccept != nil {
return fmt.Errorf("failed to start HTTP server: %v", errAccept)
}
return nil
case errAccept := <-acceptErrCh:
if s.muxHTTPListener != nil {
_ = s.muxHTTPListener.Close()
}
if s.muxBaseListener != nil {
if errClose := s.muxBaseListener.Close(); errClose != nil && !errors.Is(errClose, net.ErrClosed) {
log.Debugf("failed to close shared listener after accept loop exit: %v", errClose)
}
}
errServe := <-httpErrCh
errServe = normalizeHTTPServeError(errServe)
errAccept = normalizeListenerError(errAccept)
if errAccept != nil {
return fmt.Errorf("failed to start HTTP server: %v", errAccept)
}
if errServe != nil {
return fmt.Errorf("failed to start HTTP server: %v", errServe)
}
return nil
}
log.Debugf("Starting API server on %s", s.server.Addr)
if errServe := s.server.ListenAndServe(); errServe != nil && !errors.Is(errServe, http.ErrServerClosed) {
return fmt.Errorf("failed to start HTTP server: %v", errServe)
}
return nil
}
// Stop gracefully shuts down the API server without interrupting any
@@ -837,6 +920,15 @@ func (s *Server) Stop(ctx context.Context) error {
}
}
if s.muxHTTPListener != nil {
_ = s.muxHTTPListener.Close()
}
if s.muxBaseListener != nil {
if errClose := s.muxBaseListener.Close(); errClose != nil && !errors.Is(errClose, net.ErrClosed) {
log.Debugf("failed to close shared listener: %v", errClose)
}
}
// Shutdown the HTTP server.
if err := s.server.Shutdown(ctx); err != nil {
return fmt.Errorf("failed to shutdown HTTP server: %v", err)
@@ -963,6 +1055,7 @@ func (s *Server) UpdateClients(cfg *config.Config) {
s.managementRoutesEnabled.Store(!newSecretEmpty)
}
}
redisqueue.SetEnabled(s.managementRoutesEnabled.Load())
s.applyAccessConfig(oldCfg, cfg)
s.cfg = cfg

View File

@@ -0,0 +1,145 @@
package redisqueue
import (
"context"
"encoding/json"
"net/http"
"strings"
"time"
"github.com/gin-gonic/gin"
internallogging "github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
internalusage "github.com/router-for-me/CLIProxyAPI/v6/internal/usage"
coreusage "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/usage"
)
func init() {
coreusage.RegisterPlugin(&usageQueuePlugin{})
}
type usageQueuePlugin struct{}
func (p *usageQueuePlugin) HandleUsage(ctx context.Context, record coreusage.Record) {
if p == nil {
return
}
if !Enabled() || !internalusage.StatisticsEnabled() {
return
}
timestamp := record.RequestedAt
if timestamp.IsZero() {
timestamp = time.Now()
}
modelName := strings.TrimSpace(record.Model)
if modelName == "" {
modelName = "unknown"
}
provider := strings.TrimSpace(record.Provider)
if provider == "" {
provider = "unknown"
}
authType := strings.TrimSpace(record.AuthType)
if authType == "" {
authType = "unknown"
}
apiKey := strings.TrimSpace(record.APIKey)
requestID := strings.TrimSpace(internallogging.GetRequestID(ctx))
if requestID == "" {
if ginCtx, ok := ctx.Value("gin").(*gin.Context); ok && ginCtx != nil {
requestID = strings.TrimSpace(internallogging.GetGinRequestID(ginCtx))
}
}
tokens := internalusage.TokenStats{
InputTokens: record.Detail.InputTokens,
OutputTokens: record.Detail.OutputTokens,
ReasoningTokens: record.Detail.ReasoningTokens,
CachedTokens: record.Detail.CachedTokens,
TotalTokens: record.Detail.TotalTokens,
}
if tokens.TotalTokens == 0 {
tokens.TotalTokens = tokens.InputTokens + tokens.OutputTokens + tokens.ReasoningTokens
}
if tokens.TotalTokens == 0 {
tokens.TotalTokens = tokens.InputTokens + tokens.OutputTokens + tokens.ReasoningTokens + tokens.CachedTokens
}
failed := record.Failed
if !failed {
failed = !resolveSuccess(ctx)
}
detail := internalusage.RequestDetail{
Timestamp: timestamp,
LatencyMs: record.Latency.Milliseconds(),
Source: record.Source,
AuthIndex: record.AuthIndex,
Tokens: tokens,
Failed: failed,
}
payload, err := json.Marshal(queuedUsageDetail{
RequestDetail: detail,
Provider: provider,
Model: modelName,
Endpoint: resolveEndpoint(ctx),
AuthType: authType,
APIKey: apiKey,
RequestID: requestID,
})
if err != nil {
return
}
Enqueue(payload)
}
type queuedUsageDetail struct {
internalusage.RequestDetail
Provider string `json:"provider"`
Model string `json:"model"`
Endpoint string `json:"endpoint"`
AuthType string `json:"auth_type"`
APIKey string `json:"api_key"`
RequestID string `json:"request_id"`
}
func resolveSuccess(ctx context.Context) bool {
if ctx == nil {
return true
}
ginCtx, ok := ctx.Value("gin").(*gin.Context)
if !ok || ginCtx == nil {
return true
}
status := ginCtx.Writer.Status()
if status == 0 {
return true
}
return status < http.StatusBadRequest
}
func resolveEndpoint(ctx context.Context) string {
if ctx == nil {
return ""
}
ginCtx, ok := ctx.Value("gin").(*gin.Context)
if !ok || ginCtx == nil || ginCtx.Request == nil {
return ""
}
path := strings.TrimSpace(ginCtx.FullPath())
if path == "" && ginCtx.Request.URL != nil {
path = strings.TrimSpace(ginCtx.Request.URL.Path)
}
if path == "" {
return ""
}
method := strings.TrimSpace(ginCtx.Request.Method)
if method == "" {
return path
}
return method + " " + path
}

View File

@@ -0,0 +1,160 @@
package redisqueue
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/gin-gonic/gin"
internallogging "github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
internalusage "github.com/router-for-me/CLIProxyAPI/v6/internal/usage"
coreusage "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/usage"
)
func TestUsageQueuePluginPayloadIncludesStableFieldsAndSuccess(t *testing.T) {
withEnabledQueue(t, func() {
ginCtx := newTestGinContext(t, http.MethodPost, "/v1/chat/completions", http.StatusOK)
internallogging.SetGinRequestID(ginCtx, "gin-request-id-ignored")
ctx := context.WithValue(internallogging.WithRequestID(context.Background(), "ctx-request-id"), "gin", ginCtx)
plugin := &usageQueuePlugin{}
plugin.HandleUsage(ctx, coreusage.Record{
Provider: "openai",
Model: "gpt-5.4",
APIKey: "test-key",
AuthIndex: "0",
AuthType: "apikey",
Source: "user@example.com",
RequestedAt: time.Date(2026, 4, 25, 0, 0, 0, 0, time.UTC),
Latency: 1500 * time.Millisecond,
Detail: coreusage.Detail{
InputTokens: 10,
OutputTokens: 20,
TotalTokens: 30,
},
})
payload := popSinglePayload(t)
requireStringField(t, payload, "provider", "openai")
requireStringField(t, payload, "model", "gpt-5.4")
requireStringField(t, payload, "endpoint", "POST /v1/chat/completions")
requireStringField(t, payload, "auth_type", "apikey")
requireStringField(t, payload, "request_id", "ctx-request-id")
requireBoolField(t, payload, "failed", false)
})
}
func TestUsageQueuePluginPayloadIncludesStableFieldsAndFailureAndGinRequestID(t *testing.T) {
withEnabledQueue(t, func() {
ginCtx := newTestGinContext(t, http.MethodGet, "/v1/responses", http.StatusInternalServerError)
internallogging.SetGinRequestID(ginCtx, "gin-request-id")
ctx := context.WithValue(context.Background(), "gin", ginCtx)
plugin := &usageQueuePlugin{}
plugin.HandleUsage(ctx, coreusage.Record{
Provider: "openai",
Model: "gpt-5.4-mini",
APIKey: "test-key",
AuthIndex: "0",
AuthType: "apikey",
Source: "user@example.com",
RequestedAt: time.Date(2026, 4, 25, 0, 0, 0, 0, time.UTC),
Latency: 2500 * time.Millisecond,
Detail: coreusage.Detail{
InputTokens: 10,
OutputTokens: 20,
TotalTokens: 30,
},
})
payload := popSinglePayload(t)
requireStringField(t, payload, "provider", "openai")
requireStringField(t, payload, "model", "gpt-5.4-mini")
requireStringField(t, payload, "endpoint", "GET /v1/responses")
requireStringField(t, payload, "auth_type", "apikey")
requireStringField(t, payload, "request_id", "gin-request-id")
requireBoolField(t, payload, "failed", true)
})
}
func withEnabledQueue(t *testing.T, fn func()) {
t.Helper()
prevQueueEnabled := Enabled()
prevStatsEnabled := internalusage.StatisticsEnabled()
SetEnabled(false)
SetEnabled(true)
internalusage.SetStatisticsEnabled(true)
defer func() {
SetEnabled(false)
SetEnabled(prevQueueEnabled)
internalusage.SetStatisticsEnabled(prevStatsEnabled)
}()
fn()
}
func newTestGinContext(t *testing.T, method, path string, status int) *gin.Context {
t.Helper()
gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(recorder)
ginCtx.Request = httptest.NewRequest(method, "http://example.com"+path, nil)
if status != 0 {
ginCtx.Status(status)
}
return ginCtx
}
func popSinglePayload(t *testing.T) map[string]json.RawMessage {
t.Helper()
items := PopOldest(10)
if len(items) != 1 {
t.Fatalf("PopOldest() items = %d, want 1", len(items))
}
var payload map[string]json.RawMessage
if err := json.Unmarshal(items[0], &payload); err != nil {
t.Fatalf("unmarshal payload: %v", err)
}
return payload
}
func requireStringField(t *testing.T, payload map[string]json.RawMessage, key, want string) {
t.Helper()
raw, ok := payload[key]
if !ok {
t.Fatalf("payload missing %q", key)
}
var got string
if err := json.Unmarshal(raw, &got); err != nil {
t.Fatalf("unmarshal %q: %v", key, err)
}
if got != want {
t.Fatalf("%s = %q, want %q", key, got, want)
}
}
func requireBoolField(t *testing.T, payload map[string]json.RawMessage, key string, want bool) {
t.Helper()
raw, ok := payload[key]
if !ok {
t.Fatalf("payload missing %q", key)
}
var got bool
if err := json.Unmarshal(raw, &got); err != nil {
t.Fatalf("unmarshal %q: %v", key, err)
}
if got != want {
t.Fatalf("%s = %t, want %t", key, got, want)
}
}

View File

@@ -0,0 +1,133 @@
package redisqueue
import (
"sync"
"sync/atomic"
"time"
)
const retentionWindow = time.Minute
type queueItem struct {
enqueuedAt time.Time
payload []byte
}
type queue struct {
mu sync.Mutex
items []queueItem
head int
}
var (
enabled atomic.Bool
global queue
)
func SetEnabled(value bool) {
enabled.Store(value)
if !value {
global.clear()
}
}
func Enabled() bool {
return enabled.Load()
}
func Enqueue(payload []byte) {
if !Enabled() {
return
}
if len(payload) == 0 {
return
}
global.enqueue(payload)
}
func PopOldest(count int) [][]byte {
if !Enabled() {
return nil
}
if count <= 0 {
return nil
}
return global.popOldest(count)
}
func (q *queue) clear() {
q.mu.Lock()
defer q.mu.Unlock()
q.items = nil
q.head = 0
}
func (q *queue) enqueue(payload []byte) {
now := time.Now()
q.mu.Lock()
defer q.mu.Unlock()
q.pruneLocked(now)
q.items = append(q.items, queueItem{
enqueuedAt: now,
payload: append([]byte(nil), payload...),
})
q.maybeCompactLocked()
}
func (q *queue) popOldest(count int) [][]byte {
now := time.Now()
q.mu.Lock()
defer q.mu.Unlock()
q.pruneLocked(now)
available := len(q.items) - q.head
if available <= 0 {
q.items = nil
q.head = 0
return nil
}
if count > available {
count = available
}
out := make([][]byte, 0, count)
for i := 0; i < count; i++ {
item := q.items[q.head+i]
out = append(out, item.payload)
}
q.head += count
q.maybeCompactLocked()
return out
}
func (q *queue) pruneLocked(now time.Time) {
if q.head >= len(q.items) {
q.items = nil
q.head = 0
return
}
cutoff := now.Add(-retentionWindow)
for q.head < len(q.items) && q.items[q.head].enqueuedAt.Before(cutoff) {
q.head++
}
}
func (q *queue) maybeCompactLocked() {
if q.head == 0 {
return
}
if q.head >= len(q.items) {
q.items = nil
q.head = 0
return
}
if q.head < 1024 && q.head*2 < len(q.items) {
return
}
q.items = append([]queueItem(nil), q.items[q.head:]...)
q.head = 0
}

View File

@@ -20,6 +20,7 @@ type UsageReporter struct {
model string
authID string
authIndex string
authType string
apiKey string
source string
requestedAt time.Time
@@ -34,6 +35,7 @@ func NewUsageReporter(ctx context.Context, provider, model string, auth *cliprox
requestedAt: time.Now(),
apiKey: apiKey,
source: resolveUsageSource(auth, apiKey),
authType: resolveUsageAuthType(auth),
}
if auth != nil {
reporter.authID = auth.ID
@@ -98,6 +100,7 @@ func (r *UsageReporter) buildRecord(detail usage.Detail, failed bool) usage.Reco
APIKey: r.apiKey,
AuthID: r.authID,
AuthIndex: r.authIndex,
AuthType: r.authType,
RequestedAt: r.requestedAt,
Latency: r.latency(),
Failed: failed,
@@ -181,6 +184,18 @@ func resolveUsageSource(auth *cliproxyauth.Auth, ctxAPIKey string) string {
return ""
}
func resolveUsageAuthType(auth *cliproxyauth.Auth) string {
if auth == nil {
return ""
}
kind, _ := auth.AccountInfo()
kind = strings.TrimSpace(kind)
if kind == "api_key" {
return "apikey"
}
return kind
}
func ParseCodexUsage(data []byte) (usage.Detail, bool) {
usageNode := gjson.ParseBytes(data).Get("response.usage")
if !usageNode.Exists() {

View File

@@ -13,6 +13,7 @@ import (
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/api"
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/redisqueue"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor"
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/usage"

View File

@@ -15,6 +15,7 @@ type Record struct {
APIKey string
AuthID string
AuthIndex string
AuthType string
Source string
RequestedAt time.Time
Latency time.Duration