feat(auth): implement auto-refresh loop for managing auth token schedule

- Introduced `authAutoRefreshLoop` to handle token refresh scheduling.
- Replaced semaphore-based refresh logic in `Manager` with the new loop.
- Added unit tests to verify refresh schedule logic and edge cases.
This commit is contained in:
Luis Pater
2026-04-12 02:06:40 +08:00
parent 0ab1f5412f
commit a583463d60
3 changed files with 636 additions and 64 deletions

View File

@@ -0,0 +1,444 @@
package auth
import (
"container/heap"
"context"
"strings"
"sync"
"time"
log "github.com/sirupsen/logrus"
)
type authAutoRefreshLoop struct {
manager *Manager
interval time.Duration
mu sync.Mutex
queue refreshMinHeap
index map[string]*refreshHeapItem
dirty map[string]struct{}
wakeCh chan struct{}
jobs chan string
}
func newAuthAutoRefreshLoop(manager *Manager, interval time.Duration) *authAutoRefreshLoop {
if interval <= 0 {
interval = refreshCheckInterval
}
jobBuffer := refreshMaxConcurrency * 4
if jobBuffer < 64 {
jobBuffer = 64
}
return &authAutoRefreshLoop{
manager: manager,
interval: interval,
index: make(map[string]*refreshHeapItem),
dirty: make(map[string]struct{}),
wakeCh: make(chan struct{}, 1),
jobs: make(chan string, jobBuffer),
}
}
func (l *authAutoRefreshLoop) queueReschedule(authID string) {
if l == nil || authID == "" {
return
}
l.mu.Lock()
l.dirty[authID] = struct{}{}
l.mu.Unlock()
select {
case l.wakeCh <- struct{}{}:
default:
}
}
func (l *authAutoRefreshLoop) run(ctx context.Context) {
if l == nil || l.manager == nil {
return
}
for i := 0; i < refreshMaxConcurrency; i++ {
go l.worker(ctx)
}
l.loop(ctx)
}
func (l *authAutoRefreshLoop) worker(ctx context.Context) {
for {
select {
case <-ctx.Done():
return
case authID := <-l.jobs:
if authID == "" {
continue
}
l.manager.refreshAuth(ctx, authID)
l.queueReschedule(authID)
}
}
}
func (l *authAutoRefreshLoop) rebuild(now time.Time) {
type entry struct {
id string
next time.Time
}
entries := make([]entry, 0)
l.manager.mu.RLock()
for id, auth := range l.manager.auths {
next, ok := nextRefreshCheckAt(now, auth, l.interval)
if !ok {
continue
}
entries = append(entries, entry{id: id, next: next})
}
l.manager.mu.RUnlock()
l.mu.Lock()
l.queue = l.queue[:0]
l.index = make(map[string]*refreshHeapItem, len(entries))
for _, e := range entries {
item := &refreshHeapItem{id: e.id, next: e.next}
heap.Push(&l.queue, item)
l.index[e.id] = item
}
l.mu.Unlock()
}
func (l *authAutoRefreshLoop) loop(ctx context.Context) {
timer := time.NewTimer(time.Hour)
if !timer.Stop() {
select {
case <-timer.C:
default:
}
}
defer timer.Stop()
var timerCh <-chan time.Time
l.resetTimer(timer, &timerCh, time.Now())
for {
select {
case <-ctx.Done():
return
case <-l.wakeCh:
now := time.Now()
l.applyDirty(now)
l.resetTimer(timer, &timerCh, now)
case <-timerCh:
now := time.Now()
l.handleDue(ctx, now)
l.applyDirty(now)
l.resetTimer(timer, &timerCh, now)
}
}
}
func (l *authAutoRefreshLoop) resetTimer(timer *time.Timer, timerCh *<-chan time.Time, now time.Time) {
next, ok := l.peek()
if !ok {
if !timer.Stop() {
select {
case <-timer.C:
default:
}
}
*timerCh = nil
return
}
wait := next.Sub(now)
if wait < 0 {
wait = 0
}
if !timer.Stop() {
select {
case <-timer.C:
default:
}
}
timer.Reset(wait)
*timerCh = timer.C
}
func (l *authAutoRefreshLoop) peek() (time.Time, bool) {
l.mu.Lock()
defer l.mu.Unlock()
if len(l.queue) == 0 {
return time.Time{}, false
}
return l.queue[0].next, true
}
func (l *authAutoRefreshLoop) handleDue(ctx context.Context, now time.Time) {
due := l.popDue(now)
if len(due) == 0 {
return
}
if log.IsLevelEnabled(log.DebugLevel) {
log.Debugf("auto-refresh scheduler due auths: %d", len(due))
}
for _, authID := range due {
l.handleDueAuth(ctx, now, authID)
}
}
func (l *authAutoRefreshLoop) popDue(now time.Time) []string {
l.mu.Lock()
defer l.mu.Unlock()
var due []string
for len(l.queue) > 0 {
item := l.queue[0]
if item == nil || item.next.After(now) {
break
}
popped := heap.Pop(&l.queue).(*refreshHeapItem)
if popped == nil {
continue
}
delete(l.index, popped.id)
due = append(due, popped.id)
}
return due
}
func (l *authAutoRefreshLoop) handleDueAuth(ctx context.Context, now time.Time, authID string) {
if authID == "" {
return
}
manager := l.manager
manager.mu.RLock()
auth := manager.auths[authID]
if auth == nil {
manager.mu.RUnlock()
return
}
next, shouldSchedule := nextRefreshCheckAt(now, auth, l.interval)
shouldRefresh := manager.shouldRefresh(auth, now)
exec := manager.executors[auth.Provider]
manager.mu.RUnlock()
if !shouldSchedule {
l.remove(authID)
return
}
if !shouldRefresh {
l.upsert(authID, next)
return
}
if exec == nil {
l.upsert(authID, now.Add(l.interval))
return
}
if !manager.markRefreshPending(authID, now) {
manager.mu.RLock()
auth = manager.auths[authID]
next, shouldSchedule = nextRefreshCheckAt(now, auth, l.interval)
manager.mu.RUnlock()
if shouldSchedule {
l.upsert(authID, next)
} else {
l.remove(authID)
}
return
}
select {
case <-ctx.Done():
return
case l.jobs <- authID:
}
}
func (l *authAutoRefreshLoop) applyDirty(now time.Time) {
dirty := l.drainDirty()
if len(dirty) == 0 {
return
}
for _, authID := range dirty {
l.manager.mu.RLock()
auth := l.manager.auths[authID]
next, ok := nextRefreshCheckAt(now, auth, l.interval)
l.manager.mu.RUnlock()
if !ok {
l.remove(authID)
continue
}
l.upsert(authID, next)
}
}
func (l *authAutoRefreshLoop) drainDirty() []string {
l.mu.Lock()
defer l.mu.Unlock()
if len(l.dirty) == 0 {
return nil
}
out := make([]string, 0, len(l.dirty))
for authID := range l.dirty {
out = append(out, authID)
delete(l.dirty, authID)
}
return out
}
func (l *authAutoRefreshLoop) upsert(authID string, next time.Time) {
if authID == "" || next.IsZero() {
return
}
l.mu.Lock()
defer l.mu.Unlock()
if item, ok := l.index[authID]; ok && item != nil {
item.next = next
heap.Fix(&l.queue, item.index)
return
}
item := &refreshHeapItem{id: authID, next: next}
heap.Push(&l.queue, item)
l.index[authID] = item
}
func (l *authAutoRefreshLoop) remove(authID string) {
if authID == "" {
return
}
l.mu.Lock()
defer l.mu.Unlock()
item, ok := l.index[authID]
if !ok || item == nil {
return
}
heap.Remove(&l.queue, item.index)
delete(l.index, authID)
}
func nextRefreshCheckAt(now time.Time, auth *Auth, interval time.Duration) (time.Time, bool) {
if auth == nil || auth.Disabled {
return time.Time{}, false
}
accountType, _ := auth.AccountInfo()
if accountType == "api_key" {
return time.Time{}, false
}
if !auth.NextRefreshAfter.IsZero() && now.Before(auth.NextRefreshAfter) {
return auth.NextRefreshAfter, true
}
if evaluator, ok := auth.Runtime.(RefreshEvaluator); ok && evaluator != nil {
if interval <= 0 {
interval = refreshCheckInterval
}
return now.Add(interval), true
}
lastRefresh := auth.LastRefreshedAt
if lastRefresh.IsZero() {
if ts, ok := authLastRefreshTimestamp(auth); ok {
lastRefresh = ts
}
}
expiry, hasExpiry := auth.ExpirationTime()
if pref := authPreferredInterval(auth); pref > 0 {
candidates := make([]time.Time, 0, 2)
if hasExpiry && !expiry.IsZero() {
if !expiry.After(now) || expiry.Sub(now) <= pref {
return now, true
}
candidates = append(candidates, expiry.Add(-pref))
}
if lastRefresh.IsZero() {
return now, true
}
candidates = append(candidates, lastRefresh.Add(pref))
next := candidates[0]
for _, candidate := range candidates[1:] {
if candidate.Before(next) {
next = candidate
}
}
if !next.After(now) {
return now, true
}
return next, true
}
provider := strings.ToLower(auth.Provider)
lead := ProviderRefreshLead(provider, auth.Runtime)
if lead == nil {
return time.Time{}, false
}
if hasExpiry && !expiry.IsZero() {
dueAt := expiry.Add(-*lead)
if !dueAt.After(now) {
return now, true
}
return dueAt, true
}
if !lastRefresh.IsZero() {
dueAt := lastRefresh.Add(*lead)
if !dueAt.After(now) {
return now, true
}
return dueAt, true
}
return now, true
}
type refreshHeapItem struct {
id string
next time.Time
index int
}
type refreshMinHeap []*refreshHeapItem
func (h refreshMinHeap) Len() int { return len(h) }
func (h refreshMinHeap) Less(i, j int) bool {
return h[i].next.Before(h[j].next)
}
func (h refreshMinHeap) Swap(i, j int) {
h[i], h[j] = h[j], h[i]
h[i].index = i
h[j].index = j
}
func (h *refreshMinHeap) Push(x any) {
item, ok := x.(*refreshHeapItem)
if !ok || item == nil {
return
}
item.index = len(*h)
*h = append(*h, item)
}
func (h *refreshMinHeap) Pop() any {
old := *h
n := len(old)
if n == 0 {
return (*refreshHeapItem)(nil)
}
item := old[n-1]
item.index = -1
*h = old[:n-1]
return item
}

View File

@@ -0,0 +1,137 @@
package auth
import (
"strings"
"testing"
"time"
)
type testRefreshEvaluator struct{}
func (testRefreshEvaluator) ShouldRefresh(time.Time, *Auth) bool { return false }
func setRefreshLeadFactory(t *testing.T, provider string, factory func() *time.Duration) {
t.Helper()
key := strings.ToLower(strings.TrimSpace(provider))
refreshLeadMu.Lock()
prev, hadPrev := refreshLeadFactories[key]
if factory == nil {
delete(refreshLeadFactories, key)
} else {
refreshLeadFactories[key] = factory
}
refreshLeadMu.Unlock()
t.Cleanup(func() {
refreshLeadMu.Lock()
if hadPrev {
refreshLeadFactories[key] = prev
} else {
delete(refreshLeadFactories, key)
}
refreshLeadMu.Unlock()
})
}
func TestNextRefreshCheckAt_DisabledUnschedule(t *testing.T) {
now := time.Date(2026, 4, 12, 0, 0, 0, 0, time.UTC)
auth := &Auth{ID: "a1", Provider: "test", Disabled: true}
if _, ok := nextRefreshCheckAt(now, auth, 15*time.Minute); ok {
t.Fatalf("nextRefreshCheckAt() ok = true, want false")
}
}
func TestNextRefreshCheckAt_APIKeyUnschedule(t *testing.T) {
now := time.Date(2026, 4, 12, 0, 0, 0, 0, time.UTC)
auth := &Auth{ID: "a1", Provider: "test", Attributes: map[string]string{"api_key": "k"}}
if _, ok := nextRefreshCheckAt(now, auth, 15*time.Minute); ok {
t.Fatalf("nextRefreshCheckAt() ok = true, want false")
}
}
func TestNextRefreshCheckAt_NextRefreshAfterGate(t *testing.T) {
now := time.Date(2026, 4, 12, 0, 0, 0, 0, time.UTC)
nextAfter := now.Add(30 * time.Minute)
auth := &Auth{
ID: "a1",
Provider: "test",
NextRefreshAfter: nextAfter,
Metadata: map[string]any{"email": "x@example.com"},
}
got, ok := nextRefreshCheckAt(now, auth, 15*time.Minute)
if !ok {
t.Fatalf("nextRefreshCheckAt() ok = false, want true")
}
if !got.Equal(nextAfter) {
t.Fatalf("nextRefreshCheckAt() = %s, want %s", got, nextAfter)
}
}
func TestNextRefreshCheckAt_PreferredInterval_PicksEarliestCandidate(t *testing.T) {
now := time.Date(2026, 4, 12, 0, 0, 0, 0, time.UTC)
expiry := now.Add(20 * time.Minute)
auth := &Auth{
ID: "a1",
Provider: "test",
LastRefreshedAt: now,
Metadata: map[string]any{
"email": "x@example.com",
"expires_at": expiry.Format(time.RFC3339),
"refresh_interval_seconds": 900, // 15m
},
}
got, ok := nextRefreshCheckAt(now, auth, 15*time.Minute)
if !ok {
t.Fatalf("nextRefreshCheckAt() ok = false, want true")
}
want := expiry.Add(-15 * time.Minute)
if !got.Equal(want) {
t.Fatalf("nextRefreshCheckAt() = %s, want %s", got, want)
}
}
func TestNextRefreshCheckAt_ProviderLead_Expiry(t *testing.T) {
now := time.Date(2026, 4, 12, 0, 0, 0, 0, time.UTC)
expiry := now.Add(time.Hour)
lead := 10 * time.Minute
setRefreshLeadFactory(t, "provider-lead-expiry", func() *time.Duration {
d := lead
return &d
})
auth := &Auth{
ID: "a1",
Provider: "provider-lead-expiry",
Metadata: map[string]any{
"email": "x@example.com",
"expires_at": expiry.Format(time.RFC3339),
},
}
got, ok := nextRefreshCheckAt(now, auth, 15*time.Minute)
if !ok {
t.Fatalf("nextRefreshCheckAt() ok = false, want true")
}
want := expiry.Add(-lead)
if !got.Equal(want) {
t.Fatalf("nextRefreshCheckAt() = %s, want %s", got, want)
}
}
func TestNextRefreshCheckAt_RefreshEvaluatorFallback(t *testing.T) {
now := time.Date(2026, 4, 12, 0, 0, 0, 0, time.UTC)
interval := 15 * time.Minute
auth := &Auth{
ID: "a1",
Provider: "test",
Metadata: map[string]any{"email": "x@example.com"},
Runtime: testRefreshEvaluator{},
}
got, ok := nextRefreshCheckAt(now, auth, interval)
if !ok {
t.Fatalf("nextRefreshCheckAt() ok = false, want true")
}
want := now.Add(interval)
if !got.Equal(want) {
t.Fatalf("nextRefreshCheckAt() = %s, want %s", got, want)
}
}

View File

@@ -162,8 +162,8 @@ type Manager struct {
rtProvider RoundTripperProvider
// Auto refresh state
refreshCancel context.CancelFunc
refreshSemaphore chan struct{}
refreshCancel context.CancelFunc
refreshLoop *authAutoRefreshLoop
}
// NewManager constructs a manager with optional custom selector and hook.
@@ -182,7 +182,6 @@ func NewManager(store Store, selector Selector, hook Hook) *Manager {
auths: make(map[string]*Auth),
providerOffsets: make(map[string]int),
modelPoolOffsets: make(map[string]int),
refreshSemaphore: make(chan struct{}, refreshMaxConcurrency),
}
// atomic.Value requires non-nil initial value.
manager.runtimeConfig.Store(&internalconfig.Config{})
@@ -214,6 +213,16 @@ func (m *Manager) syncScheduler() {
m.syncSchedulerFromSnapshot(m.snapshotAuths())
}
func (m *Manager) snapshotAuths() []*Auth {
m.mu.RLock()
defer m.mu.RUnlock()
out := make([]*Auth, 0, len(m.auths))
for _, a := range m.auths {
out = append(out, a.Clone())
}
return out
}
// RefreshSchedulerEntry re-upserts a single auth into the scheduler so that its
// supportedModelSet is rebuilt from the current global model registry state.
// This must be called after models have been registered for a newly added auth,
@@ -1088,6 +1097,7 @@ func (m *Manager) Register(ctx context.Context, auth *Auth) (*Auth, error) {
if m.scheduler != nil {
m.scheduler.upsertAuth(authClone)
}
m.queueRefreshReschedule(auth.ID)
_ = m.persist(ctx, auth)
m.hook.OnAuthRegistered(ctx, auth.Clone())
return auth.Clone(), nil
@@ -1118,6 +1128,7 @@ func (m *Manager) Update(ctx context.Context, auth *Auth) (*Auth, error) {
if m.scheduler != nil {
m.scheduler.upsertAuth(authClone)
}
m.queueRefreshReschedule(auth.ID)
_ = m.persist(ctx, auth)
m.hook.OnAuthUpdated(ctx, auth.Clone())
return auth.Clone(), nil
@@ -2890,80 +2901,51 @@ func (m *Manager) StartAutoRefresh(parent context.Context, interval time.Duratio
if interval <= 0 {
interval = refreshCheckInterval
}
if m.refreshCancel != nil {
m.refreshCancel()
m.refreshCancel = nil
m.mu.Lock()
cancel := m.refreshCancel
m.refreshCancel = nil
m.refreshLoop = nil
m.mu.Unlock()
if cancel != nil {
cancel()
}
ctx, cancel := context.WithCancel(parent)
loop := newAuthAutoRefreshLoop(m, interval)
m.mu.Lock()
m.refreshCancel = cancel
go func() {
ticker := time.NewTicker(interval)
defer ticker.Stop()
m.checkRefreshes(ctx)
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
m.checkRefreshes(ctx)
}
}
}()
m.refreshLoop = loop
m.mu.Unlock()
loop.rebuild(time.Now())
go loop.run(ctx)
}
// StopAutoRefresh cancels the background refresh loop, if running.
func (m *Manager) StopAutoRefresh() {
if m.refreshCancel != nil {
m.refreshCancel()
m.refreshCancel = nil
m.mu.Lock()
cancel := m.refreshCancel
m.refreshCancel = nil
m.refreshLoop = nil
m.mu.Unlock()
if cancel != nil {
cancel()
}
}
func (m *Manager) checkRefreshes(ctx context.Context) {
// log.Debugf("checking refreshes")
now := time.Now()
snapshot := m.snapshotAuths()
for _, a := range snapshot {
typ, _ := a.AccountInfo()
if typ != "api_key" {
if !m.shouldRefresh(a, now) {
continue
}
log.Debugf("checking refresh for %s, %s, %s", a.Provider, a.ID, typ)
if exec := m.executorFor(a.Provider); exec == nil {
continue
}
if !m.markRefreshPending(a.ID, now) {
continue
}
go m.refreshAuthWithLimit(ctx, a.ID)
}
}
}
func (m *Manager) refreshAuthWithLimit(ctx context.Context, id string) {
if m.refreshSemaphore == nil {
m.refreshAuth(ctx, id)
func (m *Manager) queueRefreshReschedule(authID string) {
if m == nil || authID == "" {
return
}
select {
case m.refreshSemaphore <- struct{}{}:
defer func() { <-m.refreshSemaphore }()
case <-ctx.Done():
return
}
m.refreshAuth(ctx, id)
}
func (m *Manager) snapshotAuths() []*Auth {
m.mu.RLock()
defer m.mu.RUnlock()
out := make([]*Auth, 0, len(m.auths))
for _, a := range m.auths {
out = append(out, a.Clone())
loop := m.refreshLoop
m.mu.RUnlock()
if loop == nil {
return
}
return out
loop.queueReschedule(authID)
}
func (m *Manager) shouldRefresh(a *Auth, now time.Time) bool {
@@ -3173,16 +3155,20 @@ func lookupMetadataTime(meta map[string]any, keys ...string) (time.Time, bool) {
func (m *Manager) markRefreshPending(id string, now time.Time) bool {
m.mu.Lock()
defer m.mu.Unlock()
auth, ok := m.auths[id]
if !ok || auth == nil || auth.Disabled {
m.mu.Unlock()
return false
}
if !auth.NextRefreshAfter.IsZero() && now.Before(auth.NextRefreshAfter) {
m.mu.Unlock()
return false
}
auth.NextRefreshAfter = now.Add(refreshPendingBackoff)
m.auths[id] = auth
m.mu.Unlock()
m.queueRefreshReschedule(id)
return true
}
@@ -3209,16 +3195,21 @@ func (m *Manager) refreshAuth(ctx context.Context, id string) {
log.Debugf("refreshed %s, %s, %v", auth.Provider, auth.ID, err)
now := time.Now()
if err != nil {
shouldReschedule := false
m.mu.Lock()
if current := m.auths[id]; current != nil {
current.NextRefreshAfter = now.Add(refreshFailureBackoff)
current.LastError = &Error{Message: err.Error()}
m.auths[id] = current
shouldReschedule = true
if m.scheduler != nil {
m.scheduler.upsertAuth(current.Clone())
}
}
m.mu.Unlock()
if shouldReschedule {
m.queueRefreshReschedule(id)
}
return
}
if updated == nil {