Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
183 changes: 117 additions & 66 deletions internal/app/train.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,26 +21,26 @@ type trainSnapshot struct {
// Used for demo: the user can type natural phrases instead of exact commands.
var trainTextAliases = map[string]string{
// start / rerun
"run it": "start",
"run the training": "start",
"rerun": "start",
"rerun training": "start",
"rerun experiment": "start",
"go": "start",
"launch": "start",
"start it": "start",
"start it up": "start",
"run again": "start",
"run it again": "start",
"begin": "start",
"begin training": "start",
"let's go": "start",
"let's run it": "start",
"kick it off": "start",
"proceed": "start",
"execute": "start",
"run experiment": "start",
"start the run": "start",
"run it": "start",
"run the training": "start",
"rerun": "start",
"rerun training": "start",
"rerun experiment": "start",
"go": "start",
"launch": "start",
"start it": "start",
"start it up": "start",
"run again": "start",
"run it again": "start",
"begin": "start",
"begin training": "start",
"let's go": "start",
"let's run it": "start",
"kick it off": "start",
"proceed": "start",
"execute": "start",
"run experiment": "start",
"start the run": "start",
// analyze / diagnose
"analysis": "analyze",
"what went wrong": "analyze",
Expand All @@ -61,39 +61,39 @@ var trainTextAliases = map[string]string{
"explain the failure": "analyze",
"figure it out": "analyze",
// diagnose (explicit)
"diagnose it": "diagnose",
"find the issue": "diagnose",
"root cause": "diagnose",
"diagnose the issue": "diagnose",
"diagnose it": "diagnose",
"find the issue": "diagnose",
"root cause": "diagnose",
"diagnose the issue": "diagnose",
// retry
"try again": "retry",
"one more time": "retry",
"retry it": "retry",
"try again": "retry",
"one more time": "retry",
"retry it": "retry",
// apply fix (confirmation words like "yes"/"ok"/"do it" are
// handled in the UI layer — they fire the current focused button)
"fix it": "apply fix",
"apply the fix": "apply fix",
"patch it": "apply fix",
"apply": "apply fix",
"apply patch": "apply fix",
"apply the change": "apply fix",
"make the change": "apply fix",
"fix it": "apply fix",
"apply the fix": "apply fix",
"patch it": "apply fix",
"apply": "apply fix",
"apply patch": "apply fix",
"apply the change": "apply fix",
"make the change": "apply fix",
// analyze perf
"check performance": "analyze perf",
"profile it": "analyze perf",
"why is it slow": "analyze perf",
"check perf": "analyze perf",
"perf analysis": "analyze perf",
"optimize": "analyze perf",
"optimize it": "analyze perf",
"make it faster": "analyze perf",
"speed it up": "analyze perf",
"check throughput": "analyze perf",
"check speed": "analyze perf",
"profile": "analyze perf",
"tune performance": "analyze perf",
"bottleneck": "analyze perf",
"why slow": "analyze perf",
"check performance": "analyze perf",
"profile it": "analyze perf",
"why is it slow": "analyze perf",
"check perf": "analyze perf",
"perf analysis": "analyze perf",
"optimize": "analyze perf",
"optimize it": "analyze perf",
"make it faster": "analyze perf",
"speed it up": "analyze perf",
"check throughput": "analyze perf",
"check speed": "analyze perf",
"profile": "analyze perf",
"tune performance": "analyze perf",
"bottleneck": "analyze perf",
"why slow": "analyze perf",
// algo-feature
"add mhc": "add algo-feature mhc",
"try mhc": "add algo-feature mhc",
Expand Down Expand Up @@ -123,12 +123,12 @@ var trainTextAliases = map[string]string{
"boost perf": "add perf-feature fa2",
"optimize perf": "add perf-feature fa2",
// stop
"cancel": "stop",
"abort": "stop",
"stop it": "stop",
"halt": "stop",
"kill it": "stop",
"stop training": "stop",
"cancel": "stop",
"abort": "stop",
"stop it": "stop",
"halt": "stop",
"kill it": "stop",
"stop training": "stop",
}

type bootstrapRunState struct {
Expand Down Expand Up @@ -159,11 +159,11 @@ func (a *Application) cmdTrain(args []string) {
Backend: train.BackendSSHHost,
Name: "torch-npu-910b-0",
Config: map[string]any{
"address": "8.9.72.194:22",
"env_manager": "docker",
"demo_ssh_flaky": true,
"demo_libs_missing": true,
"demo_fail_at_step": 50,
"address": "8.9.72.194:22",
"env_manager": "docker",
"demo_ssh_flaky": true,
"demo_libs_missing": true,
"demo_fail_at_step": 0,
},
},
}
Expand Down Expand Up @@ -325,11 +325,11 @@ func (a *Application) runApplyFix(ctx context.Context, runID uint64, req train.R
return
}

// Fix succeeded — clear failure injection so rerun won't fail again.
// Fix succeeded — disable failure injection so rerun won't fail again.
a.trainMu.Lock()
if a.trainReq != nil {
if a.trainReq.Target.Config != nil {
delete(a.trainReq.Target.Config, "demo_fail_at_step")
a.trainReq.Target.Config["demo_fail_at_step"] = -1
}
if issueType == "accuracy" || issueType == "performance" {
if a.trainReq.Target.Config == nil {
Expand All @@ -344,7 +344,7 @@ func (a *Application) runApplyFix(ctx context.Context, runID uint64, req train.R
}
if r, ok := a.trainReqs[a.trainCurrentRun]; ok {
if r.Target.Config != nil {
delete(r.Target.Config, "demo_fail_at_step")
r.Target.Config["demo_fail_at_step"] = -1
}
if issueType == "accuracy" || issueType == "performance" {
if r.Target.Config == nil {
Expand Down Expand Up @@ -387,7 +387,7 @@ func (a *Application) handleTrainInput(input string) {
// Gate all other commands on the current phase.
switch {
case lower == "start" || lower == "start training":
if snapshot.phase != "ready" && snapshot.phase != "completed" {
if snapshot.phase != "ready" && snapshot.phase != "completed" && snapshot.phase != "stopped" {
a.rejectCommand("start", "setup must complete first")
return
}
Expand All @@ -412,6 +412,14 @@ func (a *Application) handleTrainInput(input string) {
a.rejectCommand("analyze performance", "performance analysis needs a finished or active run")
return
}
// Check if perf fix already applied
if a.isPerfFixed() {
a.EventCh <- model.Event{
Type: model.AgentReply,
Message: "Performance optimization already applied. Training is using fused_adam kernel.",
}
return
}
a.setTrainIssueType("performance")
a.analyzeTraining()

Expand All @@ -434,6 +442,14 @@ func (a *Application) handleTrainInput(input string) {
a.rejectCommand("add algo-feature", "workspace must be stable before algo-feature iteration")
return
}
// Check if algo-feature already applied
if a.isAlgoFeatureApplied() {
a.EventCh <- model.Event{
Type: model.AgentReply,
Message: "MHC algo-feature already applied. Training is using multi-head contrastive loss.",
}
return
}
a.addAlgoFeature(strings.TrimSpace(strings.TrimPrefix(lower, "add algo-feature")))

case strings.HasPrefix(lower, "add perf-feature"):
Expand Down Expand Up @@ -479,8 +495,9 @@ func (a *Application) startTraining() {
func (a *Application) stopTraining() {
a.stopTrainTask("stopped")
a.EventCh <- model.Event{
Type: model.AgentReply,
Message: "training stopped.",
Type: model.TrainStopped,
Message: "Training stopped.",
Train: &model.TrainEventData{Status: "stopped"},
}
}

Expand Down Expand Up @@ -1292,3 +1309,37 @@ func (a *Application) bootstrapApplied(runID string) map[string]bool {
}
return out
}

// isPerfFixed checks if the performance fix (fused_adam) has already been applied.
func (a *Application) isPerfFixed() bool {
a.trainMu.RLock()
defer a.trainMu.RUnlock()
if a.trainReq != nil && a.trainReq.Target.Config != nil {
if fixed, ok := a.trainReq.Target.Config["demo_perf_fixed"].(bool); ok && fixed {
return true
}
}
if r, ok := a.trainReqs[a.trainCurrentRun]; ok && r.Target.Config != nil {
if fixed, ok := r.Target.Config["demo_perf_fixed"].(bool); ok && fixed {
return true
}
}
return false
}

// isAlgoFeatureApplied checks if the algo-feature (MHC) has already been applied.
func (a *Application) isAlgoFeatureApplied() bool {
a.trainMu.RLock()
defer a.trainMu.RUnlock()
if a.trainReq != nil && a.trainReq.Target.Config != nil {
if applied, ok := a.trainReq.Target.Config["demo_trick_applied"].(bool); ok && applied {
return true
}
}
if r, ok := a.trainReqs[a.trainCurrentRun]; ok && r.Target.Config != nil {
if applied, ok := r.Target.Config["demo_trick_applied"].(bool); ok && applied {
return true
}
}
return false
}
20 changes: 15 additions & 5 deletions ui/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,10 @@ func agentMsg(source, msg string, done bool) string {
}

var (
diffAddStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("114")) // green
diffRemoveStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("196")) // red
diffHunkStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("12")) // blue
diffFileStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("220")).Bold(true)
diffAddStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("114")) // green
diffRemoveStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("196")) // red
diffHunkStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("12")) // blue
diffFileStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("220")).Bold(true)
diffContextStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("244")) // dim
diffSummaryStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("252")).Bold(true)
)
Expand Down Expand Up @@ -1584,7 +1584,17 @@ func (a App) renderTrainLayout(topBar string) string {

// Maximized panel — full screen.
if panel, ok := a.trainView.MaximizedPanel(); ok {
body := panels.RenderTrainWorkspacePanel(panel, a.trainView, w, a.trainBodyHeight())
var body string
if panel == model.TrainPanelAgent {
// Agent panel uses the viewport content for display.
agentSpinner := ""
if status := a.agentStatus(); status != "" {
agentSpinner = a.thinking.FrameView() + " " + status
}
body = panels.RenderAgentBox(a.viewport.View(), w, a.trainBodyHeight(), true, a.viewport.TotalLines(), a.viewport.YOffset(), agentSpinner)
} else {
body = panels.RenderTrainWorkspacePanel(panel, a.trainView, w, a.trainBodyHeight())
}
input := " " + a.input.View()
hintBar := panels.RenderTrainHintBar(w, a.trainFocus, true)
return trimViewHeight(lipgloss.JoinVertical(lipgloss.Left,
Expand Down
17 changes: 14 additions & 3 deletions ui/model/train.go
Original file line number Diff line number Diff line change
Expand Up @@ -367,9 +367,9 @@ type TrainWorkspaceState struct {
SetupContext SetupContext
TrainPlan *TrainPlan
RunConfig *RunConfig
ActiveRunID string
Compare *CompareViewState
Hosts []TrainHostView
ActiveRunID string
Compare *CompareViewState
Hosts []TrainHostView

Panels map[TrainPanelID]*PanelDisplayState
Focus TrainPanelID
Expand Down Expand Up @@ -428,6 +428,11 @@ func NewTrainWorkspaceState() *TrainWorkspaceState {
Collapsed: true,
Maximized: false,
},
TrainPanelAgent: {
Focused: false,
Collapsed: false,
Maximized: false,
},
TrainPanelCompare: {
Focused: false,
Collapsed: true,
Expand Down Expand Up @@ -808,12 +813,18 @@ func (s *TrainWorkspaceState) RefreshActions() {
s.GlobalActions.Items = []TrainAction{
{ID: "stop", Label: "stop", Enabled: true, Primary: true},
}
case TrainPhaseStopped:
s.GlobalActions.Items = []TrainAction{
{ID: "rerun", Label: "rerun", Enabled: true, Primary: true},
{ID: "close", Label: "exit", Enabled: true},
}
case TrainPhaseCompleted:
items := []TrainAction{
{ID: "rerun", Label: "rerun", Enabled: true, Primary: true},
{ID: "analyze_perf", Label: "analyze perf", Enabled: true},
{ID: "add_algo_feature", Label: "algo-feature", Enabled: true},
{ID: "add_perf_feature", Label: "perf-feature", Enabled: true},
{ID: "close", Label: "exit", Enabled: true},
}
s.GlobalActions.Items = items
default:
Expand Down
4 changes: 4 additions & 0 deletions ui/panels/train_workspace.go
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,10 @@ func panelBody(id model.TrainPanelID, tv model.TrainWorkspaceState, width, heigh
return renderMetricsPanel(tv, width, height)
case model.TrainPanelLogs:
return renderLogsPanel(tv, width, height)
case model.TrainPanelAgent:
// Agent panel body is rendered separately via RenderAgentBox with viewport content.
// When maximized, we show a placeholder indicating content is displayed via viewport.
return ""
default:
return ""
}
Expand Down
6 changes: 5 additions & 1 deletion workflow/train/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,12 @@ func (c *Controller) Open(_ context.Context, req itrain.Request) *Session {
Phase: "setup",
}
// Read demo failure injection config.
if failStep, ok := req.Target.Config["demo_fail_at_step"].(int); ok && failStep > 0 {
// -1 = disabled, >= 0 = fail at specific step (0 = fail immediately)
if failStep, ok := req.Target.Config["demo_fail_at_step"].(int); ok {
s.FailAtStep = failStep
} else {
// Default to -1 (disabled) if not set.
s.FailAtStep = -1
}
if fixed, ok := req.Target.Config["demo_drift_fixed"].(bool); ok && fixed {
s.DriftFixed = true
Expand Down
6 changes: 3 additions & 3 deletions workflow/train/demo.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ func (d *DemoBackend) Run(ctx context.Context, session *Session, sink func(Event
}

// runTraining plays demo training: loading, log lines, metrics, completion.
// If failStep > 0, training crashes at that step with a DSA operator error.
// If failStep >= 0, training crashes at that step with a DSA operator error.
// If perfFixed is true, throughput is ~10% higher (fused adam kernel).
// If trickApplied is true, loss converges faster (MHC contrastive loss).
func (d *DemoBackend) runTraining(ctx context.Context, model, method string, sink func(Event), failStep int, perfFixed, trickApplied bool) error {
Expand Down Expand Up @@ -168,8 +168,8 @@ func (d *DemoBackend) runTraining(ctx context.Context, model, method string, sin
totalSteps := 300
for _, s := range steps {
// Failure injection: crash at the specified step.
if failStep > 0 && s.step >= failStep {
crashMsg := fmt.Sprintf("[%s] FATAL: operator init failed at step %d — DSA operator (torch.ops.npu.dsa) not implemented in torch 2.7 for Ascend backend", runID, s.step)
if failStep >= 0 && s.step >= failStep {
crashMsg := fmt.Sprintf("[%s] FATAL: operator init failed at step %d — DSA operator (torch.ops.npu.dsa) not implemented in torch 2.7 for Ascend backend", runID, failStep)
e(Event{Kind: EventLogLine, Message: crashMsg, DelayMs: 200})
return fmt.Errorf("operator init failed: DSA operator not implemented in torch 2.7")
}
Expand Down
2 changes: 1 addition & 1 deletion workflow/train/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ type Session struct {
Phase string // setup, ready, running, completed, failed, stopped

// FailAtStep causes the demo backend to crash training at this step.
// 0 means no failure injection.
// -1 means no failure injection.
FailAtStep int

// DriftFixed indicates the accuracy fix has been applied,
Expand Down