Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions pkg/constants/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ const (
// LabelSupport indicates support status for a runtime, e.g. "deprecated".
LabelSupport string = "trainer.kubeflow.org/support"

// LabelJobName is the label to identify job-owned ConfigMap and Secret resources.
LabelJobName string = "trainer.kubeflow.org/trainjob-name"

// SupportDeprecated indicates the runtime is deprecated when used with LabelSupport.
SupportDeprecated string = "deprecated"

Expand Down
2 changes: 2 additions & 0 deletions pkg/runtime/core/trainingruntime_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1950,6 +1950,7 @@ func TestTrainingRuntimeNewObjects(t *testing.T) {
Obj(),
wantObjs: []runtime.Object{
testingutil.MakeConfigMapWrapper(fmt.Sprintf("test-job%s", constants.MPIHostfileConfigMapSuffix), metav1.NamespaceDefault).
WithLabels(map[string]string{constants.LabelJobName: "test-job"}).
ControllerReference(trainer.SchemeGroupVersion.WithKind(trainer.TrainJobKind), "test-job", "uid").
WithData(map[string]string{
constants.MPIHostfileName: `test-job-node-0-0.test-job slots=8
Expand All @@ -1958,6 +1959,7 @@ test-job-node-0-1.test-job slots=8
}).
Obj(),
testingutil.MakeSecretWrapper(fmt.Sprintf("test-job%s", constants.MPISSHAuthSecretSuffix), metav1.NamespaceDefault).
WithLabels(map[string]string{constants.LabelJobName: "test-job"}).
ControllerReference(trainer.SchemeGroupVersion.WithKind(trainer.TrainJobKind), "test-job", "uid").
WithImmutable(true).
WithData(map[string][]byte{
Expand Down
2 changes: 2 additions & 0 deletions pkg/runtime/framework/core/framework_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1297,6 +1297,7 @@ func TestRunComponentBuilderPlugins(t *testing.T) {
}).
Obj(),
testingutil.MakeSecretWrapper(fmt.Sprintf("test-job%s", constants.MPISSHAuthSecretSuffix), metav1.NamespaceDefault).
WithLabels(map[string]string{constants.LabelJobName: "test-job"}).
WithImmutable(true).
WithType(corev1.SecretTypeSSHAuth).
WithData(map[string][]byte{
Expand All @@ -1306,6 +1307,7 @@ func TestRunComponentBuilderPlugins(t *testing.T) {
ControllerReference(trainer.SchemeGroupVersion.WithKind(trainer.TrainJobKind), "test-job", "uid").
Obj(),
testingutil.MakeConfigMapWrapper(fmt.Sprintf("test-job%s", constants.MPIHostfileConfigMapSuffix), metav1.NamespaceDefault).
WithLabels(map[string]string{constants.LabelJobName: "test-job"}).
WithData(map[string]string{
constants.MPIHostfileName: `test-job-launcher-0-0.test-job slots=1
test-job-node-0-0.test-job slots=1
Expand Down
24 changes: 21 additions & 3 deletions pkg/runtime/framework/plugins/mpi/mpi.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import (

"golang.org/x/crypto/ssh"
corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
apiruntime "k8s.io/apimachinery/pkg/runtime"
"k8s.io/apimachinery/pkg/util/validation/field"
corev1ac "k8s.io/client-go/applyconfigurations/core/v1"
Expand All @@ -38,6 +39,7 @@ import (
"sigs.k8s.io/controller-runtime/pkg/cache"
"sigs.k8s.io/controller-runtime/pkg/client"
"sigs.k8s.io/controller-runtime/pkg/handler"
"sigs.k8s.io/controller-runtime/pkg/predicate"
"sigs.k8s.io/controller-runtime/pkg/webhook/admission"

configapi "github.com/kubeflow/trainer/v2/pkg/apis/config/v1alpha1"
Expand Down Expand Up @@ -218,19 +220,27 @@ func (m *MPI) EnforceMLPolicy(info *runtime.Info, trainJob *trainer.TrainJob) er
func (m *MPI) ReconcilerBuilders() []runtime.ReconcilerBuilder {
return []runtime.ReconcilerBuilder{
func(b *builder.Builder, cl client.Client, cache cache.Cache) *builder.Builder {
return b.Watches(
return b.WatchesMetadata(
&corev1.ConfigMap{},
handler.EnqueueRequestForOwner(
m.client.Scheme(), m.client.RESTMapper(), &trainer.TrainJob{}, handler.OnlyControllerOwner(),
),
builder.WithPredicates(predicate.NewPredicateFuncs(func(obj client.Object) bool {
_, ok := obj.GetLabels()[constants.LabelJobName]
return ok
})),
Comment on lines +228 to +231
Copy link

Copilot AI Mar 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new label-based predicate will drop events for existing MPI-owned Secrets/ConfigMaps created before this change (they have owner refs but won’t have the new label), so deletions/updates of those objects may no longer trigger TrainJob reconciliation after an upgrade.

Copilot uses AI. Check for mistakes.
)
Comment on lines +223 to 232
Copy link

Copilot AI Mar 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

builder.WithPredicates only filters which events enqueue reconciles; it does not apply a label selector to the underlying informer list/watch, so the cache will still store metadata for all Secrets/ConfigMaps cluster-wide (just without Data). If the intent is to scope the cache itself to MPI-owned objects, this needs a cache-level selector (e.g., manager cache ByObject/label selector) rather than a predicate alone.

Copilot uses AI. Check for mistakes.
},
func(b *builder.Builder, cl client.Client, cache cache.Cache) *builder.Builder {
return b.Watches(
return b.WatchesMetadata(
&corev1.Secret{},
handler.EnqueueRequestForOwner(
m.client.Scheme(), m.client.RESTMapper(), &trainer.TrainJob{}, handler.OnlyControllerOwner(),
),
builder.WithPredicates(predicate.NewPredicateFuncs(func(obj client.Object) bool {
_, ok := obj.GetLabels()[constants.LabelJobName]
return ok
})),
)
},
}
Expand All @@ -244,7 +254,9 @@ func (m *MPI) Build(ctx context.Context, info *runtime.Info, trainJob *trainer.T
var objects []apiruntime.ApplyConfiguration

// SSHAuthSecret is immutable.
if err := m.client.Get(ctx, client.ObjectKey{Name: sshAuthSecretName(trainJob.Name), Namespace: trainJob.Namespace}, &corev1.Secret{}); err != nil {
partialSecret := &metav1.PartialObjectMetadata{}
partialSecret.SetGroupVersionKind(corev1.SchemeGroupVersion.WithKind("Secret"))
if err := m.client.Get(ctx, client.ObjectKey{Name: sshAuthSecretName(trainJob.Name), Namespace: trainJob.Namespace}, partialSecret); err != nil {
if client.IgnoreNotFound(err) != nil {
return nil, err
}
Expand Down Expand Up @@ -275,6 +287,9 @@ func (m *MPI) buildSSHAuthSecret(trainJob *trainer.TrainJob) (*corev1ac.SecretAp
return nil, err
}
return corev1ac.Secret(sshAuthSecretName(trainJob.Name), trainJob.Namespace).
WithLabels(map[string]string{
constants.LabelJobName: trainJob.Name,
}).
WithType(corev1.SecretTypeSSHAuth).
WithData(map[string][]byte{
corev1.SSHAuthPrivateKey: privatePEM,
Expand Down Expand Up @@ -310,6 +325,9 @@ func (m *MPI) buildHostFileConfigMap(info *runtime.Info, trainJob *trainer.Train
}
}
return corev1ac.ConfigMap(fmt.Sprintf("%s%s", trainJob.Name, constants.MPIHostfileConfigMapSuffix), trainJob.Namespace).
WithLabels(map[string]string{
constants.LabelJobName: trainJob.Name,
}).
WithData(map[string]string{
constants.MPIHostfileName: hostFile.String(),
}).
Expand Down
12 changes: 11 additions & 1 deletion pkg/runtime/framework/plugins/mpi/mpi_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ func TestMPI(t *testing.T) {
},
wantObjs: []apiruntime.Object{
utiltesting.MakeSecretWrapper(fmt.Sprintf("trainJob%s", constants.MPISSHAuthSecretSuffix), metav1.NamespaceDefault).
WithLabels(map[string]string{constants.LabelJobName: "trainJob"}).
WithImmutable(true).
WithType(corev1.SecretTypeSSHAuth).
WithData(map[string][]byte{
Expand All @@ -212,6 +213,7 @@ func TestMPI(t *testing.T) {
ControllerReference(trainer.SchemeGroupVersion.WithKind(trainer.TrainJobKind), "trainJob", "trainJob").
Obj(),
utiltesting.MakeConfigMapWrapper(fmt.Sprintf("trainJob%s", constants.MPIHostfileConfigMapSuffix), metav1.NamespaceDefault).
WithLabels(map[string]string{constants.LabelJobName: "trainJob"}).
WithData(map[string]string{
constants.MPIHostfileName: `trainJob-node-1-0.trainJob slots=1
trainJob-node-1-1.trainJob slots=1
Expand Down Expand Up @@ -339,6 +341,7 @@ trainJob-node-1-1.trainJob slots=1
},
wantObjs: []apiruntime.Object{
utiltesting.MakeSecretWrapper(fmt.Sprintf("trainJob%s", constants.MPISSHAuthSecretSuffix), metav1.NamespaceDefault).
WithLabels(map[string]string{constants.LabelJobName: "trainJob"}).
WithImmutable(true).
WithType(corev1.SecretTypeSSHAuth).
WithData(map[string][]byte{
Expand All @@ -348,6 +351,7 @@ trainJob-node-1-1.trainJob slots=1
ControllerReference(trainer.SchemeGroupVersion.WithKind(trainer.TrainJobKind), "trainJob", "trainJob").
Obj(),
utiltesting.MakeConfigMapWrapper(fmt.Sprintf("trainJob%s", constants.MPIHostfileConfigMapSuffix), metav1.NamespaceDefault).
WithLabels(map[string]string{constants.LabelJobName: "trainJob"}).
WithData(map[string]string{
constants.MPIHostfileName: `trainJob-node-1-0.trainJob slots=2
`,
Expand Down Expand Up @@ -476,6 +480,7 @@ trainJob-node-1-1.trainJob slots=1
},
wantObjs: []apiruntime.Object{
utiltesting.MakeSecretWrapper(fmt.Sprintf("trainJob%s", constants.MPISSHAuthSecretSuffix), metav1.NamespaceDefault).
WithLabels(map[string]string{constants.LabelJobName: "trainJob"}).
WithImmutable(true).
WithType(corev1.SecretTypeSSHAuth).
WithData(map[string][]byte{
Expand All @@ -485,6 +490,7 @@ trainJob-node-1-1.trainJob slots=1
ControllerReference(trainer.SchemeGroupVersion.WithKind(trainer.TrainJobKind), "trainJob", "trainJob").
Obj(),
utiltesting.MakeConfigMapWrapper(fmt.Sprintf("trainJob%s", constants.MPIHostfileConfigMapSuffix), metav1.NamespaceDefault).
WithLabels(map[string]string{constants.LabelJobName: "trainJob"}).
WithData(map[string]string{
constants.MPIHostfileName: `trainJob-node-1-0.trainJob slots=5
`,
Expand Down Expand Up @@ -647,6 +653,7 @@ trainJob-node-1-1.trainJob slots=1
},
wantObjs: []apiruntime.Object{
utiltesting.MakeSecretWrapper(fmt.Sprintf("trainJob%s", constants.MPISSHAuthSecretSuffix), metav1.NamespaceDefault).
WithLabels(map[string]string{constants.LabelJobName: "trainJob"}).
WithImmutable(true).
WithType(corev1.SecretTypeSSHAuth).
WithData(map[string][]byte{
Expand All @@ -656,6 +663,7 @@ trainJob-node-1-1.trainJob slots=1
ControllerReference(trainer.SchemeGroupVersion.WithKind(trainer.TrainJobKind), "trainJob", "trainJob").
Obj(),
utiltesting.MakeConfigMapWrapper(fmt.Sprintf("trainJob%s", constants.MPIHostfileConfigMapSuffix), metav1.NamespaceDefault).
WithLabels(map[string]string{constants.LabelJobName: "trainJob"}).
WithData(map[string]string{
constants.MPIHostfileName: `trainJob-launcher-0-0.trainJob slots=1
trainJob-node-1-0.trainJob slots=1
Expand All @@ -668,6 +676,7 @@ trainJob-node-1-0.trainJob slots=1
"sshAuth secret already has existed in the cluster": {
objs: []client.Object{
utiltesting.MakeSecretWrapper(sshAuthSecretName("trainJob"), metav1.NamespaceDefault).
WithLabels(map[string]string{constants.LabelJobName: "trainJob"}).
ControllerReference(trainer.SchemeGroupVersion.WithKind(trainer.TrainJobKind), "trainJob", "trainJob").
WithImmutable(true).
Obj(),
Expand Down Expand Up @@ -746,6 +755,7 @@ trainJob-node-1-0.trainJob slots=1
},
wantObjs: []apiruntime.Object{
utiltesting.MakeConfigMapWrapper(fmt.Sprintf("trainJob%s", constants.MPIHostfileConfigMapSuffix), metav1.NamespaceDefault).
WithLabels(map[string]string{constants.LabelJobName: "trainJob"}).
WithData(map[string]string{
constants.MPIHostfileName: `trainJob-launcher-0-0.trainJob slots=1
`,
Expand Down Expand Up @@ -839,7 +849,7 @@ trainJob-node-1-0.trainJob slots=1
b := utiltesting.NewClientBuilder().WithObjects(tc.objs...)
b.WithInterceptorFuncs(interceptor.Funcs{
Get: func(ctx context.Context, client client.WithWatch, key client.ObjectKey, obj client.Object, opts ...client.GetOption) error {
if _, ok := obj.(*corev1.Secret); ok && errors.Is(tc.wantBuildError, errorGetSSHAuthSecretFromAPI) {
if _, ok := obj.(*metav1.PartialObjectMetadata); ok && errors.Is(tc.wantBuildError, errorGetSSHAuthSecretFromAPI) {
return errorGetSSHAuthSecretFromAPI
}
return client.Get(ctx, key, obj, opts...)
Expand Down
21 changes: 21 additions & 0 deletions pkg/util/testing/wrapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -1425,6 +1425,16 @@ func (c *ConfigMapWrapper) WithData(data map[string]string) *ConfigMapWrapper {
return c
}

func (c *ConfigMapWrapper) WithLabels(labels map[string]string) *ConfigMapWrapper {
if c.Labels == nil {
c.Labels = make(map[string]string, len(labels))
}
for k, v := range labels {
c.Labels[k] = v
}
return c
}

func (c *ConfigMapWrapper) ControllerReference(gvk schema.GroupVersionKind, name, uid string) *ConfigMapWrapper {
c.OwnerReferences = append(c.OwnerReferences, metav1.OwnerReference{
APIVersion: gvk.GroupVersion().String(),
Expand Down Expand Up @@ -1475,6 +1485,17 @@ func (s *SecretWrapper) WithData(data map[string][]byte) *SecretWrapper {
return s
}


func (s *SecretWrapper) WithLabels(labels map[string]string) *SecretWrapper {
if s.Labels == nil {
s.Labels = make(map[string]string, len(labels))
}
for k, v := range labels {
s.Labels[k] = v
}
return s
}

func (s *SecretWrapper) WithImmutable(immutable bool) *SecretWrapper {
s.Immutable = &immutable
return s
Expand Down
Loading