From a583463d6048731c6b87eee94f458b1e01bfe3b3 Mon Sep 17 00:00:00 2001 From: Luis Pater Date: Sun, 12 Apr 2026 02:06:40 +0800 Subject: [PATCH] feat(auth): implement auto-refresh loop for managing auth token schedule - Introduced `authAutoRefreshLoop` to handle token refresh scheduling. - Replaced semaphore-based refresh logic in `Manager` with the new loop. - Added unit tests to verify refresh schedule logic and edge cases. --- sdk/cliproxy/auth/auto_refresh_loop.go | 444 ++++++++++++++++++++ sdk/cliproxy/auth/auto_refresh_loop_test.go | 137 ++++++ sdk/cliproxy/auth/conductor.go | 119 +++--- 3 files changed, 636 insertions(+), 64 deletions(-) create mode 100644 sdk/cliproxy/auth/auto_refresh_loop.go create mode 100644 sdk/cliproxy/auth/auto_refresh_loop_test.go diff --git a/sdk/cliproxy/auth/auto_refresh_loop.go b/sdk/cliproxy/auth/auto_refresh_loop.go new file mode 100644 index 000000000..3350b603e --- /dev/null +++ b/sdk/cliproxy/auth/auto_refresh_loop.go @@ -0,0 +1,444 @@ +package auth + +import ( + "container/heap" + "context" + "strings" + "sync" + "time" + + log "github.com/sirupsen/logrus" +) + +type authAutoRefreshLoop struct { + manager *Manager + interval time.Duration + + mu sync.Mutex + queue refreshMinHeap + index map[string]*refreshHeapItem + dirty map[string]struct{} + + wakeCh chan struct{} + jobs chan string +} + +func newAuthAutoRefreshLoop(manager *Manager, interval time.Duration) *authAutoRefreshLoop { + if interval <= 0 { + interval = refreshCheckInterval + } + jobBuffer := refreshMaxConcurrency * 4 + if jobBuffer < 64 { + jobBuffer = 64 + } + return &authAutoRefreshLoop{ + manager: manager, + interval: interval, + index: make(map[string]*refreshHeapItem), + dirty: make(map[string]struct{}), + wakeCh: make(chan struct{}, 1), + jobs: make(chan string, jobBuffer), + } +} + +func (l *authAutoRefreshLoop) queueReschedule(authID string) { + if l == nil || authID == "" { + return + } + l.mu.Lock() + l.dirty[authID] = struct{}{} + l.mu.Unlock() + select { + case l.wakeCh <- struct{}{}: + default: + } +} + +func (l *authAutoRefreshLoop) run(ctx context.Context) { + if l == nil || l.manager == nil { + return + } + + for i := 0; i < refreshMaxConcurrency; i++ { + go l.worker(ctx) + } + + l.loop(ctx) +} + +func (l *authAutoRefreshLoop) worker(ctx context.Context) { + for { + select { + case <-ctx.Done(): + return + case authID := <-l.jobs: + if authID == "" { + continue + } + l.manager.refreshAuth(ctx, authID) + l.queueReschedule(authID) + } + } +} + +func (l *authAutoRefreshLoop) rebuild(now time.Time) { + type entry struct { + id string + next time.Time + } + + entries := make([]entry, 0) + + l.manager.mu.RLock() + for id, auth := range l.manager.auths { + next, ok := nextRefreshCheckAt(now, auth, l.interval) + if !ok { + continue + } + entries = append(entries, entry{id: id, next: next}) + } + l.manager.mu.RUnlock() + + l.mu.Lock() + l.queue = l.queue[:0] + l.index = make(map[string]*refreshHeapItem, len(entries)) + for _, e := range entries { + item := &refreshHeapItem{id: e.id, next: e.next} + heap.Push(&l.queue, item) + l.index[e.id] = item + } + l.mu.Unlock() +} + +func (l *authAutoRefreshLoop) loop(ctx context.Context) { + timer := time.NewTimer(time.Hour) + if !timer.Stop() { + select { + case <-timer.C: + default: + } + } + defer timer.Stop() + + var timerCh <-chan time.Time + l.resetTimer(timer, &timerCh, time.Now()) + + for { + select { + case <-ctx.Done(): + return + case <-l.wakeCh: + now := time.Now() + l.applyDirty(now) + l.resetTimer(timer, &timerCh, now) + case <-timerCh: + now := time.Now() + l.handleDue(ctx, now) + l.applyDirty(now) + l.resetTimer(timer, &timerCh, now) + } + } +} + +func (l *authAutoRefreshLoop) resetTimer(timer *time.Timer, timerCh *<-chan time.Time, now time.Time) { + next, ok := l.peek() + if !ok { + if !timer.Stop() { + select { + case <-timer.C: + default: + } + } + *timerCh = nil + return + } + + wait := next.Sub(now) + if wait < 0 { + wait = 0 + } + if !timer.Stop() { + select { + case <-timer.C: + default: + } + } + timer.Reset(wait) + *timerCh = timer.C +} + +func (l *authAutoRefreshLoop) peek() (time.Time, bool) { + l.mu.Lock() + defer l.mu.Unlock() + if len(l.queue) == 0 { + return time.Time{}, false + } + return l.queue[0].next, true +} + +func (l *authAutoRefreshLoop) handleDue(ctx context.Context, now time.Time) { + due := l.popDue(now) + if len(due) == 0 { + return + } + if log.IsLevelEnabled(log.DebugLevel) { + log.Debugf("auto-refresh scheduler due auths: %d", len(due)) + } + for _, authID := range due { + l.handleDueAuth(ctx, now, authID) + } +} + +func (l *authAutoRefreshLoop) popDue(now time.Time) []string { + l.mu.Lock() + defer l.mu.Unlock() + + var due []string + for len(l.queue) > 0 { + item := l.queue[0] + if item == nil || item.next.After(now) { + break + } + popped := heap.Pop(&l.queue).(*refreshHeapItem) + if popped == nil { + continue + } + delete(l.index, popped.id) + due = append(due, popped.id) + } + return due +} + +func (l *authAutoRefreshLoop) handleDueAuth(ctx context.Context, now time.Time, authID string) { + if authID == "" { + return + } + + manager := l.manager + + manager.mu.RLock() + auth := manager.auths[authID] + if auth == nil { + manager.mu.RUnlock() + return + } + next, shouldSchedule := nextRefreshCheckAt(now, auth, l.interval) + shouldRefresh := manager.shouldRefresh(auth, now) + exec := manager.executors[auth.Provider] + manager.mu.RUnlock() + + if !shouldSchedule { + l.remove(authID) + return + } + + if !shouldRefresh { + l.upsert(authID, next) + return + } + + if exec == nil { + l.upsert(authID, now.Add(l.interval)) + return + } + + if !manager.markRefreshPending(authID, now) { + manager.mu.RLock() + auth = manager.auths[authID] + next, shouldSchedule = nextRefreshCheckAt(now, auth, l.interval) + manager.mu.RUnlock() + if shouldSchedule { + l.upsert(authID, next) + } else { + l.remove(authID) + } + return + } + + select { + case <-ctx.Done(): + return + case l.jobs <- authID: + } +} + +func (l *authAutoRefreshLoop) applyDirty(now time.Time) { + dirty := l.drainDirty() + if len(dirty) == 0 { + return + } + + for _, authID := range dirty { + l.manager.mu.RLock() + auth := l.manager.auths[authID] + next, ok := nextRefreshCheckAt(now, auth, l.interval) + l.manager.mu.RUnlock() + + if !ok { + l.remove(authID) + continue + } + l.upsert(authID, next) + } +} + +func (l *authAutoRefreshLoop) drainDirty() []string { + l.mu.Lock() + defer l.mu.Unlock() + if len(l.dirty) == 0 { + return nil + } + out := make([]string, 0, len(l.dirty)) + for authID := range l.dirty { + out = append(out, authID) + delete(l.dirty, authID) + } + return out +} + +func (l *authAutoRefreshLoop) upsert(authID string, next time.Time) { + if authID == "" || next.IsZero() { + return + } + l.mu.Lock() + defer l.mu.Unlock() + if item, ok := l.index[authID]; ok && item != nil { + item.next = next + heap.Fix(&l.queue, item.index) + return + } + item := &refreshHeapItem{id: authID, next: next} + heap.Push(&l.queue, item) + l.index[authID] = item +} + +func (l *authAutoRefreshLoop) remove(authID string) { + if authID == "" { + return + } + l.mu.Lock() + defer l.mu.Unlock() + item, ok := l.index[authID] + if !ok || item == nil { + return + } + heap.Remove(&l.queue, item.index) + delete(l.index, authID) +} + +func nextRefreshCheckAt(now time.Time, auth *Auth, interval time.Duration) (time.Time, bool) { + if auth == nil || auth.Disabled { + return time.Time{}, false + } + + accountType, _ := auth.AccountInfo() + if accountType == "api_key" { + return time.Time{}, false + } + + if !auth.NextRefreshAfter.IsZero() && now.Before(auth.NextRefreshAfter) { + return auth.NextRefreshAfter, true + } + + if evaluator, ok := auth.Runtime.(RefreshEvaluator); ok && evaluator != nil { + if interval <= 0 { + interval = refreshCheckInterval + } + return now.Add(interval), true + } + + lastRefresh := auth.LastRefreshedAt + if lastRefresh.IsZero() { + if ts, ok := authLastRefreshTimestamp(auth); ok { + lastRefresh = ts + } + } + + expiry, hasExpiry := auth.ExpirationTime() + + if pref := authPreferredInterval(auth); pref > 0 { + candidates := make([]time.Time, 0, 2) + if hasExpiry && !expiry.IsZero() { + if !expiry.After(now) || expiry.Sub(now) <= pref { + return now, true + } + candidates = append(candidates, expiry.Add(-pref)) + } + if lastRefresh.IsZero() { + return now, true + } + candidates = append(candidates, lastRefresh.Add(pref)) + next := candidates[0] + for _, candidate := range candidates[1:] { + if candidate.Before(next) { + next = candidate + } + } + if !next.After(now) { + return now, true + } + return next, true + } + + provider := strings.ToLower(auth.Provider) + lead := ProviderRefreshLead(provider, auth.Runtime) + if lead == nil { + return time.Time{}, false + } + if hasExpiry && !expiry.IsZero() { + dueAt := expiry.Add(-*lead) + if !dueAt.After(now) { + return now, true + } + return dueAt, true + } + if !lastRefresh.IsZero() { + dueAt := lastRefresh.Add(*lead) + if !dueAt.After(now) { + return now, true + } + return dueAt, true + } + return now, true +} + +type refreshHeapItem struct { + id string + next time.Time + index int +} + +type refreshMinHeap []*refreshHeapItem + +func (h refreshMinHeap) Len() int { return len(h) } + +func (h refreshMinHeap) Less(i, j int) bool { + return h[i].next.Before(h[j].next) +} + +func (h refreshMinHeap) Swap(i, j int) { + h[i], h[j] = h[j], h[i] + h[i].index = i + h[j].index = j +} + +func (h *refreshMinHeap) Push(x any) { + item, ok := x.(*refreshHeapItem) + if !ok || item == nil { + return + } + item.index = len(*h) + *h = append(*h, item) +} + +func (h *refreshMinHeap) Pop() any { + old := *h + n := len(old) + if n == 0 { + return (*refreshHeapItem)(nil) + } + item := old[n-1] + item.index = -1 + *h = old[:n-1] + return item +} diff --git a/sdk/cliproxy/auth/auto_refresh_loop_test.go b/sdk/cliproxy/auth/auto_refresh_loop_test.go new file mode 100644 index 000000000..420aae237 --- /dev/null +++ b/sdk/cliproxy/auth/auto_refresh_loop_test.go @@ -0,0 +1,137 @@ +package auth + +import ( + "strings" + "testing" + "time" +) + +type testRefreshEvaluator struct{} + +func (testRefreshEvaluator) ShouldRefresh(time.Time, *Auth) bool { return false } + +func setRefreshLeadFactory(t *testing.T, provider string, factory func() *time.Duration) { + t.Helper() + key := strings.ToLower(strings.TrimSpace(provider)) + refreshLeadMu.Lock() + prev, hadPrev := refreshLeadFactories[key] + if factory == nil { + delete(refreshLeadFactories, key) + } else { + refreshLeadFactories[key] = factory + } + refreshLeadMu.Unlock() + t.Cleanup(func() { + refreshLeadMu.Lock() + if hadPrev { + refreshLeadFactories[key] = prev + } else { + delete(refreshLeadFactories, key) + } + refreshLeadMu.Unlock() + }) +} + +func TestNextRefreshCheckAt_DisabledUnschedule(t *testing.T) { + now := time.Date(2026, 4, 12, 0, 0, 0, 0, time.UTC) + auth := &Auth{ID: "a1", Provider: "test", Disabled: true} + if _, ok := nextRefreshCheckAt(now, auth, 15*time.Minute); ok { + t.Fatalf("nextRefreshCheckAt() ok = true, want false") + } +} + +func TestNextRefreshCheckAt_APIKeyUnschedule(t *testing.T) { + now := time.Date(2026, 4, 12, 0, 0, 0, 0, time.UTC) + auth := &Auth{ID: "a1", Provider: "test", Attributes: map[string]string{"api_key": "k"}} + if _, ok := nextRefreshCheckAt(now, auth, 15*time.Minute); ok { + t.Fatalf("nextRefreshCheckAt() ok = true, want false") + } +} + +func TestNextRefreshCheckAt_NextRefreshAfterGate(t *testing.T) { + now := time.Date(2026, 4, 12, 0, 0, 0, 0, time.UTC) + nextAfter := now.Add(30 * time.Minute) + auth := &Auth{ + ID: "a1", + Provider: "test", + NextRefreshAfter: nextAfter, + Metadata: map[string]any{"email": "x@example.com"}, + } + got, ok := nextRefreshCheckAt(now, auth, 15*time.Minute) + if !ok { + t.Fatalf("nextRefreshCheckAt() ok = false, want true") + } + if !got.Equal(nextAfter) { + t.Fatalf("nextRefreshCheckAt() = %s, want %s", got, nextAfter) + } +} + +func TestNextRefreshCheckAt_PreferredInterval_PicksEarliestCandidate(t *testing.T) { + now := time.Date(2026, 4, 12, 0, 0, 0, 0, time.UTC) + expiry := now.Add(20 * time.Minute) + auth := &Auth{ + ID: "a1", + Provider: "test", + LastRefreshedAt: now, + Metadata: map[string]any{ + "email": "x@example.com", + "expires_at": expiry.Format(time.RFC3339), + "refresh_interval_seconds": 900, // 15m + }, + } + got, ok := nextRefreshCheckAt(now, auth, 15*time.Minute) + if !ok { + t.Fatalf("nextRefreshCheckAt() ok = false, want true") + } + want := expiry.Add(-15 * time.Minute) + if !got.Equal(want) { + t.Fatalf("nextRefreshCheckAt() = %s, want %s", got, want) + } +} + +func TestNextRefreshCheckAt_ProviderLead_Expiry(t *testing.T) { + now := time.Date(2026, 4, 12, 0, 0, 0, 0, time.UTC) + expiry := now.Add(time.Hour) + lead := 10 * time.Minute + setRefreshLeadFactory(t, "provider-lead-expiry", func() *time.Duration { + d := lead + return &d + }) + + auth := &Auth{ + ID: "a1", + Provider: "provider-lead-expiry", + Metadata: map[string]any{ + "email": "x@example.com", + "expires_at": expiry.Format(time.RFC3339), + }, + } + + got, ok := nextRefreshCheckAt(now, auth, 15*time.Minute) + if !ok { + t.Fatalf("nextRefreshCheckAt() ok = false, want true") + } + want := expiry.Add(-lead) + if !got.Equal(want) { + t.Fatalf("nextRefreshCheckAt() = %s, want %s", got, want) + } +} + +func TestNextRefreshCheckAt_RefreshEvaluatorFallback(t *testing.T) { + now := time.Date(2026, 4, 12, 0, 0, 0, 0, time.UTC) + interval := 15 * time.Minute + auth := &Auth{ + ID: "a1", + Provider: "test", + Metadata: map[string]any{"email": "x@example.com"}, + Runtime: testRefreshEvaluator{}, + } + got, ok := nextRefreshCheckAt(now, auth, interval) + if !ok { + t.Fatalf("nextRefreshCheckAt() ok = false, want true") + } + want := now.Add(interval) + if !got.Equal(want) { + t.Fatalf("nextRefreshCheckAt() = %s, want %s", got, want) + } +} diff --git a/sdk/cliproxy/auth/conductor.go b/sdk/cliproxy/auth/conductor.go index 25cc7221a..fc25ca2b3 100644 --- a/sdk/cliproxy/auth/conductor.go +++ b/sdk/cliproxy/auth/conductor.go @@ -162,8 +162,8 @@ type Manager struct { rtProvider RoundTripperProvider // Auto refresh state - refreshCancel context.CancelFunc - refreshSemaphore chan struct{} + refreshCancel context.CancelFunc + refreshLoop *authAutoRefreshLoop } // NewManager constructs a manager with optional custom selector and hook. @@ -182,7 +182,6 @@ func NewManager(store Store, selector Selector, hook Hook) *Manager { auths: make(map[string]*Auth), providerOffsets: make(map[string]int), modelPoolOffsets: make(map[string]int), - refreshSemaphore: make(chan struct{}, refreshMaxConcurrency), } // atomic.Value requires non-nil initial value. manager.runtimeConfig.Store(&internalconfig.Config{}) @@ -214,6 +213,16 @@ func (m *Manager) syncScheduler() { m.syncSchedulerFromSnapshot(m.snapshotAuths()) } +func (m *Manager) snapshotAuths() []*Auth { + m.mu.RLock() + defer m.mu.RUnlock() + out := make([]*Auth, 0, len(m.auths)) + for _, a := range m.auths { + out = append(out, a.Clone()) + } + return out +} + // RefreshSchedulerEntry re-upserts a single auth into the scheduler so that its // supportedModelSet is rebuilt from the current global model registry state. // This must be called after models have been registered for a newly added auth, @@ -1088,6 +1097,7 @@ func (m *Manager) Register(ctx context.Context, auth *Auth) (*Auth, error) { if m.scheduler != nil { m.scheduler.upsertAuth(authClone) } + m.queueRefreshReschedule(auth.ID) _ = m.persist(ctx, auth) m.hook.OnAuthRegistered(ctx, auth.Clone()) return auth.Clone(), nil @@ -1118,6 +1128,7 @@ func (m *Manager) Update(ctx context.Context, auth *Auth) (*Auth, error) { if m.scheduler != nil { m.scheduler.upsertAuth(authClone) } + m.queueRefreshReschedule(auth.ID) _ = m.persist(ctx, auth) m.hook.OnAuthUpdated(ctx, auth.Clone()) return auth.Clone(), nil @@ -2890,80 +2901,51 @@ func (m *Manager) StartAutoRefresh(parent context.Context, interval time.Duratio if interval <= 0 { interval = refreshCheckInterval } - if m.refreshCancel != nil { - m.refreshCancel() - m.refreshCancel = nil + + m.mu.Lock() + cancel := m.refreshCancel + m.refreshCancel = nil + m.refreshLoop = nil + m.mu.Unlock() + if cancel != nil { + cancel() } + ctx, cancel := context.WithCancel(parent) + loop := newAuthAutoRefreshLoop(m, interval) + + m.mu.Lock() m.refreshCancel = cancel - go func() { - ticker := time.NewTicker(interval) - defer ticker.Stop() - m.checkRefreshes(ctx) - for { - select { - case <-ctx.Done(): - return - case <-ticker.C: - m.checkRefreshes(ctx) - } - } - }() + m.refreshLoop = loop + m.mu.Unlock() + + loop.rebuild(time.Now()) + go loop.run(ctx) } // StopAutoRefresh cancels the background refresh loop, if running. func (m *Manager) StopAutoRefresh() { - if m.refreshCancel != nil { - m.refreshCancel() - m.refreshCancel = nil + m.mu.Lock() + cancel := m.refreshCancel + m.refreshCancel = nil + m.refreshLoop = nil + m.mu.Unlock() + if cancel != nil { + cancel() } } -func (m *Manager) checkRefreshes(ctx context.Context) { - // log.Debugf("checking refreshes") - now := time.Now() - snapshot := m.snapshotAuths() - for _, a := range snapshot { - typ, _ := a.AccountInfo() - if typ != "api_key" { - if !m.shouldRefresh(a, now) { - continue - } - log.Debugf("checking refresh for %s, %s, %s", a.Provider, a.ID, typ) - - if exec := m.executorFor(a.Provider); exec == nil { - continue - } - if !m.markRefreshPending(a.ID, now) { - continue - } - go m.refreshAuthWithLimit(ctx, a.ID) - } - } -} - -func (m *Manager) refreshAuthWithLimit(ctx context.Context, id string) { - if m.refreshSemaphore == nil { - m.refreshAuth(ctx, id) +func (m *Manager) queueRefreshReschedule(authID string) { + if m == nil || authID == "" { return } - select { - case m.refreshSemaphore <- struct{}{}: - defer func() { <-m.refreshSemaphore }() - case <-ctx.Done(): - return - } - m.refreshAuth(ctx, id) -} - -func (m *Manager) snapshotAuths() []*Auth { m.mu.RLock() - defer m.mu.RUnlock() - out := make([]*Auth, 0, len(m.auths)) - for _, a := range m.auths { - out = append(out, a.Clone()) + loop := m.refreshLoop + m.mu.RUnlock() + if loop == nil { + return } - return out + loop.queueReschedule(authID) } func (m *Manager) shouldRefresh(a *Auth, now time.Time) bool { @@ -3173,16 +3155,20 @@ func lookupMetadataTime(meta map[string]any, keys ...string) (time.Time, bool) { func (m *Manager) markRefreshPending(id string, now time.Time) bool { m.mu.Lock() - defer m.mu.Unlock() auth, ok := m.auths[id] if !ok || auth == nil || auth.Disabled { + m.mu.Unlock() return false } if !auth.NextRefreshAfter.IsZero() && now.Before(auth.NextRefreshAfter) { + m.mu.Unlock() return false } auth.NextRefreshAfter = now.Add(refreshPendingBackoff) m.auths[id] = auth + m.mu.Unlock() + + m.queueRefreshReschedule(id) return true } @@ -3209,16 +3195,21 @@ func (m *Manager) refreshAuth(ctx context.Context, id string) { log.Debugf("refreshed %s, %s, %v", auth.Provider, auth.ID, err) now := time.Now() if err != nil { + shouldReschedule := false m.mu.Lock() if current := m.auths[id]; current != nil { current.NextRefreshAfter = now.Add(refreshFailureBackoff) current.LastError = &Error{Message: err.Error()} m.auths[id] = current + shouldReschedule = true if m.scheduler != nil { m.scheduler.upsertAuth(current.Clone()) } } m.mu.Unlock() + if shouldReschedule { + m.queueRefreshReschedule(id) + } return } if updated == nil {