Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions manifests/base/crds/kubeflow.org_pytorchjobs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7324,6 +7324,9 @@ spec:
format: int32
type: integer
type: object
successPolicy:
description: SuccessPolicy is the success policy.
type: string
required:
- pytorchReplicaSpecs
type: object
Expand Down
9 changes: 9 additions & 0 deletions pkg/apis/pytorch/v1/common.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package v1

// SuccessPolicy is the success policy.
type SuccessPolicy string

const (
SuccessPolicyDefault SuccessPolicy = ""
SuccessPolicyAllWorkers SuccessPolicy = "AllWorkers"
)
12 changes: 10 additions & 2 deletions pkg/apis/pytorch/v1/defaults.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,13 +106,21 @@ func setTypeNameToCamelCase(job *PyTorchJob, typ common.ReplicaType) {
}
}

// SetDefaults_PyTorchJob sets any unspecified values to defaults.
func SetDefaults_PyTorchJob(job *PyTorchJob) {
func SetDefaultRunPolicy(job *PyTorchJob) {
// Set default cleanpod policy to None.
if job.Spec.RunPolicy.CleanPodPolicy == nil {
policy := common.CleanPodPolicyNone
job.Spec.RunPolicy.CleanPodPolicy = &policy
}
if job.Spec.SuccessPolicy == nil {
policy := SuccessPolicyDefault
job.Spec.SuccessPolicy = &policy
}
}

// SetDefaults_PyTorchJob sets any unspecified values to defaults.
func SetDefaults_PyTorchJob(job *PyTorchJob) {
SetDefaultRunPolicy(job)

// Update the key of PyTorchReplicaSpecs to camel case.
setTypeNamesToCamelCase(job)
Expand Down
8 changes: 7 additions & 1 deletion pkg/apis/pytorch/v1/openapi_generated.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions pkg/apis/pytorch/v1/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ type PyTorchJobSpec struct {
//+kubebuilder:validation:Optional
RunPolicy common.RunPolicy `json:"runPolicy"`

SuccessPolicy *SuccessPolicy `json:"successPolicy,omitempty"`

ElasticPolicy *ElasticPolicy `json:"elasticPolicy,omitempty"`

// A map of PyTorchReplicaType (type) to ReplicaSpec (value). Specifies the PyTorch cluster configuration.
Expand Down
5 changes: 5 additions & 0 deletions pkg/apis/pytorch/v1/zz_generated.deepcopy.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 12 additions & 0 deletions pkg/common/util/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,15 @@ func GetSchedulerName(replicas map[commonv1.ReplicaType]*commonv1.ReplicaSpec) s
}
return ""
}

// GetContainerExitCode gets the container exit code from the given pod.
func GetContainerExitCode(pod *corev1.Pod, name string) int32 {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

name refers the containerName?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes.

var exitCode int32 = 0xbeef // magic number
for _, status := range pod.Status.ContainerStatuses {
state := status.State
if status.Name == name && state.Terminated != nil {
exitCode = state.Terminated.ExitCode
}
}
return exitCode
}
70 changes: 64 additions & 6 deletions pkg/controller.v1/pytorch/pytorchjob_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package pytorch
import (
"context"
"fmt"
"strings"

"github.com/go-logr/logr"
commonv1 "github.com/kubeflow/common/pkg/apis/common/v1"
Expand Down Expand Up @@ -326,6 +327,8 @@ func (r *PyTorchJobReconciler) UpdateJobStatus(job interface{},
return fmt.Errorf("%+v is not a type of PyTorchJob", job)
}

logger := commonutil.LoggerForJob(pytorchjob)

for rtype, spec := range replicas {
status := jobStatus.ReplicaStatuses[rtype]
if status.LabelSelector == nil {
Expand All @@ -338,7 +341,7 @@ func (r *PyTorchJobReconciler) UpdateJobStatus(job interface{},
running := status.Active
failed := status.Failed

logrus.Infof("PyTorchJob=%s, ReplicaType=%s expected=%d, running=%d, succeeded=%d , failed=%d",
logger.Infof("PyTorchJob=%s, ReplicaType=%s expected=%d, running=%d, succeeded=%d , failed=%d",
pytorchjob.Name, rtype, expected, running, succeeded, failed)

if ContainsMasterSpec(replicas) {
Expand All @@ -347,32 +350,40 @@ func (r *PyTorchJobReconciler) UpdateJobStatus(job interface{},
msg := fmt.Sprintf("PyTorchJob %s is running.", pytorchjob.Name)
err := commonutil.UpdateJobConditions(jobStatus, commonv1.JobRunning, commonutil.JobRunningReason, msg)
if err != nil {
commonutil.LoggerForJob(pytorchjob).Infof("Append job condition error: %v", err)
logger.Infof("Append job condition error: %v", err)
return err
}
}
// when master is succeed, the job is finished.
if expected == 0 {
msg := fmt.Sprintf("PyTorchJob %s is successfully completed.", pytorchjob.Name)
logrus.Info(msg)
logger.Info(msg)
r.Recorder.Event(pytorchjob, corev1.EventTypeNormal, commonutil.JobSucceededReason, msg)
if jobStatus.CompletionTime == nil {
now := metav1.Now()
jobStatus.CompletionTime = &now
}
err := commonutil.UpdateJobConditions(jobStatus, commonv1.JobSucceeded, commonutil.JobSucceededReason, msg)
if err != nil {
commonutil.LoggerForJob(pytorchjob).Infof("Append job condition error: %v", err)
logger.Infof("Append job condition error: %v", err)
return err
}
trainingoperatorcommon.SuccessfulJobsCounterInc(pytorchjob.Namespace, pytorchv1.FrameworkName)
return nil
}
}
} else {

if rtype == pytorchv1.PyTorchReplicaTypeWorker {
// TODO(gaocegege): Support SuccessPolicy
if expected == 0 {
worker0Completed, err := r.IsWorker0Completed(pytorchjob, replicas)
if err != nil {
logger.Warnf("check if worker 0 completed error %v", err)
return err
}
// Leave a succeeded condition for the following two cases:
// 1. If default success policy is used and worker 0 has completed.
// 2. If `SuccessPolicyAllWorkers` success policy is used and all workers are succeeded.
if expected == 0 || (worker0Completed && *pytorchjob.Spec.SuccessPolicy != pytorchv1.SuccessPolicyAllWorkers) {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

L387 might take a bit of time to understand. Maybe it's because the order of the two conditions are not arranged as the comment lists.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comments and conditions are copied from TF. I will refine it soon.

msg := fmt.Sprintf("TFJob %s/%s successfully completed.",
pytorchjob.Namespace, pytorchjob.Name)
r.recorder.Event(pytorchjob, corev1.EventTypeNormal, commonutil.JobSucceededReason, msg)
Expand Down Expand Up @@ -430,6 +441,53 @@ func (r *PyTorchJobReconciler) UpdateJobStatus(job interface{},
return nil
}

// IsWorker0Completed returns true if pod of worker0 succeeded and exited with 0
func (p *PyTorchJobReconciler) IsWorker0Completed(job *pytorchv1.PyTorchJob,
replicas map[commonv1.ReplicaType]*commonv1.ReplicaSpec) (bool, error) {
worker0Completed := false
_, ok := replicas[pytorchv1.PyTorchReplicaTypeWorker]
if !ok {
return true, nil
}
podSlices, err := p.getPodSlices(job,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand the slice approach definitely works. But what is the point to create a the getPodSlices method while all we need is just the worker pod with index=0? And what kind of disadvantages will it have to use Get function or List with label selector to catch worker0 pod?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is to keep consistency with TF controller. The code is copied from it.

replicas[pytorchv1.PyTorchReplicaTypeWorker].Replicas)
if err != nil {
return false, err
}
for index, podSlice := range podSlices {
if len(podSlice) == 1 {
pod := podSlice[0]
exitCode := util.GetContainerExitCode(pod, pytorchv1.DefaultContainerName)
if index == 0 && exitCode == 0 && pod.Status.Phase == corev1.PodSucceeded {
worker0Completed = true
}
}
}
return worker0Completed, nil
}

// getPodSlices returns a slice, which element is the slice of pod.
// It gives enough information to caller to make decision to up/down scale resources.
func (p *PyTorchJobReconciler) getPodSlices(
job *pytorchv1.PyTorchJob, replicasNum *int32) ([][]*corev1.Pod, error) {
logger := commonutil.LoggerForReplica(job, strings.ToLower(string(pytorchv1.PyTorchReplicaTypeWorker)))

pods, err := p.GetPodsForJob(job)
if err != nil {
commonutil.LoggerForJob(job).Warnf("getPodsForTFJob error %v", err)
return nil, err
}

// Get all pods for the type rt.
pods, err = p.JobController.FilterPodsForReplicaType(pods, strings.ToLower(string(pytorchv1.PyTorchReplicaTypeWorker)))
if err != nil {
return nil, err
}

podSlices := p.GetPodSlices(pods, int(*replicasNum), logger)
return podSlices, nil
}

// ContainsMasterSpec returns true if the tfjob contains master spec.
func ContainsMasterSpec(replicas map[commonv1.ReplicaType]*commonv1.ReplicaSpec) bool {
if _, ok := replicas[pytorchv1.PyTorchReplicaTypeMaster]; ok {
Expand Down
12 changes: 10 additions & 2 deletions pkg/controller.v1/pytorch/pytorchjob_controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ var _ = Describe("PyTorchJob controller", func() {
}
job.Spec.PyTorchReplicaSpecs = map[commonv1.ReplicaType]*commonv1.ReplicaSpec{
pytorchv1.PyTorchReplicaTypeWorker: {
Replicas: int32Ptr(1),
Replicas: int32Ptr(2),
Template: corev1.PodTemplateSpec{
Spec: corev1.PodSpec{
Containers: []corev1.Container{
Expand Down Expand Up @@ -275,8 +275,15 @@ var _ = Describe("PyTorchJob controller", func() {
BlockOwnerDeletion: &trueVal,
}))

// Test job status.
// Set the worker 0 succeeded.
pod.Status.Phase = corev1.PodSucceeded
pod.Status.ContainerStatuses = make([]corev1.ContainerStatus, 1)
pod.Status.ContainerStatuses[0].Name = pytorchv1.DefaultContainerName
pod.Status.ContainerStatuses[0].State = corev1.ContainerState{
Terminated: &corev1.ContainerStateTerminated{
ExitCode: 0,
},
}
pod.ResourceVersion = ""
Expect(testK8sClient.Status().Update(ctx, pod)).Should(Succeed())
Eventually(func() bool {
Expand All @@ -289,6 +296,7 @@ var _ = Describe("PyTorchJob controller", func() {
}, timeout, interval).Should(BeTrue())
// Check if the job is succeeded.
cond := getCondition(created.Status, commonv1.JobSucceeded)
Expect(cond).NotTo(BeNil())
Expect(cond.Status).To(Equal(corev1.ConditionTrue))
By("Deleting the PyTorchJob")
Expect(testK8sClient.Delete(ctx, job)).Should(Succeed())
Expand Down
14 changes: 7 additions & 7 deletions pkg/controller.v1/tensorflow/tfjob_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -394,12 +394,6 @@ func (r *TFJobReconciler) UpdateJobStatus(job interface{}, replicas map[commonv1

logger := commonutil.LoggerForJob(tfJob)

worker0Completed, err := r.IsWorker0Completed(tfJob, replicas)
if err != nil {
logger.Warnf("check if worker 0 completed error %v", err)
return err
}

// Set StartTime.
if jobStatus.StartTime == nil {
now := metav1.Now()
Expand Down Expand Up @@ -469,6 +463,12 @@ func (r *TFJobReconciler) UpdateJobStatus(job interface{}, replicas map[commonv1
}
} else {
if rtype == tensorflowv1.TFReplicaTypeWorker {
worker0Completed, err := r.IsWorker0Completed(tfJob, replicas)
if err != nil {
logger.Warnf("check if worker 0 completed error %v", err)
return err
}

// Leave a succeeded condition for the following two cases:
// 1. If default success policy is used and worker 0 has completed.
// 2. If `SuccessPolicyAllWorkers` success policy is used and all workers are succeeded.
Expand Down Expand Up @@ -640,7 +640,7 @@ func (r *TFJobReconciler) IsWorker0Completed(tfjob *tensorflowv1.TFJob, replicas
for index, podSlice := range podSlices {
if len(podSlice) == 1 {
pod := podSlice[0]
exitCode := getContainerExitCode(pod)
exitCode := util.GetContainerExitCode(pod, tfv1.DefaultContainerName)
if index == 0 && exitCode == 0 && pod.Status.Phase == v1.PodSucceeded {
worker0Completed = true
}
Expand Down
12 changes: 0 additions & 12 deletions pkg/controller.v1/tensorflow/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,18 +46,6 @@ func ContainsChiefOrMasterSpec(replicas map[commonv1.ReplicaType]*commonv1.Repli
return false
}

// originally from pkg/controller.v1/tensorflow/pod.go (deleted)
func getContainerExitCode(pod *corev1.Pod) int32 {
var exitCode int32 = 0xbeef // magic number
for _, status := range pod.Status.ContainerStatuses {
state := status.State
if status.Name == tfv1.DefaultContainerName && state.Terminated != nil {
exitCode = state.Terminated.ExitCode
}
}
return exitCode
}

// originally from pkg/controller.v1/tensorflow/pod.go (deleted)
func setRestartPolicy(podTemplateSpec *corev1.PodTemplateSpec, spec *commonv1.ReplicaSpec) {
// This is necessary since restartPolicyExitCode is not supported in v1.PodTemplateSpec
Expand Down