diff --git a/internal/migration/service.go b/internal/migration/service.go index fc7b01f..bcbcf22 100644 --- a/internal/migration/service.go +++ b/internal/migration/service.go @@ -12,6 +12,8 @@ import ( "log/slog" "math" "os" + "sync" + "time" "github.com/libp2p/go-libp2p/core/host" "github.com/libp2p/go-libp2p/core/network" @@ -27,6 +29,23 @@ import ( // MigrateProtocol is the libp2p protocol ID for agent migration. const MigrateProtocol protocol.ID = "/igor/migrate/1.0.0" +type managedAgent struct { + instance *agent.Instance + cancel context.CancelFunc + done chan struct{} + closeOnce sync.Once +} + +func (m *managedAgent) close(ctx context.Context) error { + var closeErr error + m.closeOnce.Do(func() { + if m.instance != nil { + closeErr = m.instance.Close(ctx) + } + }) + return closeErr +} + // Service coordinates agent migration between nodes. type Service struct { host host.Host @@ -34,8 +53,9 @@ type Service struct { storageProvider storage.Provider logger *slog.Logger + mu sync.RWMutex // Active agents running on this node - activeAgents map[string]*agent.Instance + activeAgents map[string]*managedAgent } // NewService creates a new migration service. @@ -50,7 +70,7 @@ func NewService( runtimeEngine: engine, storageProvider: storage, logger: logger, - activeAgents: make(map[string]*agent.Instance), + activeAgents: make(map[string]*managedAgent), } // Register migration protocol handler @@ -163,12 +183,9 @@ func (s *Service) MigrateAgent( "target_node", started.NodeID, ) - // Terminate local instance if exists - if instance, exists := s.activeAgents[agentID]; exists { - if err := instance.Close(ctx); err != nil { - s.logger.Error("Failed to close local instance", "error", err) - } - delete(s.activeAgents, agentID) + // Terminate local instance if this process currently runs the agent. + if managed, exists := s.getManagedAgent(agentID); exists { + s.stopManagedAgent(ctx, agentID, managed) s.logger.Info("Local agent instance terminated", "agent_id", agentID) } @@ -202,6 +219,11 @@ func (s *Service) handleIncomingMigration(stream network.Stream) { } pkg := transfer.Package + if pkg.AgentID == "" { + s.sendStartConfirmation(stream, "", false, "agent_id is required") + return + } + s.logger.Info("Agent package received", "agent_id", pkg.AgentID, "wasm_size", len(pkg.WASMBinary), @@ -217,13 +239,27 @@ func (s *Service) handleIncomingMigration(stream network.Stream) { return } - // Write WASM binary to temporary file - wasmPath := fmt.Sprintf("/tmp/igor-agent-%s.wasm", pkg.AgentID) - if err := os.WriteFile(wasmPath, pkg.WASMBinary, 0644); err != nil { + // Write WASM binary to a secure temporary file. + tmpFile, err := os.CreateTemp("", "igor-agent-*.wasm") + if err != nil { + s.logger.Error("Failed to create temp WASM file", "error", err) + s.sendStartConfirmation(stream, pkg.AgentID, false, err.Error()) + return + } + wasmPath := tmpFile.Name() + defer os.Remove(wasmPath) + + if _, err := tmpFile.Write(pkg.WASMBinary); err != nil { + _ = tmpFile.Close() s.logger.Error("Failed to write WASM binary", "error", err) s.sendStartConfirmation(stream, pkg.AgentID, false, err.Error()) return } + if err := tmpFile.Close(); err != nil { + s.logger.Error("Failed to close temp WASM file", "error", err) + s.sendStartConfirmation(stream, pkg.AgentID, false, err.Error()) + return + } // Load agent with budget from package instance, err := agent.LoadAgent( @@ -245,7 +281,7 @@ func (s *Service) handleIncomingMigration(stream network.Stream) { // Initialize agent if err := instance.Init(ctx); err != nil { s.logger.Error("Failed to initialize agent", "error", err) - instance.Close(ctx) + _ = instance.Close(ctx) s.sendStartConfirmation(stream, pkg.AgentID, false, err.Error()) return } @@ -253,13 +289,18 @@ func (s *Service) handleIncomingMigration(stream network.Stream) { // Resume from checkpoint if err := instance.LoadCheckpointFromStorage(ctx); err != nil { s.logger.Error("Failed to resume agent", "error", err) - instance.Close(ctx) + _ = instance.Close(ctx) s.sendStartConfirmation(stream, pkg.AgentID, false, err.Error()) return } - // Store as active agent - s.activeAgents[pkg.AgentID] = instance + // Start target-side execution loop and register as active. + if err := s.startManagedAgentLoop(pkg.AgentID, instance); err != nil { + s.logger.Error("Failed to start migrated agent", "error", err) + _ = instance.Close(ctx) + s.sendStartConfirmation(stream, pkg.AgentID, false, err.Error()) + return + } s.logger.Info("Agent migration accepted and started", "agent_id", pkg.AgentID, @@ -270,6 +311,132 @@ func (s *Service) handleIncomingMigration(stream network.Stream) { s.sendStartConfirmation(stream, pkg.AgentID, true, "") } +func (s *Service) startManagedAgentLoop(agentID string, instance *agent.Instance) error { + agentCtx, cancel := context.WithCancel(context.Background()) + managed := &managedAgent{ + instance: instance, + cancel: cancel, + done: make(chan struct{}), + } + + if err := s.registerManagedAgent(agentID, managed); err != nil { + cancel() + close(managed.done) + return err + } + + go s.runManagedAgentLoop(agentCtx, agentID, managed) + return nil +} + +func (s *Service) runManagedAgentLoop(ctx context.Context, agentID string, managed *managedAgent) { + defer close(managed.done) + defer s.unregisterManagedAgent(agentID, managed) + + ticker := time.NewTicker(1 * time.Second) + defer ticker.Stop() + + checkpointTicker := time.NewTicker(5 * time.Second) + defer checkpointTicker.Stop() + + s.logger.Info("Starting migrated agent tick loop", "agent_id", agentID) + + for { + select { + case <-ctx.Done(): + if err := managed.instance.SaveCheckpointToStorage(context.Background()); err != nil { + s.logger.Error("Failed to save checkpoint on agent stop", "agent_id", agentID, "error", err) + } + if err := managed.close(context.Background()); err != nil { + s.logger.Error("Failed to close agent instance", "agent_id", agentID, "error", err) + } + s.logger.Info("Stopped migrated agent tick loop", "agent_id", agentID) + return + + case <-ticker.C: + if err := managed.instance.Tick(ctx); err != nil { + if managed.instance.Budget <= 0 { + s.logger.Info("Migrated agent budget exhausted, terminating", + "agent_id", agentID, + "reason", "budget_exhausted", + ) + } else { + s.logger.Error("Migrated agent tick failed", "agent_id", agentID, "error", err) + } + + if saveErr := managed.instance.SaveCheckpointToStorage(context.Background()); saveErr != nil { + s.logger.Error("Failed to save checkpoint on agent termination", "agent_id", agentID, "error", saveErr) + } + if closeErr := managed.close(context.Background()); closeErr != nil { + s.logger.Error("Failed to close agent instance", "agent_id", agentID, "error", closeErr) + } + return + } + + case <-checkpointTicker.C: + if err := managed.instance.SaveCheckpointToStorage(ctx); err != nil { + s.logger.Error("Failed to save periodic checkpoint", "agent_id", agentID, "error", err) + } + } + } +} + +func (s *Service) stopManagedAgent(ctx context.Context, agentID string, managed *managedAgent) { + if managed.cancel != nil { + managed.cancel() + } + + if managed.done != nil { + select { + case <-managed.done: + case <-ctx.Done(): + case <-time.After(2 * time.Second): + s.logger.Warn("Timed out waiting for agent loop shutdown", "agent_id", agentID) + } + } + + if err := managed.close(context.Background()); err != nil { + s.logger.Error("Failed to close local instance", "agent_id", agentID, "error", err) + } + + s.unregisterManagedAgent(agentID, managed) +} + +func (s *Service) registerManagedAgent(agentID string, managed *managedAgent) error { + s.mu.Lock() + defer s.mu.Unlock() + + if _, exists := s.activeAgents[agentID]; exists { + return fmt.Errorf("agent %s is already active on this node", agentID) + } + + s.activeAgents[agentID] = managed + return nil +} + +func (s *Service) getManagedAgent(agentID string) (*managedAgent, bool) { + s.mu.RLock() + defer s.mu.RUnlock() + + managed, exists := s.activeAgents[agentID] + return managed, exists +} + +func (s *Service) unregisterManagedAgent(agentID string, expected *managedAgent) { + s.mu.Lock() + defer s.mu.Unlock() + + current, exists := s.activeAgents[agentID] + if !exists { + return + } + if expected != nil && current != expected { + return + } + + delete(s.activeAgents, agentID) +} + // sendStartConfirmation sends an AgentStarted message. func (s *Service) sendStartConfirmation( stream io.Writer, @@ -292,12 +459,23 @@ func (s *Service) sendStartConfirmation( // RegisterAgent registers an actively running agent with the migration service. func (s *Service) RegisterAgent(agentID string, instance *agent.Instance) { - s.activeAgents[agentID] = instance + managed := &managedAgent{instance: instance} + if err := s.registerManagedAgent(agentID, managed); err != nil { + s.logger.Error("Failed to register agent with migration service", + "agent_id", agentID, + "error", err, + ) + return + } + s.logger.Info("Agent registered with migration service", "agent_id", agentID) } // GetActiveAgents returns the list of active agent IDs. func (s *Service) GetActiveAgents() []string { + s.mu.RLock() + defer s.mu.RUnlock() + agents := make([]string, 0, len(s.activeAgents)) for id := range s.activeAgents { agents = append(agents, id) diff --git a/internal/storage/fs_provider.go b/internal/storage/fs_provider.go index 8d466a1..44d8617 100644 --- a/internal/storage/fs_provider.go +++ b/internal/storage/fs_provider.go @@ -6,6 +6,7 @@ import ( "log/slog" "os" "path/filepath" + "regexp" ) // FSProvider implements Provider using the local filesystem. @@ -14,6 +15,8 @@ type FSProvider struct { logger *slog.Logger } +var validAgentIDPattern = regexp.MustCompile(`^[A-Za-z0-9][A-Za-z0-9._-]{0,127}$`) + // NewFSProvider creates a new filesystem-based storage provider. // The baseDir will be created if it doesn't exist. func NewFSProvider(baseDir string, logger *slog.Logger) (*FSProvider, error) { @@ -37,7 +40,10 @@ func (p *FSProvider) SaveCheckpoint( agentID string, state []byte, ) error { - checkpointPath := p.checkpointPath(agentID) + checkpointPath, pathErr := p.checkpointPath(agentID) + if pathErr != nil { + return pathErr + } tempPath := checkpointPath + ".tmp" // Write to temporary file @@ -89,7 +95,10 @@ func (p *FSProvider) LoadCheckpoint( ctx context.Context, agentID string, ) ([]byte, error) { - checkpointPath := p.checkpointPath(agentID) + checkpointPath, pathErr := p.checkpointPath(agentID) + if pathErr != nil { + return nil, pathErr + } data, err := os.ReadFile(checkpointPath) if err != nil { @@ -113,7 +122,10 @@ func (p *FSProvider) DeleteCheckpoint( ctx context.Context, agentID string, ) error { - checkpointPath := p.checkpointPath(agentID) + checkpointPath, pathErr := p.checkpointPath(agentID) + if pathErr != nil { + return pathErr + } err := os.Remove(checkpointPath) if err != nil && !os.IsNotExist(err) { @@ -124,7 +136,18 @@ func (p *FSProvider) DeleteCheckpoint( return nil } +func validateAgentID(agentID string) error { + if !validAgentIDPattern.MatchString(agentID) { + return fmt.Errorf("invalid agent_id %q", agentID) + } + return nil +} + // checkpointPath returns the filesystem path for an agent's checkpoint. -func (p *FSProvider) checkpointPath(agentID string) string { - return filepath.Join(p.baseDir, agentID+".checkpoint") +func (p *FSProvider) checkpointPath(agentID string) (string, error) { + if err := validateAgentID(agentID); err != nil { + return "", err + } + + return filepath.Join(p.baseDir, agentID+".checkpoint"), nil }