diff --git a/core/gallery/models.go b/core/gallery/models.go index 82a36d7a5d8b..baea4c1bb647 100644 --- a/core/gallery/models.go +++ b/core/gallery/models.go @@ -196,7 +196,7 @@ func InstallModel(ctx context.Context, systemState *system.SystemState, nameOver } } uri := downloader.URI(file.URI) - if err := uri.DownloadFileWithContext(ctx, filePath, file.SHA256, i, len(config.Files), downloadStatus); err != nil { + if err := uri.DownloadFileWithContext(downloader.ContextWithModelID(ctx, config.Name), filePath, file.SHA256, i, len(config.Files), downloadStatus); err != nil { return nil, err } } diff --git a/core/http/routes/ui_api.go b/core/http/routes/ui_api.go index e2689427339a..761366cb6856 100644 --- a/core/http/routes/ui_api.go +++ b/core/http/routes/ui_api.go @@ -151,6 +151,47 @@ func getDirectorySize(path string) (int64, error) { return totalSize, nil } +// parseRateString converts a human-readable bandwidth string (e.g. "2mb", +// "500kb", "10mb") to bytes per second. Returns ≤ 0 for unlimited. +func parseRateString(s string) (int64, error) { + s = strings.TrimSpace(strings.ToLower(s)) + if s == "" || s == "0" || s == "unlimited" || s == "-1" { + return 0, nil + } + var multiplier int64 = 1 + switch { + case strings.HasSuffix(s, "gb"): + multiplier = 1 << 30 + s = strings.TrimSuffix(s, "gb") + case strings.HasSuffix(s, "g"): + multiplier = 1 << 30 + s = strings.TrimSuffix(s, "g") + case strings.HasSuffix(s, "mb"): + multiplier = 1 << 20 + s = strings.TrimSuffix(s, "mb") + case strings.HasSuffix(s, "m"): + multiplier = 1 << 20 + s = strings.TrimSuffix(s, "m") + case strings.HasSuffix(s, "kb"): + multiplier = 1 << 10 + s = strings.TrimSuffix(s, "kb") + case strings.HasSuffix(s, "k"): + multiplier = 1 << 10 + s = strings.TrimSuffix(s, "k") + case strings.HasSuffix(s, "b"): + multiplier = 1 + s = strings.TrimSuffix(s, "b") + } + val, err := strconv.ParseInt(s, 10, 64) + if err != nil { + return 0, fmt.Errorf("cannot parse %q as a number", s) + } + if val <= 0 { + return 0, nil + } + return val * multiplier, nil +} + // RegisterUIAPIRoutes registers JSON API routes for the web UI func RegisterUIAPIRoutes(app *echo.Echo, cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig, galleryService *galleryop.GalleryService, opcache *galleryop.OpCache, applicationInstance *application.Application, adminMiddleware echo.MiddlewareFunc) { @@ -362,6 +403,80 @@ func RegisterUIAPIRoutes(app *echo.Echo, cl *config.ModelConfigLoader, ml *model }) }, adminMiddleware) + // Pause operation endpoint (admin only) + app.POST("/api/operations/:jobID/pause", func(c echo.Context) error { + jobID := c.Param("jobID") + xlog.Debug("API request to pause operation", "jobID", jobID) + + err := galleryService.PauseOperation(jobID) + if err != nil { + xlog.Error("Failed to pause operation", "error", err, "jobID", jobID) + return c.JSON(http.StatusBadRequest, map[string]any{ + "error": err.Error(), + }) + } + + return c.JSON(200, map[string]any{ + "success": true, + "message": "Operation paused", + }) + }, adminMiddleware) + + // Resume operation endpoint (admin only) + app.POST("/api/operations/:jobID/resume", func(c echo.Context) error { + jobID := c.Param("jobID") + xlog.Debug("API request to resume operation", "jobID", jobID) + + err := galleryService.ResumeOperation(jobID) + if err != nil { + xlog.Error("Failed to resume operation", "error", err, "jobID", jobID) + return c.JSON(http.StatusBadRequest, map[string]any{ + "error": err.Error(), + }) + } + + return c.JSON(200, map[string]any{ + "success": true, + "message": "Operation resumed", + }) + }, adminMiddleware) + + // Pause all operations (admin only) + app.POST("/api/operations/pause-all", func(c echo.Context) error { + xlog.Debug("API request to pause all operations") + + err := galleryService.PauseAllOperations() + if err != nil { + xlog.Error("Failed to pause all operations", "error", err) + return c.JSON(http.StatusBadRequest, map[string]any{ + "error": err.Error(), + }) + } + + return c.JSON(200, map[string]any{ + "success": true, + "message": "All operations paused", + }) + }, adminMiddleware) + + // Resume all operations (admin only) + app.POST("/api/operations/resume-all", func(c echo.Context) error { + xlog.Debug("API request to resume all operations") + + err := galleryService.ResumeAllOperations() + if err != nil { + xlog.Error("Failed to resume all operations", "error", err) + return c.JSON(http.StatusBadRequest, map[string]any{ + "error": err.Error(), + }) + } + + return c.JSON(200, map[string]any{ + "success": true, + "message": "All operations resumed", + }) + }, adminMiddleware) + // Dismiss a failed operation (acknowledge the error and remove it from the list) app.POST("/api/operations/:jobID/dismiss", func(c echo.Context) error { jobID := c.Param("jobID") @@ -376,6 +491,37 @@ func RegisterUIAPIRoutes(app *echo.Echo, cl *config.ModelConfigLoader, ml *model }) }, adminMiddleware) + // Throttle (rate-limit) an active download (admin only) + // Query param: ?rate=2mb or ?rate=500kb. Use 0 or -1 to remove the limit. + app.POST("/api/operations/:jobID/throttle", func(c echo.Context) error { + jobID := c.Param("jobID") + rateStr := c.QueryParam("rate") + if rateStr == "" { + return c.JSON(http.StatusBadRequest, map[string]any{ + "error": "query parameter 'rate' is required (e.g. rate=2mb, rate=500kb)", + }) + } + bytesPerSec, err := parseRateString(rateStr) + if err != nil { + return c.JSON(http.StatusBadRequest, map[string]any{ + "error": fmt.Sprintf("invalid rate %q: %v", rateStr, err), + }) + } + + xlog.Debug("API request to throttle operation", "jobID", jobID, "rate", bytesPerSec) + if err := galleryService.SetOperationRateLimit(jobID, bytesPerSec); err != nil { + xlog.Error("Failed to throttle operation", "error", err, "jobID", jobID) + return c.JSON(http.StatusBadRequest, map[string]any{ + "error": err.Error(), + }) + } + + return c.JSON(200, map[string]any{ + "success": true, + "message": fmt.Sprintf("Operation throttled to %d bytes/sec", bytesPerSec), + }) + }, adminMiddleware) + // Model Gallery APIs (admin only) app.GET("/api/models", func(c echo.Context) error { term := c.QueryParam("term") diff --git a/core/services/galleryop/models.go b/core/services/galleryop/models.go index 3d50a330edb2..2509e6e13de5 100644 --- a/core/services/galleryop/models.go +++ b/core/services/galleryop/models.go @@ -10,6 +10,7 @@ import ( "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/gallery" "github.com/mudler/LocalAI/core/services/messaging" + "github.com/mudler/LocalAI/pkg/downloader" "github.com/mudler/LocalAI/pkg/model" "github.com/mudler/LocalAI/pkg/system" "github.com/mudler/LocalAI/pkg/utils" @@ -83,6 +84,26 @@ func (g *GalleryService) modelHandler(op *ManagementOp[gallery.GalleryModel, gal }) return err } + // Check if the download was paused — the .partial is preserved for + // later resume, so this is not a terminal failure. + if errors.Is(err, downloader.ErrUserPaused) { + g.UpdateStatus(op.ID, &OpStatus{ + Paused: true, + Message: "paused", + GalleryElementName: op.GalleryElementName, + Cancellable: true, + }) + // Store the operation metadata so ResumeOperation can re-queue it. + g.storePausedOp(op.ID, &PausedModelOp{ + Galleries: op.Galleries, + BackendGalleries: op.BackendGalleries, + Req: op.Req, + GalleryElementName: op.GalleryElementName, + }) + // Return nil so Start() does not call updateError — this is not a + // failure, it's a deliberate pause. + return nil + } return err } diff --git a/core/services/galleryop/operation.go b/core/services/galleryop/operation.go index d4c12e8bc211..8aef1f0c2340 100644 --- a/core/services/galleryop/operation.go +++ b/core/services/galleryop/operation.go @@ -59,6 +59,7 @@ type OpStatus struct { GalleryElementName string `json:"gallery_element_name"` Cancelled bool `json:"cancelled"` // Cancelled is true if the operation was cancelled Cancellable bool `json:"cancellable"` // Cancellable is true if the operation can be cancelled + Paused bool `json:"paused"` // Paused is true if the operation was paused (resumable) // Nodes is the per-node breakdown for a fanned-out backend install. // Populated by DistributedBackendManager (per-node terminal status) @@ -87,6 +88,7 @@ type opStatusWire struct { GalleryElementName string `json:"gallery_element_name"` Cancelled bool `json:"cancelled"` Cancellable bool `json:"cancellable"` + Paused bool `json:"paused"` Nodes []NodeProgress `json:"nodes,omitempty"` } @@ -102,6 +104,7 @@ func (o OpStatus) MarshalJSON() ([]byte, error) { GalleryElementName: o.GalleryElementName, Cancelled: o.Cancelled, Cancellable: o.Cancellable, + Paused: o.Paused, Nodes: o.Nodes, } if o.Error != nil { @@ -125,6 +128,7 @@ func (o *OpStatus) UnmarshalJSON(data []byte) error { o.GalleryElementName = w.GalleryElementName o.Cancelled = w.Cancelled o.Cancellable = w.Cancellable + o.Paused = w.Paused o.Nodes = w.Nodes if w.ErrorMessage != "" { o.Error = errors.New(w.ErrorMessage) @@ -161,6 +165,14 @@ type GalleryCancelEvent struct { JobID string `json:"id"` } +// GalleryPauseEvent is the NATS payload for a gallery pause. Mirroring the +// cancel pattern, the pause func may live on a different frontend replica; +// the broadcast subscriber applies the pause locally. A paused operation can +// be resumed later — the .partial file is preserved. +type GalleryPauseEvent struct { + JobID string `json:"id"` +} + // NodeStatus values shared between NodeProgress (per-node tick) and the // NodeOpStatus surfaced by DistributedBackendManager's fan-out. Defined // as exported constants so producers (the manager, the progress bridge) diff --git a/core/services/galleryop/service.go b/core/services/galleryop/service.go index d01d9cc19f21..7e0f084066d7 100644 --- a/core/services/galleryop/service.go +++ b/core/services/galleryop/service.go @@ -4,9 +4,12 @@ import ( "context" "errors" "fmt" + "os" + "path/filepath" "sync" "time" + "github.com/google/uuid" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/gallery" "github.com/mudler/LocalAI/core/services/distributed" @@ -27,7 +30,9 @@ type GalleryService struct { modelManager ModelManager backendManager BackendManager statuses map[string]*OpStatus - cancellations map[string]context.CancelFunc + cancellations map[string]context.CancelCauseFunc + pausedOps map[string]*PausedModelOp + rateLimiters map[string]*downloader.DynamicRateLimiter // Distributed mode (nil when not in distributed mode). // natsClient is the wider MessagingClient (Publisher + subscribe methods) @@ -66,7 +71,8 @@ func NewGalleryService(appConfig *config.ApplicationConfig, ml *model.ModelLoade modelManager: NewLocalModelManager(appConfig, ml), backendManager: NewLocalBackendManager(appConfig, ml), statuses: make(map[string]*OpStatus), - cancellations: make(map[string]context.CancelFunc), + cancellations: make(map[string]context.CancelCauseFunc), + rateLimiters: make(map[string]*downloader.DynamicRateLimiter), } } @@ -314,7 +320,7 @@ func (g *GalleryService) CancelOperation(id string) error { return fmt.Errorf("operation %q is already cancelled", id) } - cancelFunc, localExists := g.cancellations[id] + cancelCause, localExists := g.cancellations[id] if localExists { delete(g.cancellations, id) } @@ -355,8 +361,8 @@ func (g *GalleryService) CancelOperation(id string) error { // I/O and user-provided callback after Unlock — the cancel-wildcard // subscriber loops back into applyCancel on this same replica, which // would otherwise deadlock on g.Mutex. - if cancelFunc != nil { - cancelFunc() + if cancelCause != nil { + cancelCause(downloader.ErrUserCancelled) } if nc != nil { if err := nc.Publish(messaging.SubjectGalleryCancel(id), GalleryCancelEvent{JobID: id}); err != nil { @@ -374,7 +380,7 @@ func (g *GalleryService) CancelOperation(id string) error { // already cancelled this op locally treats the inbound event as a no-op. func (g *GalleryService) applyCancel(id string) { g.Lock() - cancelFunc, hasCancel := g.cancellations[id] + cancelCause, hasCancel := g.cancellations[id] if hasCancel { delete(g.cancellations, id) } @@ -399,22 +405,22 @@ func (g *GalleryService) applyCancel(id string) { // Invoke the cancel func after Unlock so a callback that touches // GalleryService doesn't re-enter the mutex. if hasCancel { - cancelFunc() + cancelCause(downloader.ErrUserCancelled) } } -// newUserCancellableContext returns a child context whose CancelFunc cancels -// with the downloader.ErrUserCancelled cause. This lets the download layer -// distinguish a deliberate user cancel (discard the half-downloaded .partial) -// from an incidental cancellation such as process shutdown (keep the .partial -// so the next run resumes via Range instead of restarting from zero). -func newUserCancellableContext(parent context.Context) (context.Context, context.CancelFunc) { +// newUserCancellableContext returns a child context whose CancelCauseFunc +// can be called with either downloader.ErrUserCancelled (discards .partial) +// or downloader.ErrUserPaused (preserves .partial for later resume). This +// lets the download layer distinguish deliberate user actions from incidental +// cancellations such as process shutdown. +func newUserCancellableContext(parent context.Context) (context.Context, context.CancelCauseFunc) { ctx, cancelCause := context.WithCancelCause(parent) - return ctx, func() { cancelCause(downloader.ErrUserCancelled) } + return ctx, cancelCause } // storeCancellation stores a cancellation function for an operation -func (g *GalleryService) storeCancellation(id string, cancelFunc context.CancelFunc) { +func (g *GalleryService) storeCancellation(id string, cancelFunc context.CancelCauseFunc) { g.Lock() defer g.Unlock() g.cancellations[id] = cancelFunc @@ -423,7 +429,7 @@ func (g *GalleryService) storeCancellation(id string, cancelFunc context.CancelF // StoreCancellation is a public method to store a cancellation function for an operation // This allows cancellation functions to be stored immediately when operations are created, // enabling cancellation of queued operations that haven't started processing yet. -func (g *GalleryService) StoreCancellation(id string, cancelFunc context.CancelFunc) { +func (g *GalleryService) StoreCancellation(id string, cancelFunc context.CancelCauseFunc) { g.storeCancellation(id, cancelFunc) } @@ -434,7 +440,298 @@ func (g *GalleryService) removeCancellation(id string) { delete(g.cancellations, id) } +// storePausedOp saves the paused operation metadata for a later Resume call. +func (g *GalleryService) storePausedOp(id string, op *PausedModelOp) { + g.Lock() + defer g.Unlock() + if g.pausedOps == nil { + g.pausedOps = make(map[string]*PausedModelOp) + } + g.pausedOps[id] = op +} + +// removePausedOp deletes the stored paused operation metadata. +func (g *GalleryService) removePausedOp(id string) { + g.Lock() + defer g.Unlock() + delete(g.pausedOps, id) +} + +// getPausedOpLocked retrieves the paused operation metadata without locking. +// Caller must hold g.Lock(). +func (g *GalleryService) getPausedOpLocked(id string) *PausedModelOp { + if g.pausedOps == nil { + return nil + } + return g.pausedOps[id] +} + +// storeRateLimiter stores a rate limiter for an active download operation. +func (g *GalleryService) storeRateLimiter(id string, rl *downloader.DynamicRateLimiter) { + g.Lock() + defer g.Unlock() + g.rateLimiters[id] = rl +} + +// removeRateLimiter removes the rate limiter for a completed operation. +func (g *GalleryService) removeRateLimiter(id string) { + g.Lock() + defer g.Unlock() + delete(g.rateLimiters, id) +} + +// PauseAllOperations pauses every active (non-paused, non-cancelled) model +// download. Each operation's context is cancelled with ErrUserPaused so the +// download layer preserves the .partial file. Broadcast is NOT sent for +// individual ops — the callers sees a single API response and the result is +// the same set of paused statuses regardless of which replica it hits. +func (g *GalleryService) PauseAllOperations() error { + g.Lock() + ids := make([]string, 0, len(g.cancellations)) + for id := range g.cancellations { + if status, ok := g.statuses[id]; ok { + if status.Paused || (status.Processed && status.Cancelled) { + continue + } + } + ids = append(ids, id) + } + g.Unlock() + + var errs []error + for _, id := range ids { + if err := g.PauseOperation(id); err != nil { + errs = append(errs, fmt.Errorf("op %q: %w", id, err)) + } + } + if len(errs) > 0 { + return fmt.Errorf("failed to pause %d/%d operations: %v", len(errs), len(ids), errors.Join(errs...)) + } + return nil +} + +// ResumeAllOperations resumes every paused model download by re-queuing +// each stored PausedModelOp. Operations that were paused before a restart +// (recovered from sidecar files) are included. +func (g *GalleryService) ResumeAllOperations() error { + g.Lock() + ids := make([]string, 0, len(g.pausedOps)) + for id, p := range g.pausedOps { + if p == nil { + continue + } + ids = append(ids, id) + } + g.Unlock() + + var errs []error + for _, id := range ids { + if err := g.ResumeOperation(id); err != nil { + errs = append(errs, fmt.Errorf("op %q: %w", id, err)) + } + } + if len(errs) > 0 { + return fmt.Errorf("failed to resume %d/%d operations: %v", len(errs), len(ids), errors.Join(errs...)) + } + return nil +} + +// SetOperationRateLimit overrides the download rate limit for an active +// operation. A value <= 0 removes the limit (unlimited). The format +// convenience (e.g. "2mb", "500kb") should be pre-parsed by the caller. +func (g *GalleryService) SetOperationRateLimit(id string, bytesPerSec int64) error { + g.Lock() + defer g.Unlock() + rl, ok := g.rateLimiters[id] + if !ok { + return fmt.Errorf("operation %q not found or does not support rate limiting", id) + } + rl.SetRate(bytesPerSec) + return nil +} + +// autoResumePausedDownloads scans the models directory for .partial.json +// sidecar files (written when a download was paused) and re-queues the +// corresponding model operations. This provides crash-resilience: even if +// the process restarted, paused downloads are automatically resumed. +func (g *GalleryService) autoResumePausedDownloads(systemState *system.SystemState) { + modelsPath := systemState.Model.ModelsPath + sidecarFiles, err := filepath.Glob(filepath.Join(modelsPath, "*.partial.json")) + if err != nil { + xlog.Warn("Failed to scan for download sidecar files", "path", modelsPath, "error", err) + return + } + for _, scPath := range sidecarFiles { + sc, err := downloader.ReadPartialSidecar(scPath) + if err != nil { + xlog.Warn("Failed to read download sidecar for auto-resume, removing", "file", scPath, "error", err) + _ = os.Remove(scPath) + continue + } + if sc.ModelID == "" { + xlog.Warn("Download sidecar missing model_id, removing", "file", scPath) + _ = os.Remove(scPath) + continue + } + + opID := uuid.New().String() + // Reconstruct a minimal GalleryModel. The gallery layer will re-fetch + // the full model config from the configured galleries by name. + req := gallery.GalleryModel{ + Metadata: gallery.Metadata{ + Name: sc.ModelID, + }, + } + + g.ModelGalleryChannel <- ManagementOp[gallery.GalleryModel, gallery.ModelConfig]{ + ID: opID, + GalleryElementName: sc.ModelID, + Req: req, + Galleries: g.appConfig.Galleries, + BackendGalleries: g.appConfig.BackendGalleries, + } + xlog.Info("Auto-resumed paused download", "model", sc.ModelID, "op_id", opID) + } +} + +// PauseOperation pauses an in-progress model download. The download layer +// preserves the .partial file so a subsequent ResumeOperation can continue +// from the saved byte offset. In distributed mode the pause event is +// broadcast so the peer replica holding the cancellation func can apply it. +func (g *GalleryService) PauseOperation(id string) error { + g.Lock() + + if status, ok := g.statuses[id]; ok && status.Paused { + g.Unlock() + return fmt.Errorf("operation %q is already paused", id) + } + if status, ok := g.statuses[id]; ok && status.Processed && status.Cancelled { + g.Unlock() + return fmt.Errorf("operation %q is already cancelled, cannot pause", id) + } + + cancelCause, localExists := g.cancellations[id] + if !localExists { + // The cancel func may live on a different replica in distributed mode. + nc := g.natsClient + if nc == nil { + g.Unlock() + return fmt.Errorf("operation %q not found or already completed", id) + } + // Broadcast to the peer that holds the cancel func. + g.Unlock() + if err := nc.Publish(messaging.SubjectGalleryPause(id), GalleryPauseEvent{JobID: id}); err != nil { + return fmt.Errorf("failed to broadcast pause for operation %q: %w", id, err) + } + return nil + } + + delete(g.cancellations, id) + + if status, ok := g.statuses[id]; ok { + status.Paused = true + status.Message = "paused" + } else { + g.statuses[id] = &OpStatus{ + Paused: true, + Message: "paused", + Cancellable: true, + } + } + g.Unlock() + + // Cancel the context with ErrUserPaused so the download layer preserves + // the .partial file. + cancelCause(downloader.ErrUserPaused) + + // Broadcast the pause event so peer replicas update their status maps. + nc := g.natsClient + if nc != nil { + if err := nc.Publish(messaging.SubjectGalleryPause(id), GalleryPauseEvent{JobID: id}); err != nil { + xlog.Warn("Failed to broadcast gallery pause", "op_id", id, "error", err) + } + } + + return nil +} + +// ResumeOperation resumes a previously paused model download. It re-creates +// the download context and pushes a fresh ManagementOp to the model channel, +// where the existing .partial file will be picked up automatically via Range. +func (g *GalleryService) ResumeOperation(id string) error { + g.Lock() + status, statusExists := g.statuses[id] + if !statusExists || !status.Paused { + g.Unlock() + return fmt.Errorf("operation %q is not paused", id) + } + + pausedOp := g.getPausedOpLocked(id) + if pausedOp == nil { + g.Unlock() + return fmt.Errorf("no paused operation metadata found for %q", id) + } + + // Remove the paused op metadata so a second Resume fails cleanly. + delete(g.pausedOps, id) + + // Reset the status: paused → downloading. + status.Paused = false + status.Processed = false + status.Message = "resuming download" + g.Unlock() + + // Push a new ManagementOp to the model channel. Start() will create a + // fresh context with newUserCancellableContext so the user can still + // cancel the resumed download. + g.ModelGalleryChannel <- ManagementOp[gallery.GalleryModel, gallery.ModelConfig]{ + ID: id, + GalleryElementName: pausedOp.GalleryElementName, + Req: pausedOp.Req, + Galleries: pausedOp.Galleries, + BackendGalleries: pausedOp.BackendGalleries, + // Context and CancelFunc are nil — Start() creates them. + } + + return nil +} + +// applyPause is the broadcast-side counterpart to PauseOperation. The +// wildcard subscriber calls it when a peer publishes a pause event: +// run the local cancel func if we have one, and reflect the pause in the +// local statuses map. +func (g *GalleryService) applyPause(id string) { + g.Lock() + cancelCause, hasCancel := g.cancellations[id] + if hasCancel { + delete(g.cancellations, id) + } + if status, ok := g.statuses[id]; ok { + if status.Paused { + g.Unlock() + return + } + status.Paused = true + status.Message = "paused" + } else { + g.statuses[id] = &OpStatus{ + Paused: true, + Message: "paused", + Cancellable: true, + } + } + g.Unlock() + + if hasCancel { + cancelCause(downloader.ErrUserPaused) + } +} + func (g *GalleryService) Start(c context.Context, cl *config.ModelConfigLoader, systemState *system.SystemState) error { + // Auto-resume downloads that were paused before a restart. Sidecar + // files persisted by the download layer survive process crashes. + g.autoResumePausedDownloads(systemState) + // updates the status with an error var updateError func(id string, e error) if !g.appConfig.OpaqueErrors { @@ -455,10 +752,15 @@ func (g *GalleryService) Start(c context.Context, cl *config.ModelConfigLoader, case op := <-g.BackendGalleryChannel: // Create context if not provided if op.Context == nil { - op.Context, op.CancelFunc = newUserCancellableContext(c) - g.storeCancellation(op.ID, op.CancelFunc) + op.Context, cancelCause := newUserCancellableContext(c) + op.CancelFunc = func() { cancelCause(downloader.ErrUserCancelled) } + g.storeCancellation(op.ID, cancelCause) } else if op.CancelFunc != nil { - g.storeCancellation(op.ID, op.CancelFunc) + // The caller provided a CancelFunc; wrap it as a CancelCauseFunc + // that we can also use for pause. We store the wrapped version + // that always cancels with ErrUserCancelled. + cc := op.CancelFunc + g.storeCancellation(op.ID, func(error) { cc() }) } // Create DB record for distributed tracking if g.galleryStore != nil { @@ -484,11 +786,18 @@ func (g *GalleryService) Start(c context.Context, cl *config.ModelConfigLoader, case op := <-g.ModelGalleryChannel: // Create context if not provided if op.Context == nil { - op.Context, op.CancelFunc = newUserCancellableContext(c) - g.storeCancellation(op.ID, op.CancelFunc) + op.Context, cancelCause := newUserCancellableContext(c) + op.CancelFunc = func() { cancelCause(downloader.ErrUserCancelled) } + g.storeCancellation(op.ID, cancelCause) } else if op.CancelFunc != nil { - g.storeCancellation(op.ID, op.CancelFunc) + cc := op.CancelFunc + g.storeCancellation(op.ID, func(error) { cc() }) } + // Attach a dynamic rate limiter so the download can be throttled + // at runtime via SetOperationRateLimit. + rl := &downloader.DynamicRateLimiter{} + op.Context = downloader.ContextWithRateLimiter(op.Context, rl) + g.storeRateLimiter(op.ID, rl) // Create DB record for distributed tracking if g.galleryStore != nil { opType := "model_install" @@ -509,6 +818,7 @@ func (g *GalleryService) Start(c context.Context, cl *config.ModelConfigLoader, updateError(op.ID, err) } g.removeCancellation(op.ID) + g.removeRateLimiter(op.ID) } } }() @@ -554,6 +864,22 @@ func (g *GalleryService) SubscribeBroadcasts() error { return fmt.Errorf("subscribing to gallery cancel wildcard: %w", err) } + pauseSub, err := messaging.SubscribeJSON(nc, messaging.SubjectGalleryPauseWildcard, func(evt GalleryPauseEvent) { + if evt.JobID == "" { + return + } + g.applyPause(evt.JobID) + }) + if err != nil { + if uerr := progressSub.Unsubscribe(); uerr != nil { + xlog.Warn("failed to unsubscribe partial gallery progress sub", "error", uerr) + } + if uerr := cancelSub.Unsubscribe(); uerr != nil { + xlog.Warn("failed to unsubscribe partial gallery cancel sub", "error", uerr) + } + return fmt.Errorf("subscribing to gallery pause wildcard: %w", err) + } + modelsSub, err := messaging.SubscribeJSON(nc, messaging.SubjectCacheInvalidateModels, func(evt messaging.CacheInvalidateEvent) { g.Lock() cb := g.OnModelsChanged @@ -569,6 +895,9 @@ func (g *GalleryService) SubscribeBroadcasts() error { if uerr := cancelSub.Unsubscribe(); uerr != nil { xlog.Warn("failed to unsubscribe partial gallery cancel sub", "error", uerr) } + if uerr := pauseSub.Unsubscribe(); uerr != nil { + xlog.Warn("failed to unsubscribe partial gallery pause sub", "error", uerr) + } return fmt.Errorf("subscribing to models invalidation: %w", err) } @@ -589,6 +918,9 @@ func (g *GalleryService) SubscribeBroadcasts() error { if uerr := cancelSub.Unsubscribe(); uerr != nil { xlog.Warn("failed to unsubscribe partial gallery cancel sub", "error", uerr) } + if uerr := pauseSub.Unsubscribe(); uerr != nil { + xlog.Warn("failed to unsubscribe partial gallery pause sub", "error", uerr) + } if uerr := modelsSub.Unsubscribe(); uerr != nil { xlog.Warn("failed to unsubscribe partial models sub", "error", uerr) } @@ -596,7 +928,7 @@ func (g *GalleryService) SubscribeBroadcasts() error { } g.Lock() - g.broadcastSubs = append(g.broadcastSubs, progressSub, cancelSub, modelsSub, backendsSub) + g.broadcastSubs = append(g.broadcastSubs, progressSub, cancelSub, pauseSub, modelsSub, backendsSub) g.Unlock() return nil } diff --git a/core/services/messaging/subjects.go b/core/services/messaging/subjects.go index 7d099460c3f0..decd4cd716a6 100644 --- a/core/services/messaging/subjects.go +++ b/core/services/messaging/subjects.go @@ -105,6 +105,7 @@ const ( SubjectJobProgressWildcard = "jobs.*.progress" SubjectAgentCancelWildcard = "agent.*.cancel" SubjectGalleryCancelWildcard = "gallery.*.cancel" + SubjectGalleryPauseWildcard = "gallery.*.pause" SubjectGalleryProgressWildcard = "gallery.*.progress" ) @@ -128,6 +129,11 @@ func SubjectGalleryCancel(opID string) string { return subjectGalleryCancelPrefix + sanitizeSubjectToken(opID) + ".cancel" } +// SubjectGalleryPause returns the NATS subject to pause a gallery download. +func SubjectGalleryPause(opID string) string { + return subjectGalleryCancelPrefix + sanitizeSubjectToken(opID) + ".pause" +} + // Node Backend Lifecycle (Pub/Sub — targeted to specific nodes) // // These subjects control the backend *process* lifecycle on a serve-backend node, diff --git a/pkg/downloader/pause_test.go b/pkg/downloader/pause_test.go new file mode 100644 index 000000000000..3a6d368414a6 --- /dev/null +++ b/pkg/downloader/pause_test.go @@ -0,0 +1,185 @@ +package downloader_test + +import ( + "context" + "crypto/rand" + "crypto/sha256" + "errors" + "fmt" + "net/http" + "net/http/httptest" + "os" + "strconv" + "strings" + "time" + + . "github.com/mudler/LocalAI/pkg/downloader" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("Download pause and resume", func() { + var filePath string + + pauseRangeServer := func(data []byte) *httptest.Server { + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == "HEAD" { + w.Header().Set("Accept-Ranges", "bytes") + w.WriteHeader(http.StatusOK) + return + } + start := 0 + if rh := r.Header.Get("Range"); rh != "" { + _, _ = fmt.Sscanf(strings.TrimPrefix(rh, "bytes="), "%d-", &start) + } + w.Header().Set("Content-Length", strconv.Itoa(len(data)-start)) + if start > 0 { + w.Header().Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d", start, len(data)-1, len(data))) + w.WriteHeader(http.StatusPartialContent) + } else { + w.WriteHeader(http.StatusOK) + } + f, _ := w.(http.Flusher) + for i := start; i < len(data); i += 256 { + end := i + 256 + if end > len(data) { + end = len(data) + } + if _, err := w.Write(data[i:end]); err != nil { + return + } + if f != nil { + f.Flush() + } + time.Sleep(20 * time.Millisecond) + } + })) + } + + BeforeEach(func() { + dir, err := os.Getwd() + Expect(err).ToNot(HaveOccurred()) + filePath = dir + "/pause_model" + }) + + AfterEach(func() { + _ = os.Remove(filePath) + _ = os.Remove(filePath + ".partial") + }) + + It("preserves the .partial when paused with ErrUserPaused (critical: no delete)", func() { + data := make([]byte, 8192) + _, err := rand.Read(data) + Expect(err).ToNot(HaveOccurred()) + server := pauseRangeServer(data) + defer server.Close() + + ctx, cancel := context.WithCancelCause(context.Background()) + go func() { + time.Sleep(150 * time.Millisecond) + cancel(ErrUserPaused) + }() + + err = URI(server.URL).DownloadFileWithContext(ctx, filePath, "", 1, 1, func(s1, s2, s3 string, f float64) {}) + Expect(err).To(HaveOccurred()) + Expect(errors.Is(err, ErrUserPaused)).To(BeTrue(), "should return ErrUserPaused, not context.Canceled") + + info, statErr := os.Stat(filePath + ".partial") + Expect(statErr).ToNot(HaveOccurred(), + "CRITICAL: .partial must exist after pause — a deleted .partial means resume is impossible") + Expect(info.Size()).To(BeNumerically(">", 0)) + Expect(info.Size()).To(BeNumerically("<", int64(len(data)))) + }) + + It("resumes a paused download from the .partial offset and completes with correct SHA", func() { + data := make([]byte, 16384) + _, err := rand.Read(data) + Expect(err).ToNot(HaveOccurred()) + sum := sha256.Sum256(data) + sha := fmt.Sprintf("%x", sum) + server := pauseRangeServer(data) + defer server.Close() + + // First attempt: pause mid-stream with ErrUserPaused. + ctx, cancel := context.WithCancelCause(context.Background()) + go func() { + time.Sleep(150 * time.Millisecond) + cancel(ErrUserPaused) + }() + err = URI(server.URL).DownloadFileWithContext(ctx, filePath, sha, 1, 1, func(s1, s2, s3 string, f float64) {}) + Expect(err).To(HaveOccurred()) + Expect(errors.Is(err, ErrUserPaused)).To(BeTrue()) + + partialInfo, statErr := os.Stat(filePath + ".partial") + Expect(statErr).ToNot(HaveOccurred()) + resumedFrom := partialInfo.Size() + Expect(resumedFrom).To(BeNumerically(">", 0)) + + // Second attempt: fresh context, must resume via Range and verify SHA. + err = URI(server.URL).DownloadFileWithContext(context.Background(), filePath, sha, 1, 1, func(s1, s2, s3 string, f float64) {}) + Expect(err).ToNot(HaveOccurred()) + + final, rerr := os.ReadFile(filePath) + Expect(rerr).ToNot(HaveOccurred()) + Expect(final).To(Equal(data)) + }) + + It("pauses and resumes multiple times", func() { + data := make([]byte, 32768) + _, err := rand.Read(data) + Expect(err).ToNot(HaveOccurred()) + sum := sha256.Sum256(data) + sha := fmt.Sprintf("%x", sum) + server := pauseRangeServer(data) + defer server.Close() + + var prevSize int64 + + for i := 0; i < 3; i++ { + ctx, cancel := context.WithCancelCause(context.Background()) + go func() { + time.Sleep(100 * time.Millisecond) + cancel(ErrUserPaused) + }() + + err := URI(server.URL).DownloadFileWithContext(ctx, filePath, sha, 1, 1, func(s1, s2, s3 string, f float64) {}) + Expect(err).To(HaveOccurred()) + Expect(errors.Is(err, ErrUserPaused)).To(BeTrue()) + + info, statErr := os.Stat(filePath + ".partial") + Expect(statErr).ToNot(HaveOccurred(), "round %d: .partial must survive pause", i) + Expect(info.Size()).To(BeNumerically(">", prevSize), + "round %d: download must have made progress after resume", i) + prevSize = info.Size() + } + + // Final resume: complete the download. + err = URI(server.URL).DownloadFileWithContext(context.Background(), filePath, sha, 1, 1, func(s1, s2, s3 string, f float64) {}) + Expect(err).ToNot(HaveOccurred()) + + final, rerr := os.ReadFile(filePath) + Expect(rerr).ToNot(HaveOccurred()) + Expect(final).To(Equal(data)) + }) + + It("returns ErrUserPaused not context.Canceled (caller can distinguish)", func() { + data := make([]byte, 4096) + _, err := rand.Read(data) + Expect(err).ToNot(HaveOccurred()) + server := pauseRangeServer(data) + defer server.Close() + + ctx, cancel := context.WithCancelCause(context.Background()) + go func() { + time.Sleep(100 * time.Millisecond) + cancel(ErrUserPaused) + }() + + err = URI(server.URL).DownloadFileWithContext(ctx, filePath, "", 1, 1, func(s1, s2, s3 string, f float64) {}) + Expect(err).To(HaveOccurred()) + Expect(errors.Is(err, ErrUserPaused)).To(BeTrue(), + "must return ErrUserPaused so GalleryService can distinguish pause from failure") + Expect(errors.Is(err, context.Canceled)).To(BeFalse(), + "must NOT return context.Canceled — the caller would treat it as a system cancel, not a user pause") + }) +}) diff --git a/pkg/downloader/ratelimit.go b/pkg/downloader/ratelimit.go new file mode 100644 index 000000000000..77d46dfccb90 --- /dev/null +++ b/pkg/downloader/ratelimit.go @@ -0,0 +1,109 @@ +package downloader + +import ( + "context" + "io" + "sync" + "time" +) + +// DynamicRateLimiter implements a token-bucket rate limiter whose rate can +// be changed at runtime. A zero-value limiter is unlimited (no waiting). +// All methods are safe for concurrent use. +type DynamicRateLimiter struct { + mu sync.Mutex + rate float64 // bytes per second; 0 means unlimited + tokens float64 + lastTime time.Time +} + +// SetRate changes the target rate in bytes per second. A value <= 0 means +// unlimited (the Wait method becomes a no-op). +func (d *DynamicRateLimiter) SetRate(bytesPerSec int64) { + d.mu.Lock() + defer d.mu.Unlock() + d.rate = float64(bytesPerSec) + if d.rate <= 0 { + d.tokens = 0 + d.lastTime = time.Time{} + } +} + +// Wait blocks until a token is available for one byte, honouring ctx +// cancellation. It returns nil immediately when the rate is unlimited or +// when the context is already done. +func (d *DynamicRateLimiter) Wait(ctx context.Context) error { + d.mu.Lock() + rate := d.rate + if rate <= 0 { + d.mu.Unlock() + return nil + } + + now := time.Now() + if d.lastTime.IsZero() { + d.lastTime = now + d.tokens = rate // start fully charged + } + + // Refill tokens based on elapsed time since last call. + elapsed := now.Sub(d.lastTime).Seconds() + d.tokens += elapsed * rate + if d.tokens > rate { + d.tokens = rate + } + + if d.tokens >= 1 { + d.tokens-- + d.lastTime = now + d.mu.Unlock() + return nil + } + + // How long until we have at least one token? + waitDur := time.Duration((1 - d.tokens) / rate * float64(time.Second)) + d.lastTime = now + d.tokens = 0 + d.mu.Unlock() + + select { + case <-time.After(waitDur): + return nil + case <-ctx.Done(): + return ctx.Err() + } +} + +// rateLimitedReader wraps an io.ReadCloser with a DynamicRateLimiter so that +// reads respect the configured byte-per-second rate. +type rateLimitedReader struct { + inner io.ReadCloser + rl *DynamicRateLimiter +} + +func newRateLimitedReader(inner io.ReadCloser, rl *DynamicRateLimiter) io.ReadCloser { + return &rateLimitedReader{inner: inner, rl: rl} +} + +func (r *rateLimitedReader) Read(p []byte) (int, error) { + if r.rl == nil { + return r.inner.Read(p) + } + // Throttle byte-by-byte so bursty reads don't exceed the budget. + // An alternative would be to release all n bytes at once, but that + // would allow a large burst up to the buffer size. + for i := 0; i < len(p); i++ { + if err := r.rl.Wait(context.Background()); err != nil { + return i, err + } + n, err := r.inner.Read(p[i : i+1]) + if n == 0 { + return i, err + } + } + return len(p), nil +} + +func (r *rateLimitedReader) Close() error { + return r.inner.Close() +} diff --git a/pkg/downloader/sidecar.go b/pkg/downloader/sidecar.go new file mode 100644 index 000000000000..394bffd07cb2 --- /dev/null +++ b/pkg/downloader/sidecar.go @@ -0,0 +1,79 @@ +package downloader + +import ( + "context" + "encoding/json" + "os" + "time" +) + +type dlCtxKey string + +const ( + ctxKeyModelID dlCtxKey = "model_id" + ctxKeyModelURL dlCtxKey = "model_url" + ctxKeyRateLimiter dlCtxKey = "rate_limiter" +) + +// ContextWithRateLimiter attaches a DynamicRateLimiter to ctx so +// DownloadFileWithContext can throttle the download speed. +func ContextWithRateLimiter(ctx context.Context, rl *DynamicRateLimiter) context.Context { + return context.WithValue(ctx, ctxKeyRateLimiter, rl) +} + +// RateLimiterFromContext returns the DynamicRateLimiter attached to ctx, or +// nil if none is set. +func RateLimiterFromContext(ctx context.Context) *DynamicRateLimiter { + if rl, ok := ctx.Value(ctxKeyRateLimiter).(*DynamicRateLimiter); ok { + return rl + } + return nil +} + +// PartialSidecar is the metadata written alongside a .partial file when a +// download is paused. It survives restarts so the auto-resume boot hook can +// reconstruct the download operation. +type PartialSidecar struct { + URL string `json:"url"` + ModelID string `json:"model_id"` + PausedAt string `json:"paused_at"` +} + +// ContextWithModelID attaches a model identifier to ctx so the download +// layer can include it in the sidecar when paused. +func ContextWithModelID(ctx context.Context, modelID string) context.Context { + return context.WithValue(ctx, ctxKeyModelID, modelID) +} + +// WritePartialSidecar writes a sidecar JSON file next to the .partial file. +func WritePartialSidecar(partialPath, url, modelID string) error { + sc := PartialSidecar{ + URL: url, + ModelID: modelID, + PausedAt: time.Now().UTC().Format(time.RFC3339), + } + data, err := json.Marshal(sc) + if err != nil { + return err + } + return os.WriteFile(partialPath+".json", data, 0644) +} + +// RemovePartialSidecar deletes the sidecar file next to the .partial path. +// No error is returned if the file does not exist. +func RemovePartialSidecar(partialPath string) { + _ = os.Remove(partialPath + ".json") +} + +// ReadPartialSidecar reads and parses a .partial.json file. +func ReadPartialSidecar(path string) (*PartialSidecar, error) { + data, err := os.ReadFile(path) + if err != nil { + return nil, err + } + var sc PartialSidecar + if err := json.Unmarshal(data, &sc); err != nil { + return nil, err + } + return &sc, nil +} diff --git a/pkg/downloader/uri.go b/pkg/downloader/uri.go index 41bdbe672038..90d9a90a6ae8 100644 --- a/pkg/downloader/uri.go +++ b/pkg/downloader/uri.go @@ -342,6 +342,13 @@ func (s URI) ResolveURL() string { // so the next run resumes via Range instead of restarting from zero. var ErrUserCancelled = errors.New("download cancelled by user") +// ErrUserPaused is returned when the download is paused by the user. Unlike +// ErrUserCancelled, the .partial file is preserved so a subsequent Resume +// call can continue from where it left off via Range. Callers should check +// for this error with errors.Is and treat it as a transient suspension, not a +// terminal failure. +var ErrUserPaused = errors.New("download paused by user") + func removePartialFile(tmpFilePath string) error { xlog.Debug("Removing temporary file", "file", tmpFilePath) if err := os.Remove(tmpFilePath); err != nil && !errors.Is(err, os.ErrNotExist) { @@ -349,9 +356,24 @@ func removePartialFile(tmpFilePath string) error { xlog.Warn("failed to remove temporary download file", "error", err1) return err1 } + // Also clean up the sidecar metadata file if present. + RemovePartialSidecar(tmpFilePath) return nil } +// writePartialSidecarMetadata writes the .partial.json sidecar next to a +// paused download. The model_id is extracted from ctx (set by the gallery +// layer); if absent the sidecar carries "unknown". +func writePartialSidecarMetadata(ctx context.Context, tmpFilePath, url string) { + modelID, _ := ctx.Value(ctxKeyModelID).(string) + if modelID == "" { + modelID = "unknown" + } + if err := WritePartialSidecar(tmpFilePath, url, modelID); err != nil { + xlog.Warn("Failed to write download sidecar", "file", tmpFilePath+".json", "error", err) + } +} + func calculateHashForPartialFile(file *os.File) (hash.Hash, error) { hash := sha256.New() _, err := io.Copy(hash, file) @@ -613,9 +635,14 @@ func (uri URI) DownloadFileWithContext(ctx context.Context, filePath, sha string // take long enough that deleting progress means they never finish — // but discard it on a deliberate user abort (ErrUserCancelled). if ctx.Err() != nil { - if errors.Is(context.Cause(ctx), ErrUserCancelled) { + cause := context.Cause(ctx) + if errors.Is(cause, ErrUserCancelled) { _ = removePartialFile(tmpFilePath) } + if errors.Is(cause, ErrUserPaused) { + writePartialSidecarMetadata(ctx, tmpFilePath, url) + return ErrUserPaused + } return ctx.Err() } return fmt.Errorf("failed to download file %q: %v", filePath, err) @@ -635,6 +662,11 @@ func (uri URI) DownloadFileWithContext(ctx context.Context, filePath, sha string } contentLength = resp.ContentLength } + // Wrap with a rate limiter if one is attached to the context. The limiter + // is shared and dynamically adjustable, so reads honour the latest rate. + if rl, ok := ctx.Value(ctxKeyRateLimiter).(*DynamicRateLimiter); ok { + source = newRateLimitedReader(source, rl) + } defer source.Close() // Create parent directory @@ -671,9 +703,14 @@ func (uri URI) DownloadFileWithContext(ctx context.Context, filePath, sha string // it. A stall-guard abort leaves ctx uncancelled, so it falls through // to the error path below and likewise preserves the partial. if ctx.Err() != nil { - if errors.Is(context.Cause(ctx), ErrUserCancelled) { + cause := context.Cause(ctx) + if errors.Is(cause, ErrUserCancelled) { _ = removePartialFile(tmpFilePath) } + if errors.Is(cause, ErrUserPaused) { + writePartialSidecarMetadata(ctx, tmpFilePath, url) + return ErrUserPaused + } return ctx.Err() } return fmt.Errorf("failed to write file %q: %v", filePath, err) @@ -683,9 +720,14 @@ func (uri URI) DownloadFileWithContext(ctx context.Context, filePath, sha string // unless the user deliberately aborted. select { case <-ctx.Done(): - if errors.Is(context.Cause(ctx), ErrUserCancelled) { + cause := context.Cause(ctx) + if errors.Is(cause, ErrUserCancelled) { _ = removePartialFile(tmpFilePath) } + if errors.Is(cause, ErrUserPaused) { + writePartialSidecarMetadata(ctx, tmpFilePath, url) + return ErrUserPaused + } return ctx.Err() default: } @@ -713,6 +755,8 @@ func (uri URI) DownloadFileWithContext(ctx context.Context, filePath, sha string if err != nil { return fmt.Errorf("failed to rename temporary file %s -> %s: %v", tmpFilePath, filePath, err) } + // Download succeeded — the .partial is gone, so the sidecar is stale. + RemovePartialSidecar(tmpFilePath) xlog.Info("File downloaded and verified", "file", filePath) if utils.IsArchive(filePath) {