mirror of
https://mirror.skon.top/github.com/router-for-me/CLIProxyAPI
synced 2026-04-20 16:10:12 +08:00
feat(auth): implement auto-refresh loop for managing auth token schedule
- Introduced `authAutoRefreshLoop` to handle token refresh scheduling. - Replaced semaphore-based refresh logic in `Manager` with the new loop. - Added unit tests to verify refresh schedule logic and edge cases.
This commit is contained in:
444
sdk/cliproxy/auth/auto_refresh_loop.go
Normal file
444
sdk/cliproxy/auth/auto_refresh_loop.go
Normal file
@@ -0,0 +1,444 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"container/heap"
|
||||
"context"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type authAutoRefreshLoop struct {
|
||||
manager *Manager
|
||||
interval time.Duration
|
||||
|
||||
mu sync.Mutex
|
||||
queue refreshMinHeap
|
||||
index map[string]*refreshHeapItem
|
||||
dirty map[string]struct{}
|
||||
|
||||
wakeCh chan struct{}
|
||||
jobs chan string
|
||||
}
|
||||
|
||||
func newAuthAutoRefreshLoop(manager *Manager, interval time.Duration) *authAutoRefreshLoop {
|
||||
if interval <= 0 {
|
||||
interval = refreshCheckInterval
|
||||
}
|
||||
jobBuffer := refreshMaxConcurrency * 4
|
||||
if jobBuffer < 64 {
|
||||
jobBuffer = 64
|
||||
}
|
||||
return &authAutoRefreshLoop{
|
||||
manager: manager,
|
||||
interval: interval,
|
||||
index: make(map[string]*refreshHeapItem),
|
||||
dirty: make(map[string]struct{}),
|
||||
wakeCh: make(chan struct{}, 1),
|
||||
jobs: make(chan string, jobBuffer),
|
||||
}
|
||||
}
|
||||
|
||||
func (l *authAutoRefreshLoop) queueReschedule(authID string) {
|
||||
if l == nil || authID == "" {
|
||||
return
|
||||
}
|
||||
l.mu.Lock()
|
||||
l.dirty[authID] = struct{}{}
|
||||
l.mu.Unlock()
|
||||
select {
|
||||
case l.wakeCh <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func (l *authAutoRefreshLoop) run(ctx context.Context) {
|
||||
if l == nil || l.manager == nil {
|
||||
return
|
||||
}
|
||||
|
||||
for i := 0; i < refreshMaxConcurrency; i++ {
|
||||
go l.worker(ctx)
|
||||
}
|
||||
|
||||
l.loop(ctx)
|
||||
}
|
||||
|
||||
func (l *authAutoRefreshLoop) worker(ctx context.Context) {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case authID := <-l.jobs:
|
||||
if authID == "" {
|
||||
continue
|
||||
}
|
||||
l.manager.refreshAuth(ctx, authID)
|
||||
l.queueReschedule(authID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (l *authAutoRefreshLoop) rebuild(now time.Time) {
|
||||
type entry struct {
|
||||
id string
|
||||
next time.Time
|
||||
}
|
||||
|
||||
entries := make([]entry, 0)
|
||||
|
||||
l.manager.mu.RLock()
|
||||
for id, auth := range l.manager.auths {
|
||||
next, ok := nextRefreshCheckAt(now, auth, l.interval)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
entries = append(entries, entry{id: id, next: next})
|
||||
}
|
||||
l.manager.mu.RUnlock()
|
||||
|
||||
l.mu.Lock()
|
||||
l.queue = l.queue[:0]
|
||||
l.index = make(map[string]*refreshHeapItem, len(entries))
|
||||
for _, e := range entries {
|
||||
item := &refreshHeapItem{id: e.id, next: e.next}
|
||||
heap.Push(&l.queue, item)
|
||||
l.index[e.id] = item
|
||||
}
|
||||
l.mu.Unlock()
|
||||
}
|
||||
|
||||
func (l *authAutoRefreshLoop) loop(ctx context.Context) {
|
||||
timer := time.NewTimer(time.Hour)
|
||||
if !timer.Stop() {
|
||||
select {
|
||||
case <-timer.C:
|
||||
default:
|
||||
}
|
||||
}
|
||||
defer timer.Stop()
|
||||
|
||||
var timerCh <-chan time.Time
|
||||
l.resetTimer(timer, &timerCh, time.Now())
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-l.wakeCh:
|
||||
now := time.Now()
|
||||
l.applyDirty(now)
|
||||
l.resetTimer(timer, &timerCh, now)
|
||||
case <-timerCh:
|
||||
now := time.Now()
|
||||
l.handleDue(ctx, now)
|
||||
l.applyDirty(now)
|
||||
l.resetTimer(timer, &timerCh, now)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (l *authAutoRefreshLoop) resetTimer(timer *time.Timer, timerCh *<-chan time.Time, now time.Time) {
|
||||
next, ok := l.peek()
|
||||
if !ok {
|
||||
if !timer.Stop() {
|
||||
select {
|
||||
case <-timer.C:
|
||||
default:
|
||||
}
|
||||
}
|
||||
*timerCh = nil
|
||||
return
|
||||
}
|
||||
|
||||
wait := next.Sub(now)
|
||||
if wait < 0 {
|
||||
wait = 0
|
||||
}
|
||||
if !timer.Stop() {
|
||||
select {
|
||||
case <-timer.C:
|
||||
default:
|
||||
}
|
||||
}
|
||||
timer.Reset(wait)
|
||||
*timerCh = timer.C
|
||||
}
|
||||
|
||||
func (l *authAutoRefreshLoop) peek() (time.Time, bool) {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
if len(l.queue) == 0 {
|
||||
return time.Time{}, false
|
||||
}
|
||||
return l.queue[0].next, true
|
||||
}
|
||||
|
||||
func (l *authAutoRefreshLoop) handleDue(ctx context.Context, now time.Time) {
|
||||
due := l.popDue(now)
|
||||
if len(due) == 0 {
|
||||
return
|
||||
}
|
||||
if log.IsLevelEnabled(log.DebugLevel) {
|
||||
log.Debugf("auto-refresh scheduler due auths: %d", len(due))
|
||||
}
|
||||
for _, authID := range due {
|
||||
l.handleDueAuth(ctx, now, authID)
|
||||
}
|
||||
}
|
||||
|
||||
func (l *authAutoRefreshLoop) popDue(now time.Time) []string {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
|
||||
var due []string
|
||||
for len(l.queue) > 0 {
|
||||
item := l.queue[0]
|
||||
if item == nil || item.next.After(now) {
|
||||
break
|
||||
}
|
||||
popped := heap.Pop(&l.queue).(*refreshHeapItem)
|
||||
if popped == nil {
|
||||
continue
|
||||
}
|
||||
delete(l.index, popped.id)
|
||||
due = append(due, popped.id)
|
||||
}
|
||||
return due
|
||||
}
|
||||
|
||||
func (l *authAutoRefreshLoop) handleDueAuth(ctx context.Context, now time.Time, authID string) {
|
||||
if authID == "" {
|
||||
return
|
||||
}
|
||||
|
||||
manager := l.manager
|
||||
|
||||
manager.mu.RLock()
|
||||
auth := manager.auths[authID]
|
||||
if auth == nil {
|
||||
manager.mu.RUnlock()
|
||||
return
|
||||
}
|
||||
next, shouldSchedule := nextRefreshCheckAt(now, auth, l.interval)
|
||||
shouldRefresh := manager.shouldRefresh(auth, now)
|
||||
exec := manager.executors[auth.Provider]
|
||||
manager.mu.RUnlock()
|
||||
|
||||
if !shouldSchedule {
|
||||
l.remove(authID)
|
||||
return
|
||||
}
|
||||
|
||||
if !shouldRefresh {
|
||||
l.upsert(authID, next)
|
||||
return
|
||||
}
|
||||
|
||||
if exec == nil {
|
||||
l.upsert(authID, now.Add(l.interval))
|
||||
return
|
||||
}
|
||||
|
||||
if !manager.markRefreshPending(authID, now) {
|
||||
manager.mu.RLock()
|
||||
auth = manager.auths[authID]
|
||||
next, shouldSchedule = nextRefreshCheckAt(now, auth, l.interval)
|
||||
manager.mu.RUnlock()
|
||||
if shouldSchedule {
|
||||
l.upsert(authID, next)
|
||||
} else {
|
||||
l.remove(authID)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case l.jobs <- authID:
|
||||
}
|
||||
}
|
||||
|
||||
func (l *authAutoRefreshLoop) applyDirty(now time.Time) {
|
||||
dirty := l.drainDirty()
|
||||
if len(dirty) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
for _, authID := range dirty {
|
||||
l.manager.mu.RLock()
|
||||
auth := l.manager.auths[authID]
|
||||
next, ok := nextRefreshCheckAt(now, auth, l.interval)
|
||||
l.manager.mu.RUnlock()
|
||||
|
||||
if !ok {
|
||||
l.remove(authID)
|
||||
continue
|
||||
}
|
||||
l.upsert(authID, next)
|
||||
}
|
||||
}
|
||||
|
||||
func (l *authAutoRefreshLoop) drainDirty() []string {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
if len(l.dirty) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]string, 0, len(l.dirty))
|
||||
for authID := range l.dirty {
|
||||
out = append(out, authID)
|
||||
delete(l.dirty, authID)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (l *authAutoRefreshLoop) upsert(authID string, next time.Time) {
|
||||
if authID == "" || next.IsZero() {
|
||||
return
|
||||
}
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
if item, ok := l.index[authID]; ok && item != nil {
|
||||
item.next = next
|
||||
heap.Fix(&l.queue, item.index)
|
||||
return
|
||||
}
|
||||
item := &refreshHeapItem{id: authID, next: next}
|
||||
heap.Push(&l.queue, item)
|
||||
l.index[authID] = item
|
||||
}
|
||||
|
||||
func (l *authAutoRefreshLoop) remove(authID string) {
|
||||
if authID == "" {
|
||||
return
|
||||
}
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
item, ok := l.index[authID]
|
||||
if !ok || item == nil {
|
||||
return
|
||||
}
|
||||
heap.Remove(&l.queue, item.index)
|
||||
delete(l.index, authID)
|
||||
}
|
||||
|
||||
func nextRefreshCheckAt(now time.Time, auth *Auth, interval time.Duration) (time.Time, bool) {
|
||||
if auth == nil || auth.Disabled {
|
||||
return time.Time{}, false
|
||||
}
|
||||
|
||||
accountType, _ := auth.AccountInfo()
|
||||
if accountType == "api_key" {
|
||||
return time.Time{}, false
|
||||
}
|
||||
|
||||
if !auth.NextRefreshAfter.IsZero() && now.Before(auth.NextRefreshAfter) {
|
||||
return auth.NextRefreshAfter, true
|
||||
}
|
||||
|
||||
if evaluator, ok := auth.Runtime.(RefreshEvaluator); ok && evaluator != nil {
|
||||
if interval <= 0 {
|
||||
interval = refreshCheckInterval
|
||||
}
|
||||
return now.Add(interval), true
|
||||
}
|
||||
|
||||
lastRefresh := auth.LastRefreshedAt
|
||||
if lastRefresh.IsZero() {
|
||||
if ts, ok := authLastRefreshTimestamp(auth); ok {
|
||||
lastRefresh = ts
|
||||
}
|
||||
}
|
||||
|
||||
expiry, hasExpiry := auth.ExpirationTime()
|
||||
|
||||
if pref := authPreferredInterval(auth); pref > 0 {
|
||||
candidates := make([]time.Time, 0, 2)
|
||||
if hasExpiry && !expiry.IsZero() {
|
||||
if !expiry.After(now) || expiry.Sub(now) <= pref {
|
||||
return now, true
|
||||
}
|
||||
candidates = append(candidates, expiry.Add(-pref))
|
||||
}
|
||||
if lastRefresh.IsZero() {
|
||||
return now, true
|
||||
}
|
||||
candidates = append(candidates, lastRefresh.Add(pref))
|
||||
next := candidates[0]
|
||||
for _, candidate := range candidates[1:] {
|
||||
if candidate.Before(next) {
|
||||
next = candidate
|
||||
}
|
||||
}
|
||||
if !next.After(now) {
|
||||
return now, true
|
||||
}
|
||||
return next, true
|
||||
}
|
||||
|
||||
provider := strings.ToLower(auth.Provider)
|
||||
lead := ProviderRefreshLead(provider, auth.Runtime)
|
||||
if lead == nil {
|
||||
return time.Time{}, false
|
||||
}
|
||||
if hasExpiry && !expiry.IsZero() {
|
||||
dueAt := expiry.Add(-*lead)
|
||||
if !dueAt.After(now) {
|
||||
return now, true
|
||||
}
|
||||
return dueAt, true
|
||||
}
|
||||
if !lastRefresh.IsZero() {
|
||||
dueAt := lastRefresh.Add(*lead)
|
||||
if !dueAt.After(now) {
|
||||
return now, true
|
||||
}
|
||||
return dueAt, true
|
||||
}
|
||||
return now, true
|
||||
}
|
||||
|
||||
type refreshHeapItem struct {
|
||||
id string
|
||||
next time.Time
|
||||
index int
|
||||
}
|
||||
|
||||
type refreshMinHeap []*refreshHeapItem
|
||||
|
||||
func (h refreshMinHeap) Len() int { return len(h) }
|
||||
|
||||
func (h refreshMinHeap) Less(i, j int) bool {
|
||||
return h[i].next.Before(h[j].next)
|
||||
}
|
||||
|
||||
func (h refreshMinHeap) Swap(i, j int) {
|
||||
h[i], h[j] = h[j], h[i]
|
||||
h[i].index = i
|
||||
h[j].index = j
|
||||
}
|
||||
|
||||
func (h *refreshMinHeap) Push(x any) {
|
||||
item, ok := x.(*refreshHeapItem)
|
||||
if !ok || item == nil {
|
||||
return
|
||||
}
|
||||
item.index = len(*h)
|
||||
*h = append(*h, item)
|
||||
}
|
||||
|
||||
func (h *refreshMinHeap) Pop() any {
|
||||
old := *h
|
||||
n := len(old)
|
||||
if n == 0 {
|
||||
return (*refreshHeapItem)(nil)
|
||||
}
|
||||
item := old[n-1]
|
||||
item.index = -1
|
||||
*h = old[:n-1]
|
||||
return item
|
||||
}
|
||||
137
sdk/cliproxy/auth/auto_refresh_loop_test.go
Normal file
137
sdk/cliproxy/auth/auto_refresh_loop_test.go
Normal file
@@ -0,0 +1,137 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
type testRefreshEvaluator struct{}
|
||||
|
||||
func (testRefreshEvaluator) ShouldRefresh(time.Time, *Auth) bool { return false }
|
||||
|
||||
func setRefreshLeadFactory(t *testing.T, provider string, factory func() *time.Duration) {
|
||||
t.Helper()
|
||||
key := strings.ToLower(strings.TrimSpace(provider))
|
||||
refreshLeadMu.Lock()
|
||||
prev, hadPrev := refreshLeadFactories[key]
|
||||
if factory == nil {
|
||||
delete(refreshLeadFactories, key)
|
||||
} else {
|
||||
refreshLeadFactories[key] = factory
|
||||
}
|
||||
refreshLeadMu.Unlock()
|
||||
t.Cleanup(func() {
|
||||
refreshLeadMu.Lock()
|
||||
if hadPrev {
|
||||
refreshLeadFactories[key] = prev
|
||||
} else {
|
||||
delete(refreshLeadFactories, key)
|
||||
}
|
||||
refreshLeadMu.Unlock()
|
||||
})
|
||||
}
|
||||
|
||||
func TestNextRefreshCheckAt_DisabledUnschedule(t *testing.T) {
|
||||
now := time.Date(2026, 4, 12, 0, 0, 0, 0, time.UTC)
|
||||
auth := &Auth{ID: "a1", Provider: "test", Disabled: true}
|
||||
if _, ok := nextRefreshCheckAt(now, auth, 15*time.Minute); ok {
|
||||
t.Fatalf("nextRefreshCheckAt() ok = true, want false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNextRefreshCheckAt_APIKeyUnschedule(t *testing.T) {
|
||||
now := time.Date(2026, 4, 12, 0, 0, 0, 0, time.UTC)
|
||||
auth := &Auth{ID: "a1", Provider: "test", Attributes: map[string]string{"api_key": "k"}}
|
||||
if _, ok := nextRefreshCheckAt(now, auth, 15*time.Minute); ok {
|
||||
t.Fatalf("nextRefreshCheckAt() ok = true, want false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNextRefreshCheckAt_NextRefreshAfterGate(t *testing.T) {
|
||||
now := time.Date(2026, 4, 12, 0, 0, 0, 0, time.UTC)
|
||||
nextAfter := now.Add(30 * time.Minute)
|
||||
auth := &Auth{
|
||||
ID: "a1",
|
||||
Provider: "test",
|
||||
NextRefreshAfter: nextAfter,
|
||||
Metadata: map[string]any{"email": "x@example.com"},
|
||||
}
|
||||
got, ok := nextRefreshCheckAt(now, auth, 15*time.Minute)
|
||||
if !ok {
|
||||
t.Fatalf("nextRefreshCheckAt() ok = false, want true")
|
||||
}
|
||||
if !got.Equal(nextAfter) {
|
||||
t.Fatalf("nextRefreshCheckAt() = %s, want %s", got, nextAfter)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNextRefreshCheckAt_PreferredInterval_PicksEarliestCandidate(t *testing.T) {
|
||||
now := time.Date(2026, 4, 12, 0, 0, 0, 0, time.UTC)
|
||||
expiry := now.Add(20 * time.Minute)
|
||||
auth := &Auth{
|
||||
ID: "a1",
|
||||
Provider: "test",
|
||||
LastRefreshedAt: now,
|
||||
Metadata: map[string]any{
|
||||
"email": "x@example.com",
|
||||
"expires_at": expiry.Format(time.RFC3339),
|
||||
"refresh_interval_seconds": 900, // 15m
|
||||
},
|
||||
}
|
||||
got, ok := nextRefreshCheckAt(now, auth, 15*time.Minute)
|
||||
if !ok {
|
||||
t.Fatalf("nextRefreshCheckAt() ok = false, want true")
|
||||
}
|
||||
want := expiry.Add(-15 * time.Minute)
|
||||
if !got.Equal(want) {
|
||||
t.Fatalf("nextRefreshCheckAt() = %s, want %s", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNextRefreshCheckAt_ProviderLead_Expiry(t *testing.T) {
|
||||
now := time.Date(2026, 4, 12, 0, 0, 0, 0, time.UTC)
|
||||
expiry := now.Add(time.Hour)
|
||||
lead := 10 * time.Minute
|
||||
setRefreshLeadFactory(t, "provider-lead-expiry", func() *time.Duration {
|
||||
d := lead
|
||||
return &d
|
||||
})
|
||||
|
||||
auth := &Auth{
|
||||
ID: "a1",
|
||||
Provider: "provider-lead-expiry",
|
||||
Metadata: map[string]any{
|
||||
"email": "x@example.com",
|
||||
"expires_at": expiry.Format(time.RFC3339),
|
||||
},
|
||||
}
|
||||
|
||||
got, ok := nextRefreshCheckAt(now, auth, 15*time.Minute)
|
||||
if !ok {
|
||||
t.Fatalf("nextRefreshCheckAt() ok = false, want true")
|
||||
}
|
||||
want := expiry.Add(-lead)
|
||||
if !got.Equal(want) {
|
||||
t.Fatalf("nextRefreshCheckAt() = %s, want %s", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNextRefreshCheckAt_RefreshEvaluatorFallback(t *testing.T) {
|
||||
now := time.Date(2026, 4, 12, 0, 0, 0, 0, time.UTC)
|
||||
interval := 15 * time.Minute
|
||||
auth := &Auth{
|
||||
ID: "a1",
|
||||
Provider: "test",
|
||||
Metadata: map[string]any{"email": "x@example.com"},
|
||||
Runtime: testRefreshEvaluator{},
|
||||
}
|
||||
got, ok := nextRefreshCheckAt(now, auth, interval)
|
||||
if !ok {
|
||||
t.Fatalf("nextRefreshCheckAt() ok = false, want true")
|
||||
}
|
||||
want := now.Add(interval)
|
||||
if !got.Equal(want) {
|
||||
t.Fatalf("nextRefreshCheckAt() = %s, want %s", got, want)
|
||||
}
|
||||
}
|
||||
@@ -162,8 +162,8 @@ type Manager struct {
|
||||
rtProvider RoundTripperProvider
|
||||
|
||||
// Auto refresh state
|
||||
refreshCancel context.CancelFunc
|
||||
refreshSemaphore chan struct{}
|
||||
refreshCancel context.CancelFunc
|
||||
refreshLoop *authAutoRefreshLoop
|
||||
}
|
||||
|
||||
// NewManager constructs a manager with optional custom selector and hook.
|
||||
@@ -182,7 +182,6 @@ func NewManager(store Store, selector Selector, hook Hook) *Manager {
|
||||
auths: make(map[string]*Auth),
|
||||
providerOffsets: make(map[string]int),
|
||||
modelPoolOffsets: make(map[string]int),
|
||||
refreshSemaphore: make(chan struct{}, refreshMaxConcurrency),
|
||||
}
|
||||
// atomic.Value requires non-nil initial value.
|
||||
manager.runtimeConfig.Store(&internalconfig.Config{})
|
||||
@@ -214,6 +213,16 @@ func (m *Manager) syncScheduler() {
|
||||
m.syncSchedulerFromSnapshot(m.snapshotAuths())
|
||||
}
|
||||
|
||||
func (m *Manager) snapshotAuths() []*Auth {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
out := make([]*Auth, 0, len(m.auths))
|
||||
for _, a := range m.auths {
|
||||
out = append(out, a.Clone())
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// RefreshSchedulerEntry re-upserts a single auth into the scheduler so that its
|
||||
// supportedModelSet is rebuilt from the current global model registry state.
|
||||
// This must be called after models have been registered for a newly added auth,
|
||||
@@ -1088,6 +1097,7 @@ func (m *Manager) Register(ctx context.Context, auth *Auth) (*Auth, error) {
|
||||
if m.scheduler != nil {
|
||||
m.scheduler.upsertAuth(authClone)
|
||||
}
|
||||
m.queueRefreshReschedule(auth.ID)
|
||||
_ = m.persist(ctx, auth)
|
||||
m.hook.OnAuthRegistered(ctx, auth.Clone())
|
||||
return auth.Clone(), nil
|
||||
@@ -1118,6 +1128,7 @@ func (m *Manager) Update(ctx context.Context, auth *Auth) (*Auth, error) {
|
||||
if m.scheduler != nil {
|
||||
m.scheduler.upsertAuth(authClone)
|
||||
}
|
||||
m.queueRefreshReschedule(auth.ID)
|
||||
_ = m.persist(ctx, auth)
|
||||
m.hook.OnAuthUpdated(ctx, auth.Clone())
|
||||
return auth.Clone(), nil
|
||||
@@ -2890,80 +2901,51 @@ func (m *Manager) StartAutoRefresh(parent context.Context, interval time.Duratio
|
||||
if interval <= 0 {
|
||||
interval = refreshCheckInterval
|
||||
}
|
||||
if m.refreshCancel != nil {
|
||||
m.refreshCancel()
|
||||
m.refreshCancel = nil
|
||||
|
||||
m.mu.Lock()
|
||||
cancel := m.refreshCancel
|
||||
m.refreshCancel = nil
|
||||
m.refreshLoop = nil
|
||||
m.mu.Unlock()
|
||||
if cancel != nil {
|
||||
cancel()
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(parent)
|
||||
loop := newAuthAutoRefreshLoop(m, interval)
|
||||
|
||||
m.mu.Lock()
|
||||
m.refreshCancel = cancel
|
||||
go func() {
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
m.checkRefreshes(ctx)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
m.checkRefreshes(ctx)
|
||||
}
|
||||
}
|
||||
}()
|
||||
m.refreshLoop = loop
|
||||
m.mu.Unlock()
|
||||
|
||||
loop.rebuild(time.Now())
|
||||
go loop.run(ctx)
|
||||
}
|
||||
|
||||
// StopAutoRefresh cancels the background refresh loop, if running.
|
||||
func (m *Manager) StopAutoRefresh() {
|
||||
if m.refreshCancel != nil {
|
||||
m.refreshCancel()
|
||||
m.refreshCancel = nil
|
||||
m.mu.Lock()
|
||||
cancel := m.refreshCancel
|
||||
m.refreshCancel = nil
|
||||
m.refreshLoop = nil
|
||||
m.mu.Unlock()
|
||||
if cancel != nil {
|
||||
cancel()
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) checkRefreshes(ctx context.Context) {
|
||||
// log.Debugf("checking refreshes")
|
||||
now := time.Now()
|
||||
snapshot := m.snapshotAuths()
|
||||
for _, a := range snapshot {
|
||||
typ, _ := a.AccountInfo()
|
||||
if typ != "api_key" {
|
||||
if !m.shouldRefresh(a, now) {
|
||||
continue
|
||||
}
|
||||
log.Debugf("checking refresh for %s, %s, %s", a.Provider, a.ID, typ)
|
||||
|
||||
if exec := m.executorFor(a.Provider); exec == nil {
|
||||
continue
|
||||
}
|
||||
if !m.markRefreshPending(a.ID, now) {
|
||||
continue
|
||||
}
|
||||
go m.refreshAuthWithLimit(ctx, a.ID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) refreshAuthWithLimit(ctx context.Context, id string) {
|
||||
if m.refreshSemaphore == nil {
|
||||
m.refreshAuth(ctx, id)
|
||||
func (m *Manager) queueRefreshReschedule(authID string) {
|
||||
if m == nil || authID == "" {
|
||||
return
|
||||
}
|
||||
select {
|
||||
case m.refreshSemaphore <- struct{}{}:
|
||||
defer func() { <-m.refreshSemaphore }()
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
m.refreshAuth(ctx, id)
|
||||
}
|
||||
|
||||
func (m *Manager) snapshotAuths() []*Auth {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
out := make([]*Auth, 0, len(m.auths))
|
||||
for _, a := range m.auths {
|
||||
out = append(out, a.Clone())
|
||||
loop := m.refreshLoop
|
||||
m.mu.RUnlock()
|
||||
if loop == nil {
|
||||
return
|
||||
}
|
||||
return out
|
||||
loop.queueReschedule(authID)
|
||||
}
|
||||
|
||||
func (m *Manager) shouldRefresh(a *Auth, now time.Time) bool {
|
||||
@@ -3173,16 +3155,20 @@ func lookupMetadataTime(meta map[string]any, keys ...string) (time.Time, bool) {
|
||||
|
||||
func (m *Manager) markRefreshPending(id string, now time.Time) bool {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
auth, ok := m.auths[id]
|
||||
if !ok || auth == nil || auth.Disabled {
|
||||
m.mu.Unlock()
|
||||
return false
|
||||
}
|
||||
if !auth.NextRefreshAfter.IsZero() && now.Before(auth.NextRefreshAfter) {
|
||||
m.mu.Unlock()
|
||||
return false
|
||||
}
|
||||
auth.NextRefreshAfter = now.Add(refreshPendingBackoff)
|
||||
m.auths[id] = auth
|
||||
m.mu.Unlock()
|
||||
|
||||
m.queueRefreshReschedule(id)
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -3209,16 +3195,21 @@ func (m *Manager) refreshAuth(ctx context.Context, id string) {
|
||||
log.Debugf("refreshed %s, %s, %v", auth.Provider, auth.ID, err)
|
||||
now := time.Now()
|
||||
if err != nil {
|
||||
shouldReschedule := false
|
||||
m.mu.Lock()
|
||||
if current := m.auths[id]; current != nil {
|
||||
current.NextRefreshAfter = now.Add(refreshFailureBackoff)
|
||||
current.LastError = &Error{Message: err.Error()}
|
||||
m.auths[id] = current
|
||||
shouldReschedule = true
|
||||
if m.scheduler != nil {
|
||||
m.scheduler.upsertAuth(current.Clone())
|
||||
}
|
||||
}
|
||||
m.mu.Unlock()
|
||||
if shouldReschedule {
|
||||
m.queueRefreshReschedule(id)
|
||||
}
|
||||
return
|
||||
}
|
||||
if updated == nil {
|
||||
|
||||
Reference in New Issue
Block a user