From 9728b49bd4eccc103707faccd302cd60d5ac09f2 Mon Sep 17 00:00:00 2001 From: townwish Date: Fri, 13 Mar 2026 17:12:37 +0800 Subject: [PATCH 1/2] fix: bugs of issue#1 --- internal/app/train.go | 183 ++++++++++++++++++++++------------- ui/app.go | 8 +- ui/model/train.go | 17 +++- ui/panels/train_workspace.go | 4 + workflow/train/controller.go | 6 +- workflow/train/demo.go | 6 +- workflow/train/types.go | 2 +- 7 files changed, 151 insertions(+), 75 deletions(-) diff --git a/internal/app/train.go b/internal/app/train.go index b6ee324..95527dc 100644 --- a/internal/app/train.go +++ b/internal/app/train.go @@ -21,26 +21,26 @@ type trainSnapshot struct { // Used for demo: the user can type natural phrases instead of exact commands. var trainTextAliases = map[string]string{ // start / rerun - "run it": "start", - "run the training": "start", - "rerun": "start", - "rerun training": "start", - "rerun experiment": "start", - "go": "start", - "launch": "start", - "start it": "start", - "start it up": "start", - "run again": "start", - "run it again": "start", - "begin": "start", - "begin training": "start", - "let's go": "start", - "let's run it": "start", - "kick it off": "start", - "proceed": "start", - "execute": "start", - "run experiment": "start", - "start the run": "start", + "run it": "start", + "run the training": "start", + "rerun": "start", + "rerun training": "start", + "rerun experiment": "start", + "go": "start", + "launch": "start", + "start it": "start", + "start it up": "start", + "run again": "start", + "run it again": "start", + "begin": "start", + "begin training": "start", + "let's go": "start", + "let's run it": "start", + "kick it off": "start", + "proceed": "start", + "execute": "start", + "run experiment": "start", + "start the run": "start", // analyze / diagnose "analysis": "analyze", "what went wrong": "analyze", @@ -61,39 +61,39 @@ var trainTextAliases = map[string]string{ "explain the failure": "analyze", "figure it out": "analyze", // diagnose (explicit) - "diagnose it": "diagnose", - "find the issue": "diagnose", - "root cause": "diagnose", - "diagnose the issue": "diagnose", + "diagnose it": "diagnose", + "find the issue": "diagnose", + "root cause": "diagnose", + "diagnose the issue": "diagnose", // retry - "try again": "retry", - "one more time": "retry", - "retry it": "retry", + "try again": "retry", + "one more time": "retry", + "retry it": "retry", // apply fix (confirmation words like "yes"/"ok"/"do it" are // handled in the UI layer — they fire the current focused button) - "fix it": "apply fix", - "apply the fix": "apply fix", - "patch it": "apply fix", - "apply": "apply fix", - "apply patch": "apply fix", - "apply the change": "apply fix", - "make the change": "apply fix", + "fix it": "apply fix", + "apply the fix": "apply fix", + "patch it": "apply fix", + "apply": "apply fix", + "apply patch": "apply fix", + "apply the change": "apply fix", + "make the change": "apply fix", // analyze perf - "check performance": "analyze perf", - "profile it": "analyze perf", - "why is it slow": "analyze perf", - "check perf": "analyze perf", - "perf analysis": "analyze perf", - "optimize": "analyze perf", - "optimize it": "analyze perf", - "make it faster": "analyze perf", - "speed it up": "analyze perf", - "check throughput": "analyze perf", - "check speed": "analyze perf", - "profile": "analyze perf", - "tune performance": "analyze perf", - "bottleneck": "analyze perf", - "why slow": "analyze perf", + "check performance": "analyze perf", + "profile it": "analyze perf", + "why is it slow": "analyze perf", + "check perf": "analyze perf", + "perf analysis": "analyze perf", + "optimize": "analyze perf", + "optimize it": "analyze perf", + "make it faster": "analyze perf", + "speed it up": "analyze perf", + "check throughput": "analyze perf", + "check speed": "analyze perf", + "profile": "analyze perf", + "tune performance": "analyze perf", + "bottleneck": "analyze perf", + "why slow": "analyze perf", // algo-feature "add mhc": "add algo-feature mhc", "try mhc": "add algo-feature mhc", @@ -123,12 +123,12 @@ var trainTextAliases = map[string]string{ "boost perf": "add perf-feature fa2", "optimize perf": "add perf-feature fa2", // stop - "cancel": "stop", - "abort": "stop", - "stop it": "stop", - "halt": "stop", - "kill it": "stop", - "stop training": "stop", + "cancel": "stop", + "abort": "stop", + "stop it": "stop", + "halt": "stop", + "kill it": "stop", + "stop training": "stop", } type bootstrapRunState struct { @@ -159,11 +159,11 @@ func (a *Application) cmdTrain(args []string) { Backend: train.BackendSSHHost, Name: "torch-npu-910b-0", Config: map[string]any{ - "address": "8.9.72.194:22", - "env_manager": "docker", - "demo_ssh_flaky": true, - "demo_libs_missing": true, - "demo_fail_at_step": 50, + "address": "8.9.72.194:22", + "env_manager": "docker", + "demo_ssh_flaky": true, + "demo_libs_missing": true, + "demo_fail_at_step": 0, }, }, } @@ -325,11 +325,11 @@ func (a *Application) runApplyFix(ctx context.Context, runID uint64, req train.R return } - // Fix succeeded — clear failure injection so rerun won't fail again. + // Fix succeeded — disable failure injection so rerun won't fail again. a.trainMu.Lock() if a.trainReq != nil { if a.trainReq.Target.Config != nil { - delete(a.trainReq.Target.Config, "demo_fail_at_step") + a.trainReq.Target.Config["demo_fail_at_step"] = -1 } if issueType == "accuracy" || issueType == "performance" { if a.trainReq.Target.Config == nil { @@ -344,7 +344,7 @@ func (a *Application) runApplyFix(ctx context.Context, runID uint64, req train.R } if r, ok := a.trainReqs[a.trainCurrentRun]; ok { if r.Target.Config != nil { - delete(r.Target.Config, "demo_fail_at_step") + r.Target.Config["demo_fail_at_step"] = -1 } if issueType == "accuracy" || issueType == "performance" { if r.Target.Config == nil { @@ -387,7 +387,7 @@ func (a *Application) handleTrainInput(input string) { // Gate all other commands on the current phase. switch { case lower == "start" || lower == "start training": - if snapshot.phase != "ready" && snapshot.phase != "completed" { + if snapshot.phase != "ready" && snapshot.phase != "completed" && snapshot.phase != "stopped" { a.rejectCommand("start", "setup must complete first") return } @@ -412,6 +412,14 @@ func (a *Application) handleTrainInput(input string) { a.rejectCommand("analyze performance", "performance analysis needs a finished or active run") return } + // Check if perf fix already applied + if a.isPerfFixed() { + a.EventCh <- model.Event{ + Type: model.AgentReply, + Message: "Performance optimization already applied. Training is using fused_adam kernel.", + } + return + } a.setTrainIssueType("performance") a.analyzeTraining() @@ -434,6 +442,14 @@ func (a *Application) handleTrainInput(input string) { a.rejectCommand("add algo-feature", "workspace must be stable before algo-feature iteration") return } + // Check if algo-feature already applied + if a.isAlgoFeatureApplied() { + a.EventCh <- model.Event{ + Type: model.AgentReply, + Message: "MHC algo-feature already applied. Training is using multi-head contrastive loss.", + } + return + } a.addAlgoFeature(strings.TrimSpace(strings.TrimPrefix(lower, "add algo-feature"))) case strings.HasPrefix(lower, "add perf-feature"): @@ -479,8 +495,9 @@ func (a *Application) startTraining() { func (a *Application) stopTraining() { a.stopTrainTask("stopped") a.EventCh <- model.Event{ - Type: model.AgentReply, - Message: "training stopped.", + Type: model.TrainStopped, + Message: "Training stopped.", + Train: &model.TrainEventData{Status: "stopped"}, } } @@ -1292,3 +1309,37 @@ func (a *Application) bootstrapApplied(runID string) map[string]bool { } return out } + +// isPerfFixed checks if the performance fix (fused_adam) has already been applied. +func (a *Application) isPerfFixed() bool { + a.trainMu.RLock() + defer a.trainMu.RUnlock() + if a.trainReq != nil && a.trainReq.Target.Config != nil { + if fixed, ok := a.trainReq.Target.Config["demo_perf_fixed"].(bool); ok && fixed { + return true + } + } + if r, ok := a.trainReqs[a.trainCurrentRun]; ok && r.Target.Config != nil { + if fixed, ok := r.Target.Config["demo_perf_fixed"].(bool); ok && fixed { + return true + } + } + return false +} + +// isAlgoFeatureApplied checks if the algo-feature (MHC) has already been applied. +func (a *Application) isAlgoFeatureApplied() bool { + a.trainMu.RLock() + defer a.trainMu.RUnlock() + if a.trainReq != nil && a.trainReq.Target.Config != nil { + if applied, ok := a.trainReq.Target.Config["demo_trick_applied"].(bool); ok && applied { + return true + } + } + if r, ok := a.trainReqs[a.trainCurrentRun]; ok && r.Target.Config != nil { + if applied, ok := r.Target.Config["demo_trick_applied"].(bool); ok && applied { + return true + } + } + return false +} diff --git a/ui/app.go b/ui/app.go index 2e37c06..beed35f 100644 --- a/ui/app.go +++ b/ui/app.go @@ -1584,7 +1584,13 @@ func (a App) renderTrainLayout(topBar string) string { // Maximized panel — full screen. if panel, ok := a.trainView.MaximizedPanel(); ok { - body := panels.RenderTrainWorkspacePanel(panel, a.trainView, w, a.trainBodyHeight()) + var body string + if panel == model.TrainPanelAgent { + // Agent panel uses the viewport content for display. + body = panels.RenderAgentBox(a.viewport.View(), w, a.trainBodyHeight(), true, a.viewport.TotalLines(), a.viewport.YOffset()) + } else { + body = panels.RenderTrainWorkspacePanel(panel, a.trainView, w, a.trainBodyHeight()) + } input := " " + a.input.View() hintBar := panels.RenderTrainHintBar(w, a.trainFocus, true) return trimViewHeight(lipgloss.JoinVertical(lipgloss.Left, diff --git a/ui/model/train.go b/ui/model/train.go index 71be0af..18e5a96 100644 --- a/ui/model/train.go +++ b/ui/model/train.go @@ -367,9 +367,9 @@ type TrainWorkspaceState struct { SetupContext SetupContext TrainPlan *TrainPlan RunConfig *RunConfig - ActiveRunID string - Compare *CompareViewState - Hosts []TrainHostView + ActiveRunID string + Compare *CompareViewState + Hosts []TrainHostView Panels map[TrainPanelID]*PanelDisplayState Focus TrainPanelID @@ -428,6 +428,11 @@ func NewTrainWorkspaceState() *TrainWorkspaceState { Collapsed: true, Maximized: false, }, + TrainPanelAgent: { + Focused: false, + Collapsed: false, + Maximized: false, + }, TrainPanelCompare: { Focused: false, Collapsed: true, @@ -808,12 +813,18 @@ func (s *TrainWorkspaceState) RefreshActions() { s.GlobalActions.Items = []TrainAction{ {ID: "stop", Label: "stop", Enabled: true, Primary: true}, } + case TrainPhaseStopped: + s.GlobalActions.Items = []TrainAction{ + {ID: "rerun", Label: "rerun", Enabled: true, Primary: true}, + {ID: "close", Label: "exit", Enabled: true}, + } case TrainPhaseCompleted: items := []TrainAction{ {ID: "rerun", Label: "rerun", Enabled: true, Primary: true}, {ID: "analyze_perf", Label: "analyze perf", Enabled: true}, {ID: "add_algo_feature", Label: "algo-feature", Enabled: true}, {ID: "add_perf_feature", Label: "perf-feature", Enabled: true}, + {ID: "close", Label: "exit", Enabled: true}, } s.GlobalActions.Items = items default: diff --git a/ui/panels/train_workspace.go b/ui/panels/train_workspace.go index b5dc994..8028cc5 100644 --- a/ui/panels/train_workspace.go +++ b/ui/panels/train_workspace.go @@ -482,6 +482,10 @@ func panelBody(id model.TrainPanelID, tv model.TrainWorkspaceState, width, heigh return renderMetricsPanel(tv, width, height) case model.TrainPanelLogs: return renderLogsPanel(tv, width, height) + case model.TrainPanelAgent: + // Agent panel body is rendered separately via RenderAgentBox with viewport content. + // When maximized, we show a placeholder indicating content is displayed via viewport. + return "" default: return "" } diff --git a/workflow/train/controller.go b/workflow/train/controller.go index a987352..96c3520 100644 --- a/workflow/train/controller.go +++ b/workflow/train/controller.go @@ -41,8 +41,12 @@ func (c *Controller) Open(_ context.Context, req itrain.Request) *Session { Phase: "setup", } // Read demo failure injection config. - if failStep, ok := req.Target.Config["demo_fail_at_step"].(int); ok && failStep > 0 { + // -1 = disabled, >= 0 = fail at specific step (0 = fail immediately) + if failStep, ok := req.Target.Config["demo_fail_at_step"].(int); ok { s.FailAtStep = failStep + } else { + // Default to -1 (disabled) if not set. + s.FailAtStep = -1 } if fixed, ok := req.Target.Config["demo_drift_fixed"].(bool); ok && fixed { s.DriftFixed = true diff --git a/workflow/train/demo.go b/workflow/train/demo.go index 179fd0a..3a42a46 100644 --- a/workflow/train/demo.go +++ b/workflow/train/demo.go @@ -70,7 +70,7 @@ func (d *DemoBackend) Run(ctx context.Context, session *Session, sink func(Event } // runTraining plays demo training: loading, log lines, metrics, completion. -// If failStep > 0, training crashes at that step with a DSA operator error. +// If failStep >= 0, training crashes at that step with a DSA operator error. // If perfFixed is true, throughput is ~10% higher (fused adam kernel). // If trickApplied is true, loss converges faster (MHC contrastive loss). func (d *DemoBackend) runTraining(ctx context.Context, model, method string, sink func(Event), failStep int, perfFixed, trickApplied bool) error { @@ -168,8 +168,8 @@ func (d *DemoBackend) runTraining(ctx context.Context, model, method string, sin totalSteps := 300 for _, s := range steps { // Failure injection: crash at the specified step. - if failStep > 0 && s.step >= failStep { - crashMsg := fmt.Sprintf("[%s] FATAL: operator init failed at step %d — DSA operator (torch.ops.npu.dsa) not implemented in torch 2.7 for Ascend backend", runID, s.step) + if failStep >= 0 && s.step >= failStep { + crashMsg := fmt.Sprintf("[%s] FATAL: operator init failed at step %d — DSA operator (torch.ops.npu.dsa) not implemented in torch 2.7 for Ascend backend", runID, failStep) e(Event{Kind: EventLogLine, Message: crashMsg, DelayMs: 200}) return fmt.Errorf("operator init failed: DSA operator not implemented in torch 2.7") } diff --git a/workflow/train/types.go b/workflow/train/types.go index ba290c1..e1ba63d 100644 --- a/workflow/train/types.go +++ b/workflow/train/types.go @@ -101,7 +101,7 @@ type Session struct { Phase string // setup, ready, running, completed, failed, stopped // FailAtStep causes the demo backend to crash training at this step. - // 0 means no failure injection. + // -1 means no failure injection. FailAtStep int // DriftFixed indicates the accuracy fix has been applied, From dcfe40c81a77b126e97fd90c33464a5d275811e0 Mon Sep 17 00:00:00 2001 From: townwish Date: Mon, 16 Mar 2026 09:07:46 +0800 Subject: [PATCH 2/2] fix: bugfix --- ui/app.go | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/ui/app.go b/ui/app.go index beed35f..198d36f 100644 --- a/ui/app.go +++ b/ui/app.go @@ -49,10 +49,10 @@ func agentMsg(source, msg string, done bool) string { } var ( - diffAddStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("114")) // green - diffRemoveStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("196")) // red - diffHunkStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("12")) // blue - diffFileStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("220")).Bold(true) + diffAddStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("114")) // green + diffRemoveStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("196")) // red + diffHunkStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("12")) // blue + diffFileStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("220")).Bold(true) diffContextStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("244")) // dim diffSummaryStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("252")).Bold(true) ) @@ -1587,7 +1587,11 @@ func (a App) renderTrainLayout(topBar string) string { var body string if panel == model.TrainPanelAgent { // Agent panel uses the viewport content for display. - body = panels.RenderAgentBox(a.viewport.View(), w, a.trainBodyHeight(), true, a.viewport.TotalLines(), a.viewport.YOffset()) + agentSpinner := "" + if status := a.agentStatus(); status != "" { + agentSpinner = a.thinking.FrameView() + " " + status + } + body = panels.RenderAgentBox(a.viewport.View(), w, a.trainBodyHeight(), true, a.viewport.TotalLines(), a.viewport.YOffset(), agentSpinner) } else { body = panels.RenderTrainWorkspacePanel(panel, a.trainView, w, a.trainBodyHeight()) }