diff --git a/AGENTS.md b/AGENTS.md index 9ddd3b3..503e902 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -4,57 +4,42 @@ This file provides guidance to AI coding agents when working with code in this r ## Project Overview -Kubernetes node agent for Civo cloud that monitors cluster nodes and triggers automatic hard reboots via the Civo API when nodes become NotReady or lose expected GPU capacity. Deployed as a single-replica Deployment in kube-system via Helm. +`civo-node-agent` monitors Kubernetes nodes in a Civo cluster and triggers automatic recovery actions (currently hard reboot via the Civo API) when nodes fail health checks. + +### Deployment + +`civo-node-agent` is designed to run as a daemon process on the control plane VM, which is the preferred deployment. A Helm chart (`charts/`) is also provided so it can run as a single-replica Deployment in `kube-system` if needed. + +By default the agent runs in **monitor-only mode** (logs recovery actions without executing them). Set `CIVO_NODE_AGENT_MONITOR_ONLY=false` to enable actual reboots. ## Build & Test Commands ```bash -# Build -go build -o node-agent ./ +# Build (CGO disabled — no C dependencies; required for static binary on the CP VM) +CGO_ENABLED=0 go build -o node-agent ./ # Run all tests go test ./... -# Run a single test -go test ./pkg/watcher/ -run TestName - -# Build Docker image (dry-run) -goreleaser release --snapshot --skip=publish --clean +# Before completing any task, always run: +go fmt ./... +go vet ./... +go test ./... ``` No linter is configured in CI. ## Architecture -**Entrypoint** (`main.go`): Reads env vars, sets up JSON structured logging (slog), creates a Watcher, and runs it with graceful SIGTERM/SIGINT shutdown. - -**Core package** (`pkg/watcher/`): -- `watcher.go` — Main loop polls every 10 seconds. For each node matching the node pool label (`kubernetes.civo.com/civo-node-pool={nodePoolID}`), checks if the node is NotReady or has fewer GPUs than desired. If a reboot is warranted (and cooldown window hasn't elapsed), calls `HardRebootInstance` via the Civo API. -- `options.go` — Functional options pattern (`WithKubernetesClient`, `WithCivoClient`, etc.) for dependency injection and configuration. -- `fake.go` — `FakeClient` implementing `civogo.Clienter` for testing. -- `watcher_test.go` — Tests use fake Kubernetes client (`k8s.io/client-go/kubernetes/fake`) and `FakeClient` for Civo API. - -**Reboot safeguards**: Tracks last reboot time per node in a `sync.Map`. Skips reboot if the node's Ready/NotReady condition transitioned recently or a reboot command was sent within the configurable time window (default 40 minutes). - -## Required Environment Variables - -`CIVO_API_KEY`, `CIVO_REGION`, `CIVO_CLUSTER_ID`, `CIVO_NODE_POOL_ID` — see `.env.example`. - -Optional: `CIVO_API_URL`, `CIVO_NODE_DESIRED_GPU_COUNT`, `CIVO_NODE_REBOOT_TIME_WINDOW_MINUTES`. - -## Deployment - -Helm chart in `charts/`. Secrets are expected in `civo-node-agent` and `civo-api-access` Kubernetes secrets. - -```bash -helm upgrade -n kube-system --install node-agent ./charts -``` +**Entrypoint** (`main.go`): Reads env vars + `--kubeconfig` flag, sets up JSON structured logging (slog), registers Prometheus metrics, starts the metrics HTTP server, constructs an Executor + Checkers + Watcher, and runs the watcher with graceful SIGTERM/SIGINT shutdown. -## Key Dependencies +### Packages -- `github.com/civo/civogo` — Civo cloud API client -- `k8s.io/client-go` — Kubernetes client (in-cluster config by default) +- **`pkg/watcher/`** — Orchestrator. Sets up a Node Informer (filtered by optional node pool label selector) and runs a 10s ticker reconcile loop driving the state machine (`Unknown → Healthy → Unhealthy → WaitingReboot → Failed`). +- **`pkg/health/`** — Health checkers. +- **`pkg/operation/`** — Recovery executors (Civo API reboot; nop executor used as safe default). +- **`pkg/metrics/`** — Prometheus metrics (all `civo_` prefixed). Defined once in `metrics.go`. ## Release -Tags matching `v*.*.*` trigger `.github/workflows/release-image.yaml`, which builds multi-arch Docker images via goreleaser and publishes to Docker Hub. +Tags matching `v*.*.*` trigger `.github/workflows/release-image.yaml`, which builds multi-arch Docker images via goreleaser and publishes to Docker Hub. The same binary is also uploaded to Civo object storage for CP VM installations (handled outside this repository). diff --git a/README.md b/README.md index fb740b7..ff50953 100644 --- a/README.md +++ b/README.md @@ -1,64 +1,60 @@ # Node Agent -`node-agent` monitors the health of Kubernetes nodes and can automatically restart VM instances when necessary. It triggers a restart under the following conditions: +`node-agent` monitors the health of Kubernetes nodes and can automatically reboot VM instances when necessary. A reboot is triggered when a node fails one or more health checks (e.g. `NodeReady`, GPU count, Cilium, DiskPressure) for a configured threshold. -- A node enters the **NotReady** state. -- The number of available GPUs per node falls below a configured threshold. +By default it runs in **monitor-only** mode, logging recovery actions without executing them. Set `monitorOnly=false` to enable actual reboots. +## Prerequisites: `civo-api-access` Secret -## Set Your `civo-node-agent` Secret +The `civo-api-access` secret is automatically provisioned by Civo in the `kube-system` namespace of every Civo Kubernetes cluster. It contains the API credentials and cluster identity used by `node-agent`: -``` -export CIVO_DESIRED_GPU_COUNT="8" -export CIVO_NODE_POOL_ID="xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxxxxxx" -export CIVO_API_KEY="xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx" -export CIVO_NODE_REBOOT_TIME_WINDOW_MINUTES="xxxx" -kubectl -n kube-system delete secret civo-node-agent --ignore-not-found -kubectl -n kube-system create secret generic civo-node-agent -kubectl -n kube-system patch secret civo-node-agent -n kube-system --type='merge' \ - -p='{"stringData": {"civo-api-key": "'"$CIVO_API_KEY"'", "node-pool-id": "'"$CIVO_NODE_POOL_ID"'", "desired-gpu-count": "'"$CIVO_DESIRED_GPU_COUNT"'", "time-window": "'"$CIVO_NODE_REBOOT_TIME_WINDOW_MINUTES"'" }}' -``` +| Key | Description | +|-----|-------------| +| `api-key` | Civo API key used for reboot operations. | +| `api-url` | Civo API URL. | +| `cluster-id` | The ID of this Civo Kubernetes cluster. | +| `region` | The Civo region this cluster runs in. | -## Nvidia Device Plugin Install +No manual setup is required — `node-agent` reads these values directly from the existing secret. -```bash -kubectl create ns gpu-operator -kubectl label namespace gpu-operator pod-security.kubernetes.io/enforce=privileged -kubectl label namespace gpu-operator pod-security.kubernetes.io/warn=privileged -kubectl label namespace gpu-operator pod-security.kubernetes.io/audit=privileged -``` +## NVIDIA GPU Operator (GPU clusters only) -```bash -helm repo add nvdp https://nvidia.github.io/k8s-device-plugin \ -&& helm repo update -``` +The GPU health check relies on the `nvidia.com/gpu.count` label added by the NVIDIA GPU Feature Discovery component. Follow the Civo documentation to install the NVIDIA GPU Operator on your cluster: -```bash -helm install --namespace gpu-operator nvidia-device-plugin nvdp/nvidia-device-plugin --create-namespace \ - --version=0.17.0 \ - --set gfd.enabled=true \ - --set devicePlugin.enabled=true \ - --set dcgm.enabled=true \ - --set nfd.enableNodeFeatureApi=true \ - --wait -``` +[Installing the NVIDIA GPU Operator](https://www.civo.com/docs/kubernetes/advanced/gpu-config#installing-the-nvidia-gpu-operator) ## Install `node-agent` chart -You will need to clone this repository in order to have access to the charts directory that is used for installation. In your terminal, please change directory to your cloned `node-agent` repo directory, and then run: +You will need to clone this repository in order to have access to the charts directory. In your terminal, change directory to your cloned `node-agent` repo directory, then run: ```bash helm upgrade -n kube-system --install node-agent ./charts ``` -## Configuration Details +To enable active recovery (actually reboot nodes): + +```bash +helm upgrade -n kube-system --install node-agent ./charts --set monitorOnly=false +``` -The following configurations are stored in the `node-agent` secret in the `kube-system` namespace. +## Configuration -`node-pool-id`: The ID of your Kubernetes node pool which you want monitored. To collect this value, go to the [civo kubernetes dashboard](https://dashboard.civo.com/kubernetes), select your cluster, and click copy next to your pool id. +### Helm values (`values.yaml`) -`desired-gpu-count`: This value is intended to match the number of GPUs per node. If you had a 2-node cluster with 8 GPU total, you would set this value to 4 to represent the number of GPUs per node. +| Value | Default | Description | +|-------|---------|-------------| +| `nodePoolIDs` | `""` | Comma-separated node pool IDs to watch. Empty means all nodes. | +| `rebootWaitMinutes` | `10` | Minutes to wait after rebooting a standard node before retrying. | +| `gpuRebootWaitMinutes` | `40` | Minutes to wait after rebooting a GPU node before retrying. | +| `maxRebootRetries` | `5` | Maximum reboot attempts before the node transitions to `Failed` (no further reboots). | +| `monitorOnly` | `true` | If `true`, log recovery actions without executing them. Set `false` to enable reboots. | +| `metricsPort` | `9625` | Port for the Prometheus metrics endpoint. | -`civo-api-key`: The civo api key to use when automatically rebooting nodes. To collect this value, go to toue [civo settings security tab](https://dashboard.civo.com/security). +### Health checkers -`time-window`: The time-window is the time we need to give a node after a reboot happens +| Checker | Condition | Threshold | +|---------|-----------|-----------| +| `NodeReady` | `NodeReady == True` | 5 min | +| `DiskPressure` | `DiskPressure != True` | 30 min | +| `CiliumAgent` | `NetworkUnavailable == False` with reason `CiliumIsUp` (skipped for non-Cilium CNI) | 10 min | +| `GPU` | `allocatable["nvidia.com/gpu"]` equals `nvidia.com/gpu.count` label (skipped for non-GPU nodes) | 10 min | diff --git a/charts/templates/deployment.yaml b/charts/templates/deployment.yaml index a6f6e07..3c2626a 100644 --- a/charts/templates/deployment.yaml +++ b/charts/templates/deployment.yaml @@ -39,8 +39,8 @@ spec: - name: CIVO_API_KEY valueFrom: secretKeyRef: - name: civo-node-agent - key: civo-api-key + name: civo-api-access + key: api-key - name: CIVO_API_URL valueFrom: secretKeyRef: @@ -56,27 +56,30 @@ spec: secretKeyRef: name: civo-api-access key: region - - name: CIVO_NODE_POOL_ID - valueFrom: - secretKeyRef: - name: civo-node-agent - key: node-pool-id - - name: CIVO_NODE_DESIRED_GPU_COUNT - valueFrom: - secretKeyRef: - name: civo-node-agent - key: desired-gpu-count - - name: CIVO_NODE_REBOOT_TIME_WINDOW_MINUTES - valueFrom: - secretKeyRef: - name: civo-node-agent - key: time-window + - name: CIVO_NODE_POOL_IDS + value: {{ .Values.nodePoolIDs | quote }} + - name: CIVO_NODE_REBOOT_WAIT_MINUTES + value: {{ .Values.rebootWaitMinutes | quote }} + - name: CIVO_GPU_NODE_REBOOT_WAIT_MINUTES + value: {{ .Values.gpuRebootWaitMinutes | quote }} + - name: CIVO_NODE_MAX_REBOOT_RETRIES + value: {{ .Values.maxRebootRetries | quote }} + - name: CIVO_NODE_AGENT_MONITOR_ONLY + value: {{ .Values.monitorOnly | quote }} + - name: CIVO_NODE_AGENT_METRICS_PORT + value: {{ .Values.metricsPort | quote }} {{- with .Values.securityContext }} securityContext: {{- toYaml . | nindent 12 }} {{- end }} image: "{{ .Values.image.repository }}:{{ .Values.image.tag | default .Chart.AppVersion }}" imagePullPolicy: {{ .Values.image.pullPolicy }} + args: + - "--kubeconfig=" + ports: + - name: metrics + containerPort: {{ .Values.metricsPort | default 9625 }} + protocol: TCP {{- with .Values.resources }} resources: {{- toYaml . | nindent 12 }} diff --git a/charts/templates/rbac.yaml b/charts/templates/rbac.yaml index e3a295f..0bb2b6a 100644 --- a/charts/templates/rbac.yaml +++ b/charts/templates/rbac.yaml @@ -16,5 +16,6 @@ subjects: name: {{ .Chart.Name }} namespace: kube-system roleRef: - kind: ClusterRole + kind: ClusterRole name: {{ .Chart.Name }} + apiGroup: rbac.authorization.k8s.io diff --git a/charts/values.yaml b/charts/values.yaml index 92ec8b1..3551e88 100644 --- a/charts/values.yaml +++ b/charts/values.yaml @@ -6,6 +6,24 @@ image: pullPolicy: IfNotPresent tag: "6b8426a" +# Comma-separated node pool IDs to watch (empty = all nodes). +nodePoolIDs: "" + +# Reboot wait time for standard nodes (minutes). +rebootWaitMinutes: 10 + +# Reboot wait time for GPU nodes (minutes). +gpuRebootWaitMinutes: 40 + +# Maximum number of reboot attempts before a node transitions to PhaseFailed. +maxRebootRetries: 5 + +# Monitor-only mode: log recovery actions without executing them. +monitorOnly: true + +# Port for Prometheus metrics endpoint. +metricsPort: 9625 + imagePullSecrets: [] nameOverride: "" fullnameOverride: "" diff --git a/go.mod b/go.mod index 21997ac..bf87abb 100644 --- a/go.mod +++ b/go.mod @@ -4,12 +4,15 @@ go 1.24.0 require ( github.com/civo/civogo v0.3.94 + github.com/prometheus/client_golang v1.23.2 k8s.io/api v0.32.2 k8s.io/apimachinery v0.32.2 k8s.io/client-go v0.32.2 ) require ( + github.com/beorn7/perks v1.0.1 // indirect + github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/emicklei/go-restful/v3 v3.11.0 // indirect github.com/fxamacker/cbor/v2 v2.7.0 // indirect @@ -20,7 +23,7 @@ require ( github.com/gogo/protobuf v1.3.2 // indirect github.com/golang/protobuf v1.5.4 // indirect github.com/google/gnostic-models v0.6.8 // indirect - github.com/google/go-cmp v0.6.0 // indirect + github.com/google/go-cmp v0.7.0 // indirect github.com/google/go-querystring v1.1.0 // indirect github.com/google/gofuzz v1.2.0 // indirect github.com/google/uuid v1.6.0 // indirect @@ -31,16 +34,20 @@ require ( github.com/modern-go/reflect2 v1.0.2 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect github.com/pkg/errors v0.9.1 // indirect + github.com/prometheus/client_model v0.6.2 // indirect + github.com/prometheus/common v0.66.1 // indirect + github.com/prometheus/procfs v0.16.1 // indirect github.com/spf13/pflag v1.0.5 // indirect github.com/x448/float16 v0.8.4 // indirect - golang.org/x/mod v0.20.0 // indirect - golang.org/x/net v0.38.0 // indirect - golang.org/x/oauth2 v0.27.0 // indirect - golang.org/x/sys v0.31.0 // indirect - golang.org/x/term v0.30.0 // indirect - golang.org/x/text v0.23.0 // indirect + go.yaml.in/yaml/v2 v2.4.2 // indirect + golang.org/x/mod v0.26.0 // indirect + golang.org/x/net v0.43.0 // indirect + golang.org/x/oauth2 v0.30.0 // indirect + golang.org/x/sys v0.35.0 // indirect + golang.org/x/term v0.34.0 // indirect + golang.org/x/text v0.28.0 // indirect golang.org/x/time v0.7.0 // indirect - google.golang.org/protobuf v1.35.1 // indirect + google.golang.org/protobuf v1.36.8 // indirect gopkg.in/evanphx/json-patch.v4 v4.12.0 // indirect gopkg.in/inf.v0 v0.9.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect diff --git a/go.sum b/go.sum index efaae30..fb4ea1d 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,7 @@ +github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= +github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/civo/civogo v0.3.94 h1:VhdqaJ2m4z8Jz8arzyzVjokRnO8JQ3lGjLKLshJ1eJI= github.com/civo/civogo v0.3.94/go.mod h1:LaEbkszc+9nXSh4YNG0sYXFGYqdQFmXXzQg0gESs2hc= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= @@ -29,8 +33,8 @@ github.com/google/gnostic-models v0.6.8 h1:yo/ABAfM5IMRsS1VnXjTBvUb61tFIHozhlYvR github.com/google/gnostic-models v0.6.8/go.mod h1:5n7qKqH0f5wFt+aWF8CW6pZLLNOfYuF5OpfBSENuI8U= github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= -github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= -github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/go-querystring v1.1.0 h1:AnCroh3fv4ZBgVIf1Iwtovgjaw/GiKJo8M8yD/fhyJ8= github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17icRSOU623lUBU= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= @@ -46,6 +50,8 @@ github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnr github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= +github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= +github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ= github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= @@ -53,6 +59,8 @@ github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= +github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0= github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= @@ -71,6 +79,14 @@ github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINE github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/prometheus/client_golang v1.23.2 h1:Je96obch5RDVy3FDMndoUsjAhG5Edi49h0RJWRi/o0o= +github.com/prometheus/client_golang v1.23.2/go.mod h1:Tb1a6LWHB3/SPIzCoaDXI4I8UHKeFTEQ1YCr+0Gyqmg= +github.com/prometheus/client_model v0.6.2 h1:oBsgwpGs7iVziMvrGhE53c/GrLUsZdHnqNwqPLxwZyk= +github.com/prometheus/client_model v0.6.2/go.mod h1:y3m2F6Gdpfy6Ut/GBsUqTWZqCUvMVzSfMLjcu6wAwpE= +github.com/prometheus/common v0.66.1 h1:h5E0h5/Y8niHc5DlaLlWLArTQI7tMrsfQjHV+d9ZoGs= +github.com/prometheus/common v0.66.1/go.mod h1:gcaUsgf3KfRSwHY4dIMXLPV0K/Wg1oZ8+SbZk/HH/dA= +github.com/prometheus/procfs v0.16.1 h1:hZ15bTNuirocR6u0JZ6BAHHmwS1p8B4P6MRqxtzMyRg= +github.com/prometheus/procfs v0.16.1/go.mod h1:teAbpZRB1iIAJYREa1LsoWUXykVXA1KlTmWl8x/U+Is= github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= @@ -82,55 +98,59 @@ github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UV github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= -github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= -github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM= github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= +go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= +go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= +go.yaml.in/yaml/v2 v2.4.2 h1:DzmwEr2rDGHl7lsFgAHxmNz/1NlQ7xLIrlN2h5d1eGI= +go.yaml.in/yaml/v2 v2.4.2/go.mod h1:081UH+NErpNdqlCXm3TtEran0rJZGxAYx9hb/ELlsPU= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= -golang.org/x/mod v0.20.0 h1:utOm6MM3R3dnawAiJgn0y+xvuYRsm1RKM/4giyfDgV0= -golang.org/x/mod v0.20.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= +golang.org/x/mod v0.26.0 h1:EGMPT//Ezu+ylkCijjPc+f4Aih7sZvaAr+O3EHBxvZg= +golang.org/x/mod v0.26.0/go.mod h1:/j6NAhSk8iQ723BGAUyoAcn7SlD7s15Dp9Nd/SfeaFQ= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= -golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8= -golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8= -golang.org/x/oauth2 v0.27.0 h1:da9Vo7/tDv5RH/7nZDz1eMGS/q1Vv1N/7FCrBhI9I3M= -golang.org/x/oauth2 v0.27.0/go.mod h1:onh5ek6nERTohokkhCD/y2cV4Do3fxFHFuAejCkRWT8= +golang.org/x/net v0.43.0 h1:lat02VYK2j4aLzMzecihNvTlJNQUq316m2Mr9rnM6YE= +golang.org/x/net v0.43.0/go.mod h1:vhO1fvI4dGsIjh73sWfUVjj3N7CA9WkKJNQm2svM6Jg= +golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI= +golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik= -golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= -golang.org/x/term v0.30.0 h1:PQ39fJZ+mfadBm0y5WlL4vlM7Sx1Hgf13sMIY2+QS9Y= -golang.org/x/term v0.30.0/go.mod h1:NYYFdzHoI5wRh/h5tDMdMqCqPJZEuNqVR5xJLd/n67g= +golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI= +golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/term v0.34.0 h1:O/2T7POpk0ZZ7MAzMeWFSg6S5IpWd/RXDlM9hgM3DR4= +golang.org/x/term v0.34.0/go.mod h1:5jC53AEywhIVebHgPVeg0mj8OD3VO9OzclacVrqpaAw= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY= -golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4= +golang.org/x/text v0.28.0 h1:rhazDwis8INMIwQ4tpjLDzUhx6RlXqZNPEM0huQojng= +golang.org/x/text v0.28.0/go.mod h1:U8nCwOR8jO/marOQ0QbDiOngZVEBB7MAiitBuMjXiNU= golang.org/x/time v0.7.0 h1:ntUhktv3OPE6TgYxXWv9vKvUSJyIFJlyohwbkEwPrKQ= golang.org/x/time v0.7.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= -golang.org/x/tools v0.26.0 h1:v/60pFQmzmT9ExmjDv2gGIfi3OqfKoEP6I5+umXlbnQ= -golang.org/x/tools v0.26.0/go.mod h1:TPVVj70c7JJ3WCazhD8OdXcZg/og+b9+tH/KxylGwH0= +golang.org/x/tools v0.35.0 h1:mBffYraMEf7aa0sB+NuKnuCy8qI/9Bughn8dC2Gu5r0= +golang.org/x/tools v0.35.0/go.mod h1:NKdj5HkL/73byiZSJjqJgKn3ep7KjFkBOkR/Hps3VPw= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -google.golang.org/protobuf v1.35.1 h1:m3LfL6/Ca+fqnjnlqQXNpFPABW1UD7mjh8KO2mKFytA= -google.golang.org/protobuf v1.35.1/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE= +google.golang.org/protobuf v1.36.8 h1:xHScyCOEuuwZEc6UtSOvPbAT4zRh0xcNRYekJwfqyMc= +google.golang.org/protobuf v1.36.8/go.mod h1:fuxRtAxBytpl4zzqUh6/eyUujkJdNiuEkXntxiD/uRU= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= diff --git a/main.go b/main.go index 94f678b..22f0fcc 100644 --- a/main.go +++ b/main.go @@ -2,35 +2,88 @@ package main import ( "context" + "errors" "flag" + "fmt" "log/slog" + "net/http" "os" "os/signal" + "strconv" "strings" "syscall" + "time" + "github.com/civo/node-agent/pkg/health" + "github.com/civo/node-agent/pkg/metrics" + "github.com/civo/node-agent/pkg/operation" "github.com/civo/node-agent/pkg/watcher" ) -var versionInfo = flag.Bool("version", false, "Print the driver version") +var ( + version = "0.2.0" + versionInfo = flag.Bool("version", false, "Print the driver version") + kubeconfigPath = flag.String("kubeconfig", "/etc/rancher/k3s/k3s.yaml", "Path to kubeconfig file (empty for in-cluster config)") +) var ( - apiURL = strings.TrimSpace(os.Getenv("CIVO_API_URL")) - apiKey = strings.TrimSpace(os.Getenv("CIVO_API_KEY")) - region = strings.TrimSpace(os.Getenv("CIVO_REGION")) - clusterID = strings.TrimSpace(os.Getenv("CIVO_CLUSTER_ID")) - nodePoolID = strings.TrimSpace(os.Getenv("CIVO_NODE_POOL_ID")) - nodeDesiredGPUCount = strings.TrimSpace(os.Getenv("CIVO_NODE_DESIRED_GPU_COUNT")) - rebootTimeWindowMinutes = strings.TrimSpace(os.Getenv("CIVO_NODE_REBOOT_TIME_WINDOW_MINUTES")) + apiURL = strings.TrimSpace(os.Getenv("CIVO_API_URL")) + apiKey = strings.TrimSpace(os.Getenv("CIVO_API_KEY")) + region = strings.TrimSpace(os.Getenv("CIVO_REGION")) + clusterID = strings.TrimSpace(os.Getenv("CIVO_CLUSTER_ID")) + nodePoolIDs = strings.TrimSpace(os.Getenv("CIVO_NODE_POOL_IDS")) + rebootWaitMinutes = strings.TrimSpace(os.Getenv("CIVO_NODE_REBOOT_WAIT_MINUTES")) + gpuRebootWaitMinutes = strings.TrimSpace(os.Getenv("CIVO_GPU_NODE_REBOOT_WAIT_MINUTES")) + maxRebootRetries = strings.TrimSpace(os.Getenv("CIVO_NODE_MAX_REBOOT_RETRIES")) + monitorOnly = strings.TrimSpace(os.Getenv("CIVO_NODE_AGENT_MONITOR_ONLY")) + metricsPort = strings.TrimSpace(os.Getenv("CIVO_NODE_AGENT_METRICS_PORT")) +) + +const ( + defaultMetricsPort = 9625 ) func run(ctx context.Context) error { ctx, stop := signal.NotifyContext(ctx, os.Interrupt, syscall.SIGTERM) defer stop() - w, err := watcher.NewWatcher(ctx, apiURL, apiKey, region, clusterID, nodePoolID, - watcher.WithRebootTimeWindowMinutes(rebootTimeWindowMinutes), - watcher.WithDesiredGPUCount(nodeDesiredGPUCount), + executor, err := operation.NewCivoExecutor(clusterID, + operation.WithAPIConfig(apiKey, apiURL, region, version)) + if err != nil { + return fmt.Errorf("failed to initialise executor: %w", err) + } + checkers := health.NewDefaultCheckers() + + metrics.Register() + metrics.Info.WithLabelValues(version, clusterID).Set(1) + metricsServer := &http.Server{ + Addr: ":" + metricsPortValue(metricsPort), + Handler: metrics.Handler(), + } + go func() { + slog.Info("Starting metrics server", "addr", metricsServer.Addr) + if err := metricsServer.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) { + slog.Error("Metrics server failed", "error", err) + stop() + } + }() + defer func() { + shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if err := metricsServer.Shutdown(shutdownCtx); err != nil { + slog.Error("Metrics server shutdown failed", "error", err) + } + }() + + w, err := watcher.NewWatcher( + watcher.WithNodePoolIDs(nodePoolIDs), + watcher.WithKubernetesClientConfigPath(*kubeconfigPath), + watcher.WithExecutor(executor), + watcher.WithCheckers(checkers), + watcher.WithMonitorOnly(monitorOnly), + watcher.WithRebootWaitMinutes(rebootWaitMinutes), + watcher.WithGPURebootWaitMinutes(gpuRebootWaitMinutes), + watcher.WithMaxRebootRetries(maxRebootRetries), ) if err != nil { return err @@ -41,14 +94,14 @@ func run(ctx context.Context) error { func main() { flag.Parse() if *versionInfo { - slog.Info("node-agent", "version", watcher.Version) + slog.Info("node-agent", "version", version) return } slog.SetDefault(slog.New(slog.NewJSONHandler(os.Stdout, nil).WithAttrs([]slog.Attr{ slog.String("clusterID", clusterID), slog.String("region", region), - slog.String("nodePoolID", nodePoolID), + slog.String("nodePoolIDs", nodePoolIDs), }))) if err := run(context.Background()); err != nil { @@ -56,3 +109,10 @@ func main() { os.Exit(1) } } + +func metricsPortValue(s string) string { + if v, err := strconv.Atoi(s); err == nil && v >= 1024 && v <= 65535 { + return s + } + return strconv.Itoa(defaultMetricsPort) +} diff --git a/pkg/health/cilium.go b/pkg/health/cilium.go new file mode 100644 index 0000000..44ce71d --- /dev/null +++ b/pkg/health/cilium.go @@ -0,0 +1,32 @@ +package health + +import ( + "time" + + corev1 "k8s.io/api/core/v1" +) + +const ( + ciliumReadyReason = "CiliumIsUp" + ciliumThreshold = 10 * time.Minute +) + +// ciliumChecker reports healthy when the Cilium-managed NetworkUnavailable +// condition is False. If the condition's reason is not "CiliumIsUp" +// (i.e. a different CNI manages the condition), the check is skipped. +type ciliumChecker struct{} + +func (c *ciliumChecker) Name() string { return "CiliumAgent" } +func (c *ciliumChecker) Threshold() time.Duration { return ciliumThreshold } + +func (c *ciliumChecker) Check(node *corev1.Node) (bool, string) { + for _, cond := range node.Status.Conditions { + if cond.Type == corev1.NodeNetworkUnavailable { + if cond.Reason != ciliumReadyReason { + return true, cond.Reason + } + return cond.Status == corev1.ConditionFalse, cond.Reason + } + } + return true, "NetworkUnavailable condition not found" +} diff --git a/pkg/health/cilium_test.go b/pkg/health/cilium_test.go new file mode 100644 index 0000000..a146788 --- /dev/null +++ b/pkg/health/cilium_test.go @@ -0,0 +1,99 @@ +package health + +import ( + "testing" + "time" + + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +func TestCiliumChecker_Threshold(t *testing.T) { + c := &ciliumChecker{} + if got := c.Threshold(); got != 10*time.Minute { + t.Errorf("got %v, want %v", got, 10*time.Minute) + } +} + +func TestCiliumChecker_Name(t *testing.T) { + c := &ciliumChecker{} + if got := c.Name(); got != "CiliumAgent" { + t.Errorf("got %q, want %q", got, "CiliumAgent") + } +} + +func TestCiliumChecker_Check(t *testing.T) { + tests := []struct { + description string + node *corev1.Node + want bool + }{ + { + description: "returns true when NetworkUnavailable is False with CiliumIsUp", + node: &corev1.Node{ + ObjectMeta: metav1.ObjectMeta{Name: "node-01"}, + Status: corev1.NodeStatus{ + Conditions: []corev1.NodeCondition{ + { + Type: corev1.NodeNetworkUnavailable, + Status: corev1.ConditionFalse, + Reason: ciliumReadyReason, + }, + }, + }, + }, + want: true, + }, + { + description: "returns false when NetworkUnavailable is True with CiliumIsUp", + node: &corev1.Node{ + ObjectMeta: metav1.ObjectMeta{Name: "node-01"}, + Status: corev1.NodeStatus{ + Conditions: []corev1.NodeCondition{ + { + Type: corev1.NodeNetworkUnavailable, + Status: corev1.ConditionTrue, + Reason: ciliumReadyReason, + }, + }, + }, + }, + want: false, + }, + { + description: "skips check when NetworkUnavailable has non-Cilium reason", + node: &corev1.Node{ + ObjectMeta: metav1.ObjectMeta{Name: "node-01"}, + Status: corev1.NodeStatus{ + Conditions: []corev1.NodeCondition{ + { + Type: corev1.NodeNetworkUnavailable, + Status: corev1.ConditionFalse, + Reason: "FlannelIsUp", + }, + }, + }, + }, + want: true, + }, + { + description: "returns true when condition is absent", + node: &corev1.Node{ + ObjectMeta: metav1.ObjectMeta{Name: "node-01"}, + Status: corev1.NodeStatus{ + Conditions: []corev1.NodeCondition{}, + }, + }, + want: true, + }, + } + + c := &ciliumChecker{} + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { + if got, _ := c.Check(test.node); got != test.want { + t.Errorf("got %v, want %v", got, test.want) + } + }) + } +} diff --git a/pkg/health/disk_pressure.go b/pkg/health/disk_pressure.go new file mode 100644 index 0000000..d092aa3 --- /dev/null +++ b/pkg/health/disk_pressure.go @@ -0,0 +1,24 @@ +package health + +import ( + "time" + + corev1 "k8s.io/api/core/v1" +) + +const diskPressureThreshold = 30 * time.Minute + +// diskPressureChecker reports healthy when the node does not have disk pressure. +type diskPressureChecker struct{} + +func (c *diskPressureChecker) Name() string { return "DiskPressure" } +func (c *diskPressureChecker) Threshold() time.Duration { return diskPressureThreshold } + +func (c *diskPressureChecker) Check(node *corev1.Node) (bool, string) { + for _, cond := range node.Status.Conditions { + if cond.Type == corev1.NodeDiskPressure { + return cond.Status != corev1.ConditionTrue, cond.Reason + } + } + return true, "DiskPressure condition not found" +} diff --git a/pkg/health/disk_pressure_test.go b/pkg/health/disk_pressure_test.go new file mode 100644 index 0000000..bfc66aa --- /dev/null +++ b/pkg/health/disk_pressure_test.go @@ -0,0 +1,87 @@ +package health + +import ( + "testing" + "time" + + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +func TestDiskPressureChecker_Threshold(t *testing.T) { + c := &diskPressureChecker{} + if got := c.Threshold(); got != 30*time.Minute { + t.Errorf("got %v, want %v", got, 30*time.Minute) + } +} + +func TestDiskPressureChecker_Name(t *testing.T) { + c := &diskPressureChecker{} + if got := c.Name(); got != "DiskPressure" { + t.Errorf("got %q, want %q", got, "DiskPressure") + } +} + +func TestDiskPressureChecker_Check(t *testing.T) { + tests := []struct { + description string + node *corev1.Node + want bool + }{ + { + description: "returns true when DiskPressure is False (no pressure)", + node: &corev1.Node{ + ObjectMeta: metav1.ObjectMeta{Name: "node-01"}, + Status: corev1.NodeStatus{ + Conditions: []corev1.NodeCondition{ + {Type: corev1.NodeDiskPressure, Status: corev1.ConditionFalse}, + }, + }, + }, + want: true, + }, + { + description: "returns false when DiskPressure is True (under pressure)", + node: &corev1.Node{ + ObjectMeta: metav1.ObjectMeta{Name: "node-01"}, + Status: corev1.NodeStatus{ + Conditions: []corev1.NodeCondition{ + {Type: corev1.NodeDiskPressure, Status: corev1.ConditionTrue}, + }, + }, + }, + want: false, + }, + { + description: "returns true when no conditions present", + node: &corev1.Node{ + ObjectMeta: metav1.ObjectMeta{Name: "node-01"}, + Status: corev1.NodeStatus{ + Conditions: []corev1.NodeCondition{}, + }, + }, + want: true, + }, + { + description: "returns true when only non-DiskPressure conditions present", + node: &corev1.Node{ + ObjectMeta: metav1.ObjectMeta{Name: "node-01"}, + Status: corev1.NodeStatus{ + Conditions: []corev1.NodeCondition{ + {Type: corev1.NodeReady, Status: corev1.ConditionTrue}, + }, + }, + }, + want: true, + }, + } + + c := &diskPressureChecker{} + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { + if got, _ := c.Check(test.node); got != test.want { + t.Errorf("got %v, want %v", got, test.want) + } + }) + } +} diff --git a/pkg/health/gpu.go b/pkg/health/gpu.go new file mode 100644 index 0000000..48bd351 --- /dev/null +++ b/pkg/health/gpu.go @@ -0,0 +1,65 @@ +package health + +import ( + "strconv" + "time" + + corev1 "k8s.io/api/core/v1" +) + +const ( + gpuResourceName = "nvidia.com/gpu" + gpuCountLabel = "nvidia.com/gpu.count" + gpuThreshold = 10 * time.Minute +) + +// gpuChecker reports healthy when the node's allocatable GPU count +// matches the expected count from the nvidia.com/gpu.count label. +// If the label is not present, the node is not a GPU node and the check is skipped. +type gpuChecker struct{} + +func (c *gpuChecker) Name() string { return "GPU" } +func (c *gpuChecker) Threshold() time.Duration { return gpuThreshold } + +func (c *gpuChecker) Check(node *corev1.Node) (bool, string) { + expected, ok := expectedGPUCount(node) + if !ok || expected == 0 { + return true, "NonGPUNode" + } + + quantity, exists := node.Status.Allocatable[gpuResourceName] + if !exists || quantity.IsZero() { + return false, "GPUCountMismatch" + } + + actual, ok := quantity.AsInt64() + if !ok { + return false, "NoAllocatableGPU" + } + + if actual == int64(expected) { + return true, "GPUCountMatch" + } + return false, "GPUCountMismatch" +} + +// HasGPU returns true if the node has the nvidia.com/gpu.count label +// with a positive value, indicating it is a GPU node regardless of +// current GPU health. +func HasGPU(node *corev1.Node) bool { + n, ok := expectedGPUCount(node) + return ok && n > 0 +} + +// expectedGPUCount reads the nvidia.com/gpu.count label from the node. +func expectedGPUCount(node *corev1.Node) (int, bool) { + v, exists := node.Labels[gpuCountLabel] + if !exists { + return 0, false + } + n, err := strconv.Atoi(v) + if err != nil || n < 0 { + return 0, false + } + return n, true +} diff --git a/pkg/health/gpu_test.go b/pkg/health/gpu_test.go new file mode 100644 index 0000000..7a48c3f --- /dev/null +++ b/pkg/health/gpu_test.go @@ -0,0 +1,203 @@ +package health + +import ( + "testing" + "time" + + corev1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/api/resource" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +func TestGPUChecker_Threshold(t *testing.T) { + c := &gpuChecker{} + if got := c.Threshold(); got != 10*time.Minute { + t.Errorf("got %v, want %v", got, 10*time.Minute) + } +} + +func TestHasGPU(t *testing.T) { + tests := []struct { + description string + node *corev1.Node + want bool + }{ + { + description: "returns true when gpu.count label is positive", + node: &corev1.Node{ + ObjectMeta: metav1.ObjectMeta{ + Name: "node-01", + Labels: map[string]string{gpuCountLabel: "8"}, + }, + }, + want: true, + }, + { + description: "returns false when gpu.count label is absent", + node: &corev1.Node{ + ObjectMeta: metav1.ObjectMeta{Name: "node-01"}, + }, + want: false, + }, + { + description: "returns false when gpu.count label is 0", + node: &corev1.Node{ + ObjectMeta: metav1.ObjectMeta{ + Name: "node-01", + Labels: map[string]string{gpuCountLabel: "0"}, + }, + }, + want: false, + }, + { + description: "returns false when gpu.count label is invalid", + node: &corev1.Node{ + ObjectMeta: metav1.ObjectMeta{ + Name: "node-01", + Labels: map[string]string{gpuCountLabel: "invalid"}, + }, + }, + want: false, + }, + } + + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { + if got := HasGPU(test.node); got != test.want { + t.Errorf("got %v, want %v", got, test.want) + } + }) + } +} + +func TestGPUChecker_Name(t *testing.T) { + c := &gpuChecker{} + if got := c.Name(); got != "GPU" { + t.Errorf("got %q, want %q", got, "GPU") + } +} + +func TestGPUChecker_Check(t *testing.T) { + tests := []struct { + description string + node *corev1.Node + want bool + }{ + { + description: "returns true when allocatable matches label count", + node: &corev1.Node{ + ObjectMeta: metav1.ObjectMeta{ + Name: "node-01", + Labels: map[string]string{gpuCountLabel: "8"}, + }, + Status: corev1.NodeStatus{ + Allocatable: corev1.ResourceList{ + gpuResourceName: resource.MustParse("8"), + }, + }, + }, + want: true, + }, + { + description: "returns true when gpu.count label is absent (non-GPU node)", + node: &corev1.Node{ + ObjectMeta: metav1.ObjectMeta{Name: "node-01"}, + Status: corev1.NodeStatus{ + Allocatable: corev1.ResourceList{}, + }, + }, + want: true, + }, + { + description: "returns true when gpu.count label is 0", + node: &corev1.Node{ + ObjectMeta: metav1.ObjectMeta{ + Name: "node-01", + Labels: map[string]string{gpuCountLabel: "0"}, + }, + Status: corev1.NodeStatus{ + Allocatable: corev1.ResourceList{}, + }, + }, + want: true, + }, + { + description: "returns false when allocatable is less than label count", + node: &corev1.Node{ + ObjectMeta: metav1.ObjectMeta{ + Name: "node-01", + Labels: map[string]string{gpuCountLabel: "8"}, + }, + Status: corev1.NodeStatus{ + Allocatable: corev1.ResourceList{ + gpuResourceName: resource.MustParse("7"), + }, + }, + }, + want: false, + }, + { + description: "returns false when allocatable GPU is zero", + node: &corev1.Node{ + ObjectMeta: metav1.ObjectMeta{ + Name: "node-01", + Labels: map[string]string{gpuCountLabel: "8"}, + }, + Status: corev1.NodeStatus{ + Allocatable: corev1.ResourceList{ + gpuResourceName: resource.MustParse("0"), + }, + }, + }, + want: false, + }, + { + description: "returns false when allocatable GPU resource is missing", + node: &corev1.Node{ + ObjectMeta: metav1.ObjectMeta{ + Name: "node-01", + Labels: map[string]string{gpuCountLabel: "8"}, + }, + Status: corev1.NodeStatus{ + Allocatable: corev1.ResourceList{}, + }, + }, + want: false, + }, + { + description: "returns true when gpu.count label is invalid", + node: &corev1.Node{ + ObjectMeta: metav1.ObjectMeta{ + Name: "node-01", + Labels: map[string]string{gpuCountLabel: "invalid"}, + }, + Status: corev1.NodeStatus{ + Allocatable: corev1.ResourceList{}, + }, + }, + want: true, + }, + } + + c := &gpuChecker{} + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { + if got, _ := c.Check(test.node); got != test.want { + t.Errorf("got %v, want %v", got, test.want) + } + }) + } +} + +func TestNewDefaultCheckers(t *testing.T) { + checkers := NewDefaultCheckers() + if len(checkers) != 4 { + t.Fatalf("expected 4 checkers, got %d", len(checkers)) + } + expected := []string{"NodeReady", "DiskPressure", "CiliumAgent", "GPU"} + for i, name := range expected { + if checkers[i].Name() != name { + t.Errorf("checkers[%d]: expected %q, got %q", i, name, checkers[i].Name()) + } + } +} diff --git a/pkg/health/health.go b/pkg/health/health.go new file mode 100644 index 0000000..58ff759 --- /dev/null +++ b/pkg/health/health.go @@ -0,0 +1,32 @@ +package health + +import ( + "time" + + corev1 "k8s.io/api/core/v1" +) + +// HealthChecker determines whether a single aspect of a node is healthy. +type HealthChecker interface { + // Name returns a human-readable identifier for this checker (e.g. "NodeReady"). + Name() string + // Check returns whether the node is healthy and a reason string. + // On success the reason is empty. On failure it describes what went wrong. + Check(node *corev1.Node) (healthy bool, reason string) + // Threshold returns how long this checker must continuously fail + // before a recovery action is triggered. + // A zero value means "trigger immediately on failure" (no wait period). + Threshold() time.Duration +} + +// NewDefaultCheckers returns the enabled health checkers. +// GPU checker auto-skips non-GPU nodes by checking the nvidia.com/gpu.count label. +// Cilium checker auto-skips nodes without CiliumAgentIsReady condition. +func NewDefaultCheckers() []HealthChecker { + return []HealthChecker{ + &nodeReadyChecker{}, + &diskPressureChecker{}, + &ciliumChecker{}, + &gpuChecker{}, + } +} diff --git a/pkg/health/node_ready.go b/pkg/health/node_ready.go new file mode 100644 index 0000000..edaf2a5 --- /dev/null +++ b/pkg/health/node_ready.go @@ -0,0 +1,24 @@ +package health + +import ( + "time" + + corev1 "k8s.io/api/core/v1" +) + +const nodeReadyThreshold = 5 * time.Minute + +// nodeReadyChecker reports healthy when the node's NodeReady condition is True. +type nodeReadyChecker struct{} + +func (c *nodeReadyChecker) Name() string { return "NodeReady" } +func (c *nodeReadyChecker) Threshold() time.Duration { return nodeReadyThreshold } + +func (c *nodeReadyChecker) Check(node *corev1.Node) (bool, string) { + for _, cond := range node.Status.Conditions { + if cond.Type == corev1.NodeReady { + return cond.Status == corev1.ConditionTrue, cond.Reason + } + } + return false, "NodeReady condition not found" +} diff --git a/pkg/health/node_ready_test.go b/pkg/health/node_ready_test.go new file mode 100644 index 0000000..9b7747d --- /dev/null +++ b/pkg/health/node_ready_test.go @@ -0,0 +1,87 @@ +package health + +import ( + "testing" + "time" + + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +func TestNodeReadyChecker_Threshold(t *testing.T) { + c := &nodeReadyChecker{} + if got := c.Threshold(); got != 5*time.Minute { + t.Errorf("got %v, want %v", got, 5*time.Minute) + } +} + +func TestNodeReadyChecker_Name(t *testing.T) { + c := &nodeReadyChecker{} + if got := c.Name(); got != "NodeReady" { + t.Errorf("got %q, want %q", got, "NodeReady") + } +} + +func TestNodeReadyChecker_Check(t *testing.T) { + tests := []struct { + description string + node *corev1.Node + want bool + }{ + { + description: "returns true when NodeReady condition is True", + node: &corev1.Node{ + ObjectMeta: metav1.ObjectMeta{Name: "node-01"}, + Status: corev1.NodeStatus{ + Conditions: []corev1.NodeCondition{ + {Type: corev1.NodeReady, Status: corev1.ConditionTrue}, + }, + }, + }, + want: true, + }, + { + description: "returns false when NodeReady condition is False", + node: &corev1.Node{ + ObjectMeta: metav1.ObjectMeta{Name: "node-01"}, + Status: corev1.NodeStatus{ + Conditions: []corev1.NodeCondition{ + {Type: corev1.NodeReady, Status: corev1.ConditionFalse}, + }, + }, + }, + want: false, + }, + { + description: "returns false when no conditions present", + node: &corev1.Node{ + ObjectMeta: metav1.ObjectMeta{Name: "node-01"}, + Status: corev1.NodeStatus{ + Conditions: []corev1.NodeCondition{}, + }, + }, + want: false, + }, + { + description: "returns false when only non-NodeReady conditions present", + node: &corev1.Node{ + ObjectMeta: metav1.ObjectMeta{Name: "node-01"}, + Status: corev1.NodeStatus{ + Conditions: []corev1.NodeCondition{ + {Type: corev1.NodeDiskPressure, Status: corev1.ConditionFalse}, + }, + }, + }, + want: false, + }, + } + + c := &nodeReadyChecker{} + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { + if got, _ := c.Check(test.node); got != test.want { + t.Errorf("got %v, want %v", got, test.want) + } + }) + } +} diff --git a/pkg/metrics/metrics.go b/pkg/metrics/metrics.go new file mode 100644 index 0000000..c099678 --- /dev/null +++ b/pkg/metrics/metrics.go @@ -0,0 +1,98 @@ +package metrics + +import ( + "net/http" + + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promhttp" +) + +var ( + // HealthCheckTotal counts the number of health check executions per node, + // checker, and result (pass/fail). + HealthCheckTotal = prometheus.NewCounterVec( + prometheus.CounterOpts{ + Name: "civo_node_agent_health_check_total", + Help: "Total number of health check executions.", + }, + []string{"node", "checker", "result"}, + ) + + // RecoveryActionsTotal counts the number of recovery actions performed + // per node, action type (reboot), and mode (report/active). + RecoveryActionsTotal = prometheus.NewCounterVec( + prometheus.CounterOpts{ + Name: "civo_node_agent_recovery_actions_total", + Help: "Total number of recovery actions performed.", + }, + []string{"node", "action", "mode"}, + ) + + // RecoveryFailuresTotal counts the number of recovery actions that failed + // (e.g. Civo API errors). + RecoveryFailuresTotal = prometheus.NewCounterVec( + prometheus.CounterOpts{ + Name: "civo_node_agent_recovery_failures_total", + Help: "Total number of recovery actions that failed.", + }, + []string{"node", "action"}, + ) + + // ReconcileErrorsTotal counts errors encountered during the reconcile loop, + // labeled by reason (e.g. "list_nodes"). + ReconcileErrorsTotal = prometheus.NewCounterVec( + prometheus.CounterOpts{ + Name: "civo_node_agent_reconcile_errors_total", + Help: "Total number of errors encountered during the reconcile loop.", + }, + []string{"reason"}, + ) + + // NodeUnhealthyDurationSeconds tracks how long each node has been + // continuously unhealthy, in seconds. + NodeUnhealthyDurationSeconds = prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Name: "civo_node_agent_node_unhealthy_duration_seconds", + Help: "Duration in seconds a node has been continuously unhealthy.", + }, + []string{"node"}, + ) + + // RecoveryPhase reports the current recovery phase for each node. + // The value is the numeric NodePhase (0=Healthy, 1=Unhealthy, etc.). + RecoveryPhase = prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Name: "civo_node_agent_recovery_phase", + Help: "Current recovery phase of a node.", + }, + []string{"node", "phase"}, + ) + + // Info exposes build and cluster identity as a constant gauge (value is always 1). + // Use PromQL joins (group_left) to enrich other metrics with version/cluster_id. + Info = prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Name: "civo_node_agent_info", + Help: "Build and cluster identity for the node-agent.", + }, + []string{"version", "cluster_id"}, + ) +) + +// Register registers all node-agent metrics with the default Prometheus registerer. +func Register() { + prometheus.MustRegister( + HealthCheckTotal, + RecoveryActionsTotal, + RecoveryFailuresTotal, + ReconcileErrorsTotal, + NodeUnhealthyDurationSeconds, + RecoveryPhase, + Info, + ) +} + +// Handler returns an http.Handler that serves Prometheus metrics. +func Handler() http.Handler { + return promhttp.Handler() +} diff --git a/pkg/operation/civo.go b/pkg/operation/civo.go new file mode 100644 index 0000000..ef9cd67 --- /dev/null +++ b/pkg/operation/civo.go @@ -0,0 +1,70 @@ +package operation + +import ( + "context" + "fmt" + "log/slog" + + "github.com/civo/civogo" +) + +// civoExecutor implements Executor using the Civo API. +type civoExecutor struct { + civoClient civogo.Clienter + clusterID string + + apiKey string + apiURL string + region string + version string +} + +// NewCivoExecutor creates an Executor that performs recovery actions via the Civo API. +func NewCivoExecutor(clusterID string, opts ...Option) (Executor, error) { + e := &civoExecutor{clusterID: clusterID} + for _, opt := range opts { + opt(e) + } + + if clusterID == "" { + return nil, fmt.Errorf("cluster ID must not be empty") + } + + if e.civoClient != nil { + return e, nil + } + + if e.apiKey == "" { + return nil, fmt.Errorf("API key must not be empty") + } + if e.apiURL == "" { + return nil, fmt.Errorf("API URL must not be empty") + } + + client, err := civogo.NewClientWithURL(e.apiKey, e.apiURL, e.region) + if err != nil { + return nil, fmt.Errorf("failed to initialise civo client: %w", err) + } + client.SetUserAgent(&civogo.Component{ + ID: clusterID, + Name: "node-agent", + Version: e.version, + }) + e.civoClient = client + return e, nil +} + +func (e *civoExecutor) Reboot(_ context.Context, nodeName string) error { + instance, err := e.civoClient.FindKubernetesClusterInstance(e.clusterID, nodeName) + if err != nil { + return fmt.Errorf("failed to find instance, clusterID: %s, nodeName: %s: %w", e.clusterID, nodeName, err) + } + + _, err = e.civoClient.HardRebootInstance(instance.ID) + if err != nil { + return fmt.Errorf("failed to reboot instance, clusterID: %s, instanceID: %s: %w", e.clusterID, instance.ID, err) + } + + slog.Info("Instance is rebooting", "instanceID", instance.ID, "node", nodeName) + return nil +} diff --git a/pkg/operation/civo_test.go b/pkg/operation/civo_test.go new file mode 100644 index 0000000..4dcf06d --- /dev/null +++ b/pkg/operation/civo_test.go @@ -0,0 +1,148 @@ +package operation + +import ( + "errors" + "testing" + + "github.com/civo/civogo" +) + +// fakeClient overrides the Civo API methods needed by CivoExecutor. +type fakeClient struct { + findFunc func(clusterID, search string) (*civogo.Instance, error) + rebootFunc func(id string) (*civogo.SimpleResponse, error) + + *civogo.FakeClient +} + +func (f *fakeClient) FindKubernetesClusterInstance(clusterID, search string) (*civogo.Instance, error) { + if f.findFunc != nil { + return f.findFunc(clusterID, search) + } + return f.FakeClient.FindKubernetesClusterInstance(clusterID, search) +} + +func (f *fakeClient) HardRebootInstance(id string) (*civogo.SimpleResponse, error) { + if f.rebootFunc != nil { + return f.rebootFunc(id) + } + return f.FakeClient.HardRebootInstance(id) +} + +var _ civogo.Clienter = (*fakeClient)(nil) + +func TestCivoExecutor_Reboot(t *testing.T) { + tests := []struct { + description string + nodeName string + setupClient func(t *testing.T) *fakeClient + wantErr bool + }{ + { + description: "returns nil on successful find and reboot", + nodeName: "node-01", + setupClient: func(t *testing.T) *fakeClient { + return &fakeClient{ + findFunc: func(clusterID, search string) (*civogo.Instance, error) { + return &civogo.Instance{ID: "instance-01"}, nil + }, + rebootFunc: func(id string) (*civogo.SimpleResponse, error) { + if id != "instance-01" { + t.Errorf("instanceID mismatch: got %s, want instance-01", id) + } + return new(civogo.SimpleResponse), nil + }, + } + }, + }, + { + description: "returns error when instance lookup fails", + nodeName: "node-01", + setupClient: func(t *testing.T) *fakeClient { + return &fakeClient{ + findFunc: func(_, _ string) (*civogo.Instance, error) { + return nil, errors.New("not found") + }, + } + }, + wantErr: true, + }, + { + description: "returns error when hard reboot fails", + nodeName: "node-01", + setupClient: func(t *testing.T) *fakeClient { + return &fakeClient{ + findFunc: func(_, _ string) (*civogo.Instance, error) { + return &civogo.Instance{ID: "instance-01"}, nil + }, + rebootFunc: func(_ string) (*civogo.SimpleResponse, error) { + return nil, errors.New("reboot failed") + }, + } + }, + wantErr: true, + }, + } + + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { + client := test.setupClient(t) + exec, err := NewCivoExecutor("test-cluster", WithClient(client)) + if err != nil { + t.Fatal(err) + } + err = exec.Reboot(t.Context(), test.nodeName) + if (err != nil) != test.wantErr { + t.Errorf("error = %v, wantErr %v", err, test.wantErr) + } + }) + } +} + +func TestNewCivoExecutor_Validation(t *testing.T) { + tests := []struct { + description string + id string + opts []Option + wantErr bool + }{ + { + description: "returns no error with injected client", + id: "test-cluster", + opts: []Option{WithClient(&fakeClient{})}, + }, + { + description: "returns error when clusterID is empty without injected client", + id: "", + opts: []Option{WithAPIConfig("key", "https://api.civo.com", "lon1", "0.0.1")}, + wantErr: true, + }, + { + description: "returns error when apiKey is empty", + id: "test-cluster", + opts: []Option{WithAPIConfig("", "https://api.civo.com", "lon1", "0.0.1")}, + wantErr: true, + }, + { + description: "returns error when apiURL is empty", + id: "test-cluster", + opts: []Option{WithAPIConfig("key", "", "lon1", "0.0.1")}, + wantErr: true, + }, + { + description: "returns error when clusterID is empty even with injected client", + id: "", + opts: []Option{WithClient(&fakeClient{})}, + wantErr: true, + }, + } + + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { + _, err := NewCivoExecutor(test.id, test.opts...) + if (err != nil) != test.wantErr { + t.Errorf("error = %v, wantErr %v", err, test.wantErr) + } + }) + } +} diff --git a/pkg/operation/operation.go b/pkg/operation/operation.go new file mode 100644 index 0000000..1498a6b --- /dev/null +++ b/pkg/operation/operation.go @@ -0,0 +1,19 @@ +package operation + +import "context" + +// Executor performs recovery operations on cluster nodes. +type Executor interface { + Reboot(ctx context.Context, nodeName string) error +} + +// nopExecutor is a no-op Executor that does nothing. +// Used as a safe default to prevent nil pointer dereference. +type nopExecutor struct{} + +func (e *nopExecutor) Reboot(_ context.Context, _ string) error { return nil } + +// NewNopExecutor returns an Executor that performs no operations. +func NewNopExecutor() Executor { + return &nopExecutor{} +} diff --git a/pkg/operation/options.go b/pkg/operation/options.go new file mode 100644 index 0000000..49ca08a --- /dev/null +++ b/pkg/operation/options.go @@ -0,0 +1,24 @@ +package operation + +import "github.com/civo/civogo" + +// Option represents a configuration function that modifies civoExecutor. +type Option func(*civoExecutor) + +// WithAPIConfig returns Option to configure the Civo API credentials and version. +// The client is created internally using these values. +func WithAPIConfig(apiKey, apiURL, region, version string) Option { + return func(e *civoExecutor) { + e.apiKey = apiKey + e.apiURL = apiURL + e.region = region + e.version = version + } +} + +// WithClient returns Option to inject a pre-built Civo client (for testing). +func WithClient(client civogo.Clienter) Option { + return func(e *civoExecutor) { + e.civoClient = client + } +} diff --git a/pkg/watcher/fake.go b/pkg/watcher/fake.go deleted file mode 100644 index 40c37d9..0000000 --- a/pkg/watcher/fake.go +++ /dev/null @@ -1,28 +0,0 @@ -package watcher - -import "github.com/civo/civogo" - -// FakeClient is a test client used for more flexible behavior control -// when FakeClient alone is not sufficient. -type FakeClient struct { - HardRebootInstanceFunc func(id string) (*civogo.SimpleResponse, error) - FindKubernetesClusterInstanceFunc func(clusterID, search string) (*civogo.Instance, error) - - *civogo.FakeClient -} - -func (f *FakeClient) HardRebootInstance(id string) (*civogo.SimpleResponse, error) { - if f.HardRebootInstanceFunc != nil { - return f.HardRebootInstanceFunc(id) - } - return f.FakeClient.HardRebootInstance(id) -} - -func (f *FakeClient) FindKubernetesClusterInstance(clusterID, search string) (*civogo.Instance, error) { - if f.FindKubernetesClusterInstanceFunc != nil { - return f.FindKubernetesClusterInstanceFunc(clusterID, search) - } - return f.FakeClient.FindKubernetesClusterInstance(clusterID, search) -} - -var _ civogo.Clienter = (*FakeClient)(nil) diff --git a/pkg/watcher/options.go b/pkg/watcher/options.go index 904634d..f1edcd4 100644 --- a/pkg/watcher/options.go +++ b/pkg/watcher/options.go @@ -3,18 +3,25 @@ package watcher import ( "log/slog" "strconv" + "strings" "time" - "github.com/civo/civogo" + "github.com/civo/node-agent/pkg/health" + "github.com/civo/node-agent/pkg/operation" "k8s.io/client-go/kubernetes" + listerscorev1 "k8s.io/client-go/listers/core/v1" ) // Option represents a configuration function that modifies watcher object. type Option func(*watcher) var defaultOptions = []Option{ - WithRebootTimeWindowMinutes("40"), - WithDesiredGPUCount("0"), + WithMonitorOnly("true"), + WithExecutor(operation.NewNopExecutor()), + WithRebootWaitMinutes("10"), + WithGPURebootWaitMinutes("40"), + WithMaxRebootRetries("5"), + WithMaxRebootFailures("30"), } // WithKubernetesClient returns Option to set Kubernetes API client. @@ -26,7 +33,7 @@ func WithKubernetesClient(client kubernetes.Interface) Option { } } -// WithKubernetesClient returns Option to set Kubernetes config path. +// WithKubernetesClientConfigPath returns Option to set Kubernetes config path. func WithKubernetesClientConfigPath(path string) Option { return func(w *watcher) { if path != "" { @@ -35,35 +42,117 @@ func WithKubernetesClientConfigPath(path string) Option { } } -// WithCivoClient returns Option to set Civo API client. -func WithCivoClient(client civogo.Clienter) Option { +// WithNodePoolIDs returns Option to append node pool IDs to watch. +// Accepts a comma-separated string (e.g. "pool-1,pool-2"). +// Can be called multiple times to accumulate IDs. +// Empty string is a no-op. If no IDs are provided across all calls, all nodes are watched. +func WithNodePoolIDs(s string) Option { return func(w *watcher) { - if client != nil { - w.civoClient = client + for _, id := range strings.Split(s, ",") { + if v := strings.TrimSpace(id); v != "" { + w.nodePoolIDs = append(w.nodePoolIDs, v) + } } } } -// WithRebootTimeWindowMinutes returns Option to set reboot time window. -func WithRebootTimeWindowMinutes(s string) Option { +// WithRebootWaitMinutes returns Option to set the reboot wait time for standard nodes. +func WithRebootWaitMinutes(s string) Option { return func(w *watcher) { n, err := strconv.Atoi(s) if err == nil && n > 0 { - w.rebootTimeWindowMinutes = time.Duration(n) + w.rebootWaitMinutes = time.Duration(n) * time.Minute } else { - slog.Info("RebootTimeWindowMinutes is invalid", "value", s) + slog.Info("RebootWaitMinutes is invalid", "value", s) } } } -// WithDesiredGPUCount returns Option to set desired GPU count . -func WithDesiredGPUCount(s string) Option { +// WithGPURebootWaitMinutes returns Option to set the reboot wait time for GPU nodes. +func WithGPURebootWaitMinutes(s string) Option { return func(w *watcher) { n, err := strconv.Atoi(s) - if err == nil && n >= 0 { - w.nodeDesiredGPUCount = n + if err == nil && n > 0 { + w.gpuRebootWaitMinutes = time.Duration(n) * time.Minute + } else { + slog.Info("GPURebootWaitMinutes is invalid", "value", s) + } + } +} + +// WithMaxRebootRetries returns Option to set the maximum number of reboot +// attempts before a node transitions to PhaseFailed. +func WithMaxRebootRetries(s string) Option { + return func(w *watcher) { + n, err := strconv.Atoi(s) + if err == nil && n > 0 { + w.maxRebootRetries = n + } else { + slog.Info("MaxRebootRetries is invalid", "value", s) + } + } +} + +// WithMaxRebootFailures returns Option to set the maximum number of reboot +// call failures tolerated before a node transitions to PhaseFailed. +// +// Intentionally not exposed as an env var: reboot call failures are not +// followed by a wait window, so a high value would let the agent hammer +// the Civo API on sustained failures. The bound is controlled here via +// the default option to cap the blast radius. +func WithMaxRebootFailures(s string) Option { + return func(w *watcher) { + n, err := strconv.Atoi(s) + if err == nil && n > 0 { + w.maxRebootFailures = n + } else { + slog.Info("MaxRebootFailures is invalid", "value", s) + } + } +} + +// WithMonitorOnly returns Option to enable or disable monitor-only mode. +// Accepts a string parsable by strconv.ParseBool (e.g. "true", "false", "1", "0"). +// Empty or unparsable values are ignored (default: true). +func WithMonitorOnly(s string) Option { + return func(w *watcher) { + if v, err := strconv.ParseBool(s); err == nil { + w.monitorOnly = v } else { - slog.Info("DesiredGPUCount is invalid", "value", s) + slog.Info("MonitorOnly is invalid", "value", s) } } } + +// WithCheckers returns Option to set the health checkers. +func WithCheckers(checkers []health.HealthChecker) Option { + return func(w *watcher) { + w.checkers = checkers + } +} + +// WithExecutor returns Option to set the recovery executor. +func WithExecutor(exec operation.Executor) Option { + return func(w *watcher) { + if exec != nil { + w.executor = exec + } + } +} + +// withNowFunc returns Option to override the time source (for testing). +func withNowFunc(fn func() time.Time) Option { + return func(w *watcher) { + if fn != nil { + w.nowFunc = fn + } + } +} + +// withNodeLister returns Option to inject a node lister (for testing). +// When set, the informer setup is skipped. +func withNodeLister(lister listerscorev1.NodeLister) Option { + return func(w *watcher) { + w.nodeLister = lister + } +} diff --git a/pkg/watcher/state.go b/pkg/watcher/state.go new file mode 100644 index 0000000..5b643da --- /dev/null +++ b/pkg/watcher/state.go @@ -0,0 +1,225 @@ +package watcher + +import ( + "sync" + "time" +) + +// NodePhase represents the current recovery phase of a node. +type NodePhase int + +const ( + PhaseUnknown NodePhase = iota // 0 - unknown/uninitialized + PhaseHealthy // 1 - node is healthy + PhaseUnhealthy // 2 - checker(s) failing, waiting for threshold + PhaseWaitingReboot // 3 - waiting for reboot to take effect + PhaseDrain // 4 - future: draining pods + PhaseReplace // 5 - future: replace issued + PhaseFailed // 6 - recovery gave up (exceeded retries); awaits manual intervention or natural recovery +) + +// String returns the string representation of a NodePhase. +func (p NodePhase) String() string { + switch p { + case PhaseUnknown: + return "Unknown" + case PhaseHealthy: + return "Healthy" + case PhaseUnhealthy: + return "Unhealthy" + case PhaseWaitingReboot: + return "WaitingReboot" + case PhaseDrain: + return "Drain" + case PhaseReplace: + return "Replace" + case PhaseFailed: + return "Failed" + default: + return "Unknown" + } +} + +// NodeState holds the recovery state for a single node. +// All fields are private; read via getters, mutate via StateStore methods. +type NodeState struct { + mu sync.RWMutex + phase NodePhase + + unhealthySince time.Time + + lastRebootTime time.Time + rebootCount int + failedRebootCount int + + failedCheckers []string + isGPUNode bool +} + +func (s *NodeState) Phase() NodePhase { + s.mu.RLock() + defer s.mu.RUnlock() + return s.phase +} +func (s *NodeState) UnhealthySince() time.Time { + s.mu.RLock() + defer s.mu.RUnlock() + return s.unhealthySince +} +func (s *NodeState) LastRebootTime() time.Time { + s.mu.RLock() + defer s.mu.RUnlock() + return s.lastRebootTime +} +func (s *NodeState) RebootCount() int { + s.mu.RLock() + defer s.mu.RUnlock() + return s.rebootCount +} +func (s *NodeState) FailedRebootCount() int { + s.mu.RLock() + defer s.mu.RUnlock() + return s.failedRebootCount +} +func (s *NodeState) IsGPUNode() bool { + s.mu.RLock() + defer s.mu.RUnlock() + return s.isGPUNode +} + +// StateStore is a concurrency-safe store for per-node recovery state. +type StateStore struct { + mu sync.RWMutex + nodes map[string]*NodeState +} + +// NewStateStore creates a new empty StateStore. +func NewStateStore() *StateStore { + return &StateStore{ + nodes: make(map[string]*NodeState), + } +} + +// GetOrCreate returns the NodeState for the given node name, +// creating a new one (PhaseHealthy) if it does not exist. +func (s *StateStore) GetOrCreate(name string) *NodeState { + s.mu.Lock() + defer s.mu.Unlock() + if st, ok := s.nodes[name]; ok { + return st + } + st := &NodeState{phase: PhaseHealthy} + s.nodes[name] = st + return st +} + +// Get returns the NodeState for the given node name and whether it was found. +func (s *StateStore) Get(name string) (*NodeState, bool) { + s.mu.RLock() + defer s.mu.RUnlock() + st, ok := s.nodes[name] + return st, ok +} + +// Delete removes the state entry for the given node name. +func (s *StateStore) Delete(name string) { + s.mu.Lock() + defer s.mu.Unlock() + delete(s.nodes, name) +} + +// Range calls fn for each node state entry. If fn returns false, iteration stops. +func (s *StateStore) Range(fn func(name string, state *NodeState) bool) { + s.mu.RLock() + snapshot := make(map[string]*NodeState, len(s.nodes)) + for name, state := range s.nodes { + snapshot[name] = state + } + s.mu.RUnlock() + for name, state := range snapshot { + if !fn(name, state) { + return + } + } +} + +// UpdateCheckerInfo updates the failed checker names and GPU flag for a node. +func (s *StateStore) UpdateCheckerInfo(name string, failedCheckers []string, isGPUNode bool) { + st, ok := s.Get(name) + if !ok { + return + } + st.mu.Lock() + st.failedCheckers = failedCheckers + st.isGPUNode = isGPUNode + st.mu.Unlock() +} + +// MarkUnhealthy transitions a node to PhaseUnhealthy and records when it became unhealthy. +func (s *StateStore) MarkUnhealthy(name string, now time.Time) { + st, ok := s.Get(name) + if !ok { + return + } + st.mu.Lock() + st.phase = PhaseUnhealthy + st.unhealthySince = now + st.mu.Unlock() +} + +// MarkWaitingReboot transitions a node to PhaseWaitingReboot, records the +// reboot time, and increments the reboot counter. +func (s *StateStore) MarkWaitingReboot(name string, now time.Time) { + st, ok := s.Get(name) + if !ok { + return + } + st.mu.Lock() + st.phase = PhaseWaitingReboot + st.lastRebootTime = now + st.rebootCount++ + st.mu.Unlock() +} + +// RecordRebootFailure increments the failed reboot counter for a node. +// The node's phase is not changed; the caller decides whether to transition. +func (s *StateStore) RecordRebootFailure(name string) { + st, ok := s.Get(name) + if !ok { + return + } + st.mu.Lock() + st.failedRebootCount++ + st.mu.Unlock() +} + +// MarkFailed transitions a node to PhaseFailed after recovery attempts were exhausted. +func (s *StateStore) MarkFailed(name string) { + st, ok := s.Get(name) + if !ok { + return + } + st.mu.Lock() + st.phase = PhaseFailed + st.mu.Unlock() +} + +// Reset replaces the node's state with a fresh PhaseHealthy entry. +func (s *StateStore) Reset(name string) { + s.mu.Lock() + defer s.mu.Unlock() + if _, ok := s.nodes[name]; ok { + s.nodes[name] = &NodeState{phase: PhaseHealthy} + } +} + +// Cleanup removes state entries for nodes that are not in the activeNodes set. +func (s *StateStore) Cleanup(activeNodes map[string]struct{}) { + s.mu.Lock() + defer s.mu.Unlock() + for name := range s.nodes { + if _, ok := activeNodes[name]; !ok { + delete(s.nodes, name) + } + } +} diff --git a/pkg/watcher/state_test.go b/pkg/watcher/state_test.go new file mode 100644 index 0000000..b21e1ee --- /dev/null +++ b/pkg/watcher/state_test.go @@ -0,0 +1,321 @@ +package watcher + +import ( + "testing" + "time" +) + +func TestNodePhaseString(t *testing.T) { + tests := []struct { + phase NodePhase + want string + }{ + {PhaseUnknown, "Unknown"}, + {PhaseHealthy, "Healthy"}, + {PhaseUnhealthy, "Unhealthy"}, + {PhaseWaitingReboot, "WaitingReboot"}, + {PhaseDrain, "Drain"}, + {PhaseReplace, "Replace"}, + {PhaseFailed, "Failed"}, + {NodePhase(99), "Unknown"}, + } + + for _, test := range tests { + t.Run(test.want, func(t *testing.T) { + if got := test.phase.String(); got != test.want { + t.Errorf("got %q, want %q", got, test.want) + } + }) + } +} + +func TestNodePhaseZeroValue(t *testing.T) { + var phase NodePhase + if phase != PhaseUnknown { + t.Errorf("zero value of NodePhase should be PhaseUnknown, got %v", phase) + } +} + +func TestStateStoreGetOrCreate(t *testing.T) { + s := NewStateStore() + + st := s.GetOrCreate("node-01") + if st.Phase() != PhaseHealthy { + t.Errorf("new state should be PhaseHealthy, got %v", st.Phase()) + } + + st2 := s.GetOrCreate("node-01") + if st != st2 { + t.Error("GetOrCreate should return the same pointer for existing node") + } +} + +func TestStateStoreGet(t *testing.T) { + s := NewStateStore() + + _, ok := s.Get("nonexistent") + if ok { + t.Error("Get should return false for nonexistent node") + } + + s.GetOrCreate("node-01") + st, ok := s.Get("node-01") + if !ok { + t.Error("Get should return true for existing node") + } + if st.Phase() != PhaseHealthy { + t.Errorf("got phase %v, want PhaseHealthy", st.Phase()) + } +} + +func TestStateStoreDelete(t *testing.T) { + s := NewStateStore() + s.GetOrCreate("node-01") + + s.Delete("node-01") + _, ok := s.Get("node-01") + if ok { + t.Error("node should be deleted") + } + + // Deleting nonexistent node should not panic. + s.Delete("nonexistent") +} + +func TestStateStoreRange(t *testing.T) { + s := NewStateStore() + s.GetOrCreate("node-01") + s.GetOrCreate("node-02") + s.GetOrCreate("node-03") + + visited := make(map[string]bool) + s.Range(func(name string, _ *NodeState) bool { + visited[name] = true + return true + }) + + if len(visited) != 3 { + t.Errorf("Range should visit 3 nodes, visited %d", len(visited)) + } +} + +func TestStateStoreRangeEarlyStop(t *testing.T) { + s := NewStateStore() + s.GetOrCreate("node-01") + s.GetOrCreate("node-02") + s.GetOrCreate("node-03") + + count := 0 + s.Range(func(_ string, _ *NodeState) bool { + count++ + return false + }) + + if count != 1 { + t.Errorf("Range should stop after first call when fn returns false, visited %d", count) + } +} + +func TestStateStoreMarkUnhealthy(t *testing.T) { + s := NewStateStore() + s.GetOrCreate("node-01") + + now := time.Date(2026, 4, 13, 12, 0, 0, 0, time.UTC) + s.MarkUnhealthy("node-01", now) + + st, _ := s.Get("node-01") + if st.Phase() != PhaseUnhealthy { + t.Errorf("got phase %v, want PhaseUnhealthy", st.Phase()) + } + if !st.UnhealthySince().Equal(now) { + t.Errorf("got unhealthySince %v, want %v", st.UnhealthySince(), now) + } +} + +func TestStateStoreMarkUnhealthyNonexistent(t *testing.T) { + s := NewStateStore() + // Should not panic. + s.MarkUnhealthy("nonexistent", time.Now()) +} + +func TestStateStoreMarkWaitingReboot(t *testing.T) { + s := NewStateStore() + s.GetOrCreate("node-01") + + now := time.Date(2026, 4, 13, 12, 0, 0, 0, time.UTC) + s.MarkWaitingReboot("node-01", now) + + st, _ := s.Get("node-01") + if st.Phase() != PhaseWaitingReboot { + t.Errorf("got phase %v, want PhaseWaitingReboot", st.Phase()) + } + if !st.LastRebootTime().Equal(now) { + t.Errorf("got lastRebootTime %v, want %v", st.LastRebootTime(), now) + } + if st.RebootCount() != 1 { + t.Errorf("got rebootCount %d, want 1", st.RebootCount()) + } + + // Retry increments count. + later := now.Add(time.Hour) + s.MarkWaitingReboot("node-01", later) + + st, _ = s.Get("node-01") + if st.RebootCount() != 2 { + t.Errorf("got rebootCount %d after retry, want 2", st.RebootCount()) + } + if !st.LastRebootTime().Equal(later) { + t.Errorf("got lastRebootTime %v after retry, want %v", st.LastRebootTime(), later) + } +} + +func TestStateStoreMarkWaitingRebootNonexistent(t *testing.T) { + s := NewStateStore() + // Should not panic. + s.MarkWaitingReboot("nonexistent", time.Now()) +} + +func TestStateStoreRecordRebootFailure(t *testing.T) { + s := NewStateStore() + s.GetOrCreate("node-01") + + s.RecordRebootFailure("node-01") + st, _ := s.Get("node-01") + if st.FailedRebootCount() != 1 { + t.Errorf("got failedRebootCount %d, want 1", st.FailedRebootCount()) + } + if st.Phase() != PhaseHealthy { + t.Errorf("RecordRebootFailure should not change phase, got %v", st.Phase()) + } + + s.RecordRebootFailure("node-01") + if st.FailedRebootCount() != 2 { + t.Errorf("got failedRebootCount %d after second call, want 2", st.FailedRebootCount()) + } +} + +func TestStateStoreRecordRebootFailureNonexistent(t *testing.T) { + s := NewStateStore() + // Should not panic. + s.RecordRebootFailure("nonexistent") +} + +func TestStateStoreMarkFailed(t *testing.T) { + s := NewStateStore() + s.GetOrCreate("node-01") + s.MarkWaitingReboot("node-01", time.Now()) + + s.MarkFailed("node-01") + + st, _ := s.Get("node-01") + if st.Phase() != PhaseFailed { + t.Errorf("got phase %v, want PhaseFailed", st.Phase()) + } +} + +func TestStateStoreMarkFailedNonexistent(t *testing.T) { + s := NewStateStore() + // Should not panic. + s.MarkFailed("nonexistent") +} + +func TestStateStoreUpdateCheckerInfo(t *testing.T) { + s := NewStateStore() + s.GetOrCreate("node-01") + + checkers := []string{"NodeReady", "GPU"} + s.UpdateCheckerInfo("node-01", checkers, true) + + st, _ := s.Get("node-01") + if !st.IsGPUNode() { + t.Error("expected isGPUNode to be true") + } +} + +func TestStateStoreUpdateCheckerInfoNonexistent(t *testing.T) { + s := NewStateStore() + // Should not panic. + s.UpdateCheckerInfo("nonexistent", []string{"NodeReady"}, false) +} + +func TestStateStoreReset(t *testing.T) { + s := NewStateStore() + s.GetOrCreate("node-01") + + now := time.Date(2026, 4, 13, 12, 0, 0, 0, time.UTC) + s.MarkUnhealthy("node-01", now) + s.UpdateCheckerInfo("node-01", []string{"NodeReady"}, true) + s.MarkWaitingReboot("node-01", now) + s.RecordRebootFailure("node-01") + + s.Reset("node-01") + + st, ok := s.Get("node-01") + if !ok { + t.Fatal("node should still exist after Reset") + } + if st.Phase() != PhaseHealthy { + t.Errorf("got phase %v, want PhaseHealthy", st.Phase()) + } + if st.RebootCount() != 0 { + t.Errorf("got rebootCount %d, want 0", st.RebootCount()) + } + if st.FailedRebootCount() != 0 { + t.Errorf("got failedRebootCount %d, want 0", st.FailedRebootCount()) + } + if !st.UnhealthySince().IsZero() { + t.Error("unhealthySince should be zero after Reset") + } + if !st.LastRebootTime().IsZero() { + t.Error("lastRebootTime should be zero after Reset") + } + if st.IsGPUNode() { + t.Error("isGPUNode should be false after Reset") + } +} + +func TestStateStoreResetNonexistent(t *testing.T) { + s := NewStateStore() + // Should not panic and should not create an entry. + s.Reset("nonexistent") + + _, ok := s.Get("nonexistent") + if ok { + t.Error("Reset on nonexistent node should not create an entry") + } +} + +func TestStateStoreResetReplacesPointer(t *testing.T) { + s := NewStateStore() + old := s.GetOrCreate("node-01") + + s.Reset("node-01") + + current, _ := s.Get("node-01") + if old == current { + t.Error("Reset should replace the map entry with a new pointer") + } +} + +func TestStateStoreCleanup(t *testing.T) { + s := NewStateStore() + s.GetOrCreate("node-01") + s.GetOrCreate("node-02") + s.GetOrCreate("node-03") + + active := map[string]struct{}{ + "node-01": {}, + "node-03": {}, + } + s.Cleanup(active) + + if _, ok := s.Get("node-01"); !ok { + t.Error("node-01 should still exist") + } + if _, ok := s.Get("node-02"); ok { + t.Error("node-02 should be removed") + } + if _, ok := s.Get("node-03"); !ok { + t.Error("node-03 should still exist") + } +} diff --git a/pkg/watcher/watcher.go b/pkg/watcher/watcher.go index 5f28054..c388289 100644 --- a/pkg/watcher/watcher.go +++ b/pkg/watcher/watcher.go @@ -4,25 +4,23 @@ import ( "context" "fmt" "log/slog" - "strconv" - "sync" "time" - "github.com/civo/civogo" - corev1 "k8s.io/api/core/v1" + "github.com/civo/node-agent/pkg/health" + "github.com/civo/node-agent/pkg/metrics" + "github.com/civo/node-agent/pkg/operation" + "github.com/prometheus/client_golang/prometheus" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/labels" + "k8s.io/client-go/informers" "k8s.io/client-go/kubernetes" + listerscorev1 "k8s.io/client-go/listers/core/v1" "k8s.io/client-go/rest" + "k8s.io/client-go/tools/cache" "k8s.io/client-go/tools/clientcmd" ) -// Version is the current version of the this watcher -var Version string = "0.0.1" - -const ( - nodePoolLabelKey = "kubernetes.civo.com/civo-node-pool" - gpuResourceName = "nvidia.com/gpu" -) +const nodePoolLabelKey = "kubernetes.civo.com/civo-node-pool" type Watcher interface { Run(ctx context.Context) error @@ -30,66 +28,50 @@ type Watcher interface { type watcher struct { client kubernetes.Interface - civoClient civogo.Clienter clientCfgPath string - clusterID string - region string - apiKey string - apiURL string - nodeDesiredGPUCount int - rebootTimeWindowMinutes time.Duration + nodePoolIDs []string + rebootWaitMinutes time.Duration // Standard nodes (default: 10) + gpuRebootWaitMinutes time.Duration // GPU nodes (default: 40) + maxRebootRetries int // Give up and transition to PhaseFailed after this many reboots + maxRebootFailures int // Give up and transition to PhaseFailed after this many reboot call failures - // NOTE: This is only effective when running with a single node-agent. If we want to run multiple instances, additional logic modifications will be required. - lastRebootCmdTimes sync.Map + nodeLabelSelector *metav1.LabelSelector + nodeLister listerscorev1.NodeLister - nodeSelector *metav1.LabelSelector + monitorOnly bool + checkers []health.HealthChecker + executor operation.Executor + states *StateStore + nowFunc func() time.Time } -func NewWatcher(ctx context.Context, apiURL, apiKey, region, clusterID, nodePoolID string, opts ...Option) (Watcher, error) { +func NewWatcher(opts ...Option) (Watcher, error) { w := &watcher{ - clusterID: clusterID, - apiKey: apiKey, - apiURL: apiURL, - region: region, + monitorOnly: true, + states: NewStateStore(), + nowFunc: func() time.Time { return time.Now().UTC() }, } for _, opt := range append(defaultOptions, opts...) { opt(w) } - if clusterID == "" { - return nil, fmt.Errorf("CIVO_CLUSTER_ID not set") - } - if nodePoolID == "" { - return nil, fmt.Errorf("CIVO_NODE_POOL_ID not set") - } - if w.civoClient == nil && apiKey == "" { - return nil, fmt.Errorf("CIVO_API_KEY not set") - } - - w.nodeSelector = &metav1.LabelSelector{ - MatchLabels: map[string]string{ - nodePoolLabelKey: nodePoolID, - }, - } + w.nodeLabelSelector = buildNodeSelector(w.nodePoolIDs) if err := w.setupKubernetesClient(); err != nil { return nil, err } - if err := w.setupCivoClient(); err != nil { - return nil, err - } return w, nil } // setupKubernetesClient creates Kubernetes client based on the kubeconfig path. // If kubeconfig path is not empty, the client will be created using that path. -// Otherwise, if the kubeconfig path is empty, the client will be created using the in-clustetr config. -func (w *watcher) setupKubernetesClient() (err error) { +// Otherwise, if the kubeconfig path is empty, the client will be created using the in-cluster config. +func (w *watcher) setupKubernetesClient() error { if w.clientCfgPath != "" && w.client == nil { cfg, err := clientcmd.BuildConfigFromFlags("", w.clientCfgPath) if err != nil { - return fmt.Errorf("failed to build kubeconfig from path %q: %w", cfg, err) + return fmt.Errorf("failed to build kubeconfig from path %q: %w", w.clientCfgPath, err) } w.client, err = kubernetes.NewForConfig(cfg) if err != nil { @@ -111,35 +93,48 @@ func (w *watcher) setupKubernetesClient() (err error) { return nil } -func (w *watcher) setupCivoClient() error { - if w.civoClient != nil { +func (w *watcher) setupInformer(ctx context.Context) error { + if w.nodeLister != nil { return nil } - client, err := civogo.NewClientWithURL(w.apiKey, w.apiURL, w.region) - if err != nil { - return fmt.Errorf("failed to initialise civo client: %w", err) + var informerOpts []informers.SharedInformerOption + if w.nodeLabelSelector != nil { + labelSelector := metav1.FormatLabelSelector(w.nodeLabelSelector) + slog.Info("Using node label selector", "selector", labelSelector) + informerOpts = append(informerOpts, informers.WithTweakListOptions(func(opts *metav1.ListOptions) { + opts.LabelSelector = labelSelector + })) + } else { + slog.Info("No node label selector configured, watching all nodes") } + factory := informers.NewSharedInformerFactoryWithOptions(w.client, 0, informerOpts...) + + nodeInformer := factory.Core().V1().Nodes() + w.nodeLister = nodeInformer.Lister() + + factory.Start(ctx.Done()) - userAgent := &civogo.Component{ - ID: w.clusterID, - Name: "node-agent", - Version: Version, + if !cache.WaitForCacheSync(ctx.Done(), nodeInformer.Informer().HasSynced) { + return fmt.Errorf("failed to sync node informer cache") } - client.SetUserAgent(userAgent) - w.civoClient = client + slog.Info("Node informer cache synced") return nil } func (w *watcher) Run(ctx context.Context) error { + if err := w.setupInformer(ctx); err != nil { + return err + } + ticker := time.NewTicker(10 * time.Second) defer ticker.Stop() + slog.Info("Watcher reconcile loop started") for { select { case <-ticker.C: - slog.Info("Started the watcher process...") if err := w.run(ctx); err != nil { slog.Error("An error occurred while running the watcher process", "error", err) } @@ -150,139 +145,233 @@ func (w *watcher) Run(ctx context.Context) error { } func (w *watcher) run(ctx context.Context) error { - nodes, err := w.client.CoreV1().Nodes().List(ctx, metav1.ListOptions{ - LabelSelector: metav1.FormatLabelSelector(w.nodeSelector), - }) + nodes, err := w.nodeLister.List(labels.Everything()) if err != nil { + metrics.ReconcileErrorsTotal.WithLabelValues("list_nodes").Inc() return err } - thresholdTime := time.Now().Add(-w.rebootTimeWindowMinutes * time.Minute) - - for _, node := range nodes.Items { - if !isNodeDesiredGPU(&node, w.nodeDesiredGPUCount) || !isNodeReady(&node) { - - // LTT: LastTransitionTime of node. - // LRCT: LastRebootCmdTimes - // 60: Threshold time (example) - // - LTT > 60 , LRCT < 60 dont reboot - // - LTT < 60 , LRCT < 60 dont reboot - // - LTT < 60 , LRCT > 60 dont reboot - // - LTT > 60, LRCT >. 60 reboot - slog.Info("Node is not ready, attempting to reboot", "node", node.GetName()) - if isReadyOrNotReadyStatusChangedAfter(&node, thresholdTime) { - slog.Info("Skipping reboot because Ready/NotReady status was updated recently", "node", node.GetName()) - continue - } - if w.isLastRebootCommandTimeAfter(node.GetName(), thresholdTime) { - slog.Info("Skipping reboot because Reboot command was executed recently", "node", node.GetName()) - continue - } - if err := w.rebootNode(node.GetName()); err != nil { - slog.Error("Failed to reboot Node", "node", node.GetName(), "error", err) - return fmt.Errorf("failed to reboot node: %w", err) + now := w.nowFunc() + activeNodes := make(map[string]struct{}, len(nodes)) + + for _, node := range nodes { + nodeName := node.GetName() + activeNodes[nodeName] = struct{}{} + + // Run all health checkers and collect failures. + var failedCheckers []string + var minThreshold time.Duration + for _, checker := range w.checkers { + healthy, _ := checker.Check(node) + result := "pass" + if !healthy { + result = "fail" + failedCheckers = append(failedCheckers, checker.Name()) + if minThreshold == 0 || checker.Threshold() < minThreshold { + minThreshold = checker.Threshold() + } } + metrics.HealthCheckTotal.WithLabelValues(nodeName, checker.Name(), result).Inc() } - } - return nil -} -func isReadyOrNotReadyStatusChangedAfter(node *corev1.Node, thresholdTime time.Time) bool { - var lastChangedTime time.Time - for _, cond := range node.Status.Conditions { - if cond.Type == corev1.NodeReady { - if cond.LastTransitionTime.After(lastChangedTime) { - lastChangedTime = cond.LastTransitionTime.Time + state := w.states.GetOrCreate(nodeName) + + // All checkers pass → node is healthy. + if len(failedCheckers) == 0 { + metrics.RecoveryPhase.WithLabelValues(nodeName, PhaseHealthy.String()).Set(1) + if prevPhase := state.Phase(); prevPhase != PhaseHealthy { + slog.Info("Node recovered", + "node", nodeName, + "previousPhase", prevPhase.String()) + metrics.NodeUnhealthyDurationSeconds.WithLabelValues(nodeName).Set(0) + metrics.RecoveryPhase.WithLabelValues(nodeName, prevPhase.String()).Set(0) + w.states.Reset(nodeName) } + continue } - } - slog.Info("Checking if Ready/NotReady status has changed recently", - "node", node.GetName(), - "lastTransitionTime", lastChangedTime.String(), - "thresholdTime", thresholdTime.String()) - - if lastChangedTime.IsZero() { - slog.Error("Node is in an invalid state, NodeReady condition not found", "node", node.GetName()) - return false - } - return lastChangedTime.After(thresholdTime) -} - -// isLastRebootCommandTimeAfter checks if the last reboot command time for the specified node -// is after the given threshold time. In case of delays in reboot, the -// LastTransitionTime of node might not be updated, so it compares the latest reboot -// command time to prevent sending reboot commands multiple times. -// NOTE: This is only effective when running with a single node-agent. If we want to run multiple instances, additional logic modifications will be required. -func (w *watcher) isLastRebootCommandTimeAfter(nodeName string, thresholdTime time.Time) bool { - v, ok := w.lastRebootCmdTimes.Load(nodeName) - if !ok { - slog.Info("LastRebootCommandTime not found", "node", nodeName) - return false - } - lastRebootCmdTime, ok := v.(time.Time) - if !ok { - slog.Info("LastRebootCommandTime is invalid, so it will be removed from the records", "node", nodeName, "value", v) - w.lastRebootCmdTimes.Delete(nodeName) - return false - } - - slog.Info("Checking if LastRebootCommandTime has changed recently", - "node", nodeName, - "lastRebootCommandTime", lastRebootCmdTime.String(), - "thresholdTime", thresholdTime.String()) + // At least one checker failed — enter the recovery judgment phase. + // The state machine decides the next action (wait, reboot, retry) + // regardless of which specific checker(s) failed. + isGPUNode := health.HasGPU(node) + w.states.UpdateCheckerInfo(nodeName, failedCheckers, isGPUNode) + + switch state.Phase() { + // Healthy → Unhealthy: health check failed for the first time, start tracking. + case PhaseHealthy: + w.states.MarkUnhealthy(nodeName, now) + slog.Info("Node unhealthy detected", + "node", nodeName, + "failedCheckers", failedCheckers) + metrics.NodeUnhealthyDurationSeconds.WithLabelValues(nodeName).Set(0) + metrics.RecoveryPhase.WithLabelValues(nodeName, PhaseHealthy.String()).Set(0) + metrics.RecoveryPhase.WithLabelValues(nodeName, PhaseUnhealthy.String()).Set(1) + + // Unhealthy → WaitingReboot: health check still failing and threshold exceeded, issue reboot. + case PhaseUnhealthy: + metrics.NodeUnhealthyDurationSeconds.WithLabelValues(nodeName).Set( + now.Sub(state.UnhealthySince()).Seconds()) + if now.Sub(state.UnhealthySince()) < minThreshold { + slog.Info("Waiting for unhealthy threshold", + "node", nodeName, + "elapsed", now.Sub(state.UnhealthySince()).String(), + "threshold", minThreshold.String(), + "failedCheckers", failedCheckers) + continue + } + // Reboot call failure budget exhausted before the first successful + // reboot → give up without attempting another reboot. + if state.FailedRebootCount() >= w.maxRebootFailures { + slog.Warn("Reboot call failure limit exceeded, giving up", + "node", nodeName, + "failedRebootCount", state.FailedRebootCount(), + "maxRebootFailures", w.maxRebootFailures, + "failedCheckers", failedCheckers) + metrics.RecoveryPhase.WithLabelValues(nodeName, PhaseUnhealthy.String()).Set(0) + metrics.RecoveryPhase.WithLabelValues(nodeName, PhaseFailed.String()).Set(1) + w.states.MarkFailed(nodeName) + continue + } + if !w.monitorOnly { + if err := w.executor.Reboot(ctx, nodeName); err != nil { + slog.Error("Failed to reboot node", "node", nodeName, "error", err) + metrics.RecoveryFailuresTotal.WithLabelValues(nodeName, "reboot").Inc() + w.states.RecordRebootFailure(nodeName) + continue + } + } + mode := modeLabel(w.monitorOnly) + slog.Info("Reboot initiated", + "node", nodeName, + "mode", mode, + "failedCheckers", failedCheckers) + metrics.RecoveryActionsTotal.WithLabelValues(nodeName, "reboot", mode).Inc() + metrics.RecoveryPhase.WithLabelValues(nodeName, PhaseUnhealthy.String()).Set(0) + metrics.RecoveryPhase.WithLabelValues(nodeName, PhaseWaitingReboot.String()).Set(1) + w.states.MarkWaitingReboot(nodeName, now) + + // WaitingReboot: health check still failing after reboot, retry after wait window. + case PhaseWaitingReboot: + metrics.NodeUnhealthyDurationSeconds.WithLabelValues(nodeName).Set( + now.Sub(state.UnhealthySince()).Seconds()) + rebootWait := w.rebootWaitMinutes + if state.IsGPUNode() { + rebootWait = w.gpuRebootWaitMinutes + } + if now.Sub(state.LastRebootTime()) < rebootWait { + // In monitor-only mode no reboot actually happened, so logging + // "waiting for reboot effect" every tick would be noisy. + // The "Reboot retry" log still fires once per rebootWait cycle as a liveness signal. + if !w.monitorOnly { + slog.Info("Waiting for reboot effect", + "node", nodeName, + "elapsed", now.Sub(state.LastRebootTime()).String(), + "rebootWait", rebootWait.String(), + "rebootCount", state.RebootCount(), + "isGPUNode", state.IsGPUNode()) + } + continue + } - return lastRebootCmdTime.After(thresholdTime) -} + // Retry budget exhausted → give up and transition to PhaseFailed. + // The node stays in Failed until it naturally recovers (all checkers pass). + // TODO: Standard nodes could transition to PhaseDrain → PhaseReplace here + // once that flow is wired up. GPU nodes must stay in Failed (never replaced). + if state.RebootCount() >= w.maxRebootRetries { + slog.Warn("Reboot retry limit exceeded, giving up", + "node", nodeName, + "rebootCount", state.RebootCount(), + "maxRebootRetries", w.maxRebootRetries, + "isGPUNode", state.IsGPUNode(), + "failedCheckers", failedCheckers) + metrics.RecoveryPhase.WithLabelValues(nodeName, PhaseWaitingReboot.String()).Set(0) + metrics.RecoveryPhase.WithLabelValues(nodeName, PhaseFailed.String()).Set(1) + w.states.MarkFailed(nodeName) + continue + } + // Reboot call failure budget exhausted → give up to cap Civo API load. + if state.FailedRebootCount() >= w.maxRebootFailures { + slog.Warn("Reboot call failure limit exceeded, giving up", + "node", nodeName, + "failedRebootCount", state.FailedRebootCount(), + "maxRebootFailures", w.maxRebootFailures, + "failedCheckers", failedCheckers) + metrics.RecoveryPhase.WithLabelValues(nodeName, PhaseWaitingReboot.String()).Set(0) + metrics.RecoveryPhase.WithLabelValues(nodeName, PhaseFailed.String()).Set(1) + w.states.MarkFailed(nodeName) + continue + } -func isNodeReady(node *corev1.Node) bool { - for _, cond := range node.Status.Conditions { - if cond.Type == corev1.NodeReady { - slog.Info("Current Node status", "node", node.GetName(), "type", corev1.NodeReady, "status", cond.Status) - return cond.Status == corev1.ConditionTrue + if !w.monitorOnly { + if err := w.executor.Reboot(ctx, nodeName); err != nil { + slog.Error("Failed to reboot node (retry)", "node", nodeName, "error", err) + metrics.RecoveryFailuresTotal.WithLabelValues(nodeName, "reboot").Inc() + w.states.RecordRebootFailure(nodeName) + continue + } + } + w.states.MarkWaitingReboot(nodeName, now) + mode := modeLabel(w.monitorOnly) + slog.Info("Reboot retry", + "node", nodeName, + "mode", mode, + "rebootCount", state.RebootCount(), + "failedCheckers", failedCheckers) + metrics.RecoveryActionsTotal.WithLabelValues(nodeName, "reboot", mode).Inc() + + // Failed: recovery attempts exhausted. Wait for natural recovery (all checkers pass). + // If the node recovers the "all checkers pass" branch above will Reset it back to Healthy. + case PhaseFailed: + metrics.NodeUnhealthyDurationSeconds.WithLabelValues(nodeName).Set( + now.Sub(state.UnhealthySince()).Seconds()) } } - slog.Info("NodeReady condition not found", "node", node.GetName()) - return false -} -func isNodeDesiredGPU(node *corev1.Node, desired int) bool { - if desired == 0 { - slog.Info("Desired GPU count is set to 0, so the GPU count check is skipped", "node", node.GetName()) + // Clean up state and metrics for nodes no longer in the cluster. + w.states.Range(func(name string, _ *NodeState) bool { + if _, ok := activeNodes[name]; !ok { + metrics.NodeUnhealthyDurationSeconds.DeleteLabelValues(name) + metrics.HealthCheckTotal.DeletePartialMatch(prometheus.Labels{"node": name}) + metrics.RecoveryActionsTotal.DeletePartialMatch(prometheus.Labels{"node": name}) + metrics.RecoveryFailuresTotal.DeletePartialMatch(prometheus.Labels{"node": name}) + metrics.RecoveryPhase.DeletePartialMatch(prometheus.Labels{"node": name}) + } return true - } - - quantity, exists := node.Status.Allocatable[gpuResourceName] - if !exists || quantity.IsZero() { - slog.Info("Allocatable GPU not found", "node", node.GetName()) - return false - } - - gpuCount, ok := quantity.AsInt64() - if !ok { - slog.Info("Failed to convert allocatable GPU quantity to int64", "node", node.GetName(), "quantity", quantity.String()) - return false - } - - slog.Info("Checking actual GPU count with desired", - "node", node.GetName(), - "actual", gpuCount, - "desired", strconv.Itoa(desired)) - - return gpuCount == int64(desired) + }) + w.states.Cleanup(activeNodes) + return nil } -func (w *watcher) rebootNode(name string) error { - instance, err := w.civoClient.FindKubernetesClusterInstance(w.clusterID, name) - if err != nil { - return fmt.Errorf("failed to find instance, clusterID: %s, nodeName: %s: %w", w.clusterID, name, err) +func modeLabel(monitorOnly bool) string { + if monitorOnly { + return "monitor" } + return "active" +} - _, err = w.civoClient.HardRebootInstance(instance.ID) - if err != nil { - return fmt.Errorf("failed to reboot instance, clusterID: %s, instanceID: %s: %w", w.clusterID, instance.ID, err) +// buildNodeSelector builds a LabelSelector based on the given node pool IDs. +// - empty: no selector (all nodes) +// - single: MatchLabels exact match +// - multiple: MatchExpressions In operator +func buildNodeSelector(nodePoolIDs []string) *metav1.LabelSelector { + switch len(nodePoolIDs) { + case 0: + return nil + case 1: + return &metav1.LabelSelector{ + MatchLabels: map[string]string{ + nodePoolLabelKey: nodePoolIDs[0], + }, + } + default: + return &metav1.LabelSelector{ + MatchExpressions: []metav1.LabelSelectorRequirement{ + { + Key: nodePoolLabelKey, + Operator: metav1.LabelSelectorOpIn, + Values: nodePoolIDs, + }, + }, + } } - slog.Info("Instance is rebooting", "instanceID", instance.ID, "node", name) - w.lastRebootCmdTimes.Store(name, time.Now()) - return nil } diff --git a/pkg/watcher/watcher_test.go b/pkg/watcher/watcher_test.go index c69d17c..3367e29 100644 --- a/pkg/watcher/watcher_test.go +++ b/pkg/watcher/watcher_test.go @@ -1,171 +1,180 @@ package watcher import ( - "errors" + "context" "fmt" "strconv" "testing" "time" - "github.com/civo/civogo" + "github.com/civo/node-agent/pkg/health" corev1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/api/resource" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/labels" "k8s.io/client-go/kubernetes/fake" - k8stesting "k8s.io/client-go/testing" ) +// --- Test helpers --- + +// fakeNodeLister implements listerscorev1.NodeLister for testing. +type fakeNodeLister struct { + nodes []*corev1.Node + err error +} + +func (l *fakeNodeLister) List(selector labels.Selector) ([]*corev1.Node, error) { + if l.err != nil { + return nil, l.err + } + return l.nodes, nil +} + +func (l *fakeNodeLister) Get(name string) (*corev1.Node, error) { + for _, n := range l.nodes { + if n.Name == name { + return n, nil + } + } + return nil, fmt.Errorf("node %q not found", name) +} + +// mockExecutor implements operation.Executor for testing. +type mockExecutor struct { + rebootFunc func(ctx context.Context, nodeName string) error + calls []string +} + +func (m *mockExecutor) Reboot(ctx context.Context, nodeName string) error { + m.calls = append(m.calls, nodeName) + if m.rebootFunc != nil { + return m.rebootFunc(ctx, nodeName) + } + return nil +} + +// alwaysFailChecker is a HealthChecker that always reports unhealthy. +type alwaysFailChecker struct { + name string + threshold time.Duration +} + +func (c *alwaysFailChecker) Name() string { return c.name } +func (c *alwaysFailChecker) Check(*corev1.Node) (bool, string) { return false, "always fail" } +func (c *alwaysFailChecker) Threshold() time.Duration { return c.threshold } + +// --- Test variables --- + var ( - testClusterID = "test-cluster-123" - testRegion = "lon1" - testApiKey = "test-api-key" - testApiURL = "https://test.civo.com" - testNodePoolID = "test-node-pool" - testNodeDesiredGPUCount = "8" - testRebootTimeWindowMinutes = time.Duration(40) + testNodePoolID = "test-node-pool" + testRebootWaitMinutes = time.Duration(10) * time.Minute ) +// newTestNode creates a node for testing with common defaults. +func newTestNode(name string, ready corev1.ConditionStatus, gpuCount int) *corev1.Node { + node := &corev1.Node{ + ObjectMeta: metav1.ObjectMeta{ + Name: name, + Labels: map[string]string{ + nodePoolLabelKey: testNodePoolID, + }, + }, + Status: corev1.NodeStatus{ + Conditions: []corev1.NodeCondition{ + {Type: corev1.NodeReady, Status: ready}, + }, + }, + } + if gpuCount > 0 { + node.Labels["nvidia.com/gpu.count"] = strconv.Itoa(gpuCount) + node.Status.Allocatable = corev1.ResourceList{ + "nvidia.com/gpu": resource.MustParse(strconv.Itoa(gpuCount)), + } + } + return node +} + +// newTestWatcher creates a watcher with sensible test defaults and the given options. +func newTestWatcher(t *testing.T, opts ...Option) *watcher { + t.Helper() + baseOpts := []Option{ + WithKubernetesClient(fake.NewSimpleClientset()), + WithExecutor(&mockExecutor{}), + } + w, err := NewWatcher(append(baseOpts, opts...)...) + if err != nil { + t.Fatal(err) + } + return w.(*watcher) +} + +// --- TestNew --- + func TestNew(t *testing.T) { type args struct { - clusterID string - region string - apiKey string - apiURL string - nodePoolID string - opts []Option + opts []Option } type test struct { - name string - args args - checkFunc func(*watcher) error - wantErr bool + description string + args args + checkFunc func(*watcher) error + wantErr bool } tests := []test{ { - name: "Returns no error when given valid input", + description: "returns no error when given valid input", args: args{ - clusterID: testClusterID, - region: testRegion, - apiKey: testApiKey, - apiURL: testApiURL, - nodePoolID: testNodePoolID, opts: []Option{ WithKubernetesClient(fake.NewSimpleClientset()), - WithCivoClient(&FakeClient{}), - WithDesiredGPUCount(testNodeDesiredGPUCount), + WithExecutor(&mockExecutor{}), + WithNodePoolIDs(testNodePoolID), }, }, checkFunc: func(w *watcher) error { - if w.clusterID != testClusterID { - return fmt.Errorf("clusterID mismatch: got %s, want %s", w.clusterID, testClusterID) - } - if w.region != testRegion { - return fmt.Errorf("region mismatch: got %s, want %s", w.region, testRegion) - } - if w.apiKey != testApiKey { - return fmt.Errorf("apiKey mismatch: got %s, want %s", w.apiKey, testApiKey) - } - if w.apiURL != testApiURL { - return fmt.Errorf("apiURL mismatch: got %s, want %s", w.apiURL, testApiURL) - } - - cnt, err := strconv.Atoi(testNodeDesiredGPUCount) - if err != nil { - return err - } - if w.nodeDesiredGPUCount != cnt { - return fmt.Errorf("nodeDesiredGPUCount mismatch: got %d, want %s", w.nodeDesiredGPUCount, testNodeDesiredGPUCount) - } - if w.nodeSelector == nil || w.nodeSelector.MatchLabels[nodePoolLabelKey] != testNodePoolID { - return fmt.Errorf("nodeSelector mismatch: got %v, want %s", w.nodeSelector, testNodePoolID) + if w.nodeLabelSelector == nil || w.nodeLabelSelector.MatchLabels[nodePoolLabelKey] != testNodePoolID { + return fmt.Errorf("nodeLabelSelector mismatch: got %v, want %s", w.nodeLabelSelector, testNodePoolID) } if w.client == nil { return fmt.Errorf("client is nil") } - if w.civoClient == nil { - return fmt.Errorf("civoClient is nil") + if w.rebootWaitMinutes != testRebootWaitMinutes { + return fmt.Errorf("rebootTimeWindowMinutes mismatch: got %v, want %v", w.rebootWaitMinutes, testRebootWaitMinutes) } - if w.rebootTimeWindowMinutes != testRebootTimeWindowMinutes { - return fmt.Errorf("w.rebootTimeWindowMinutes mismatch: got %v, want %s", w.nodeSelector, testNodePoolID) + if !w.monitorOnly { + return fmt.Errorf("monitorOnly should default to true") } - return nil - }, - }, - { - name: "Returns no error when input is invalid, but default value is set", - args: args{ - clusterID: testClusterID, - region: testRegion, - apiKey: testApiKey, - apiURL: testApiURL, - nodePoolID: testNodePoolID, - opts: []Option{ - WithKubernetesClient(fake.NewSimpleClientset()), - WithCivoClient(&FakeClient{}), - WithDesiredGPUCount("invalid"), // It is invalid, but the default count (0) will be used. - WithDesiredGPUCount("-1"), // It is invalid, but the default count (0) will be used. - WithRebootTimeWindowMinutes("invalid time"), // It is invalid, but the default time (40) will be used. - WithRebootTimeWindowMinutes("0"), // It is invalid, but the default time (40) will be used. - }, - }, - checkFunc: func(w *watcher) error { - if w.nodeDesiredGPUCount != 0 { - return fmt.Errorf("w.nodeDesiredGPUCount mismatch: got %d, want %d", w.nodeDesiredGPUCount, 0) + if w.states == nil { + return fmt.Errorf("states is nil") } - if w.rebootTimeWindowMinutes != testRebootTimeWindowMinutes { - return fmt.Errorf("w.rebootTimeWindowMinutes mismatch: got %v, want %s", w.nodeSelector, testNodePoolID) + if w.nowFunc == nil { + return fmt.Errorf("nowFunc is nil") } return nil }, }, { - name: "Returns no error when nodeDesiredGPUCount is 0", + description: "returns no error when input is invalid, but default value is set", args: args{ - clusterID: testClusterID, - region: testRegion, - apiKey: testApiKey, - apiURL: testApiURL, - nodePoolID: testNodePoolID, opts: []Option{ WithKubernetesClient(fake.NewSimpleClientset()), - WithCivoClient(&FakeClient{}), - WithDesiredGPUCount("0"), + WithExecutor(&mockExecutor{}), + WithRebootWaitMinutes("invalid time"), + WithRebootWaitMinutes("0"), }, }, checkFunc: func(w *watcher) error { - if w.nodeDesiredGPUCount != 0 { - return fmt.Errorf("w.nodeDesiredGPUCount mismatch: got %d, want %d", w.nodeDesiredGPUCount, 0) + if w.rebootWaitMinutes != testRebootWaitMinutes { + return fmt.Errorf("rebootTimeWindowMinutes mismatch: got %v, want %v", w.rebootWaitMinutes, testRebootWaitMinutes) } return nil }, }, - { - name: "Returns an error when clusterID is missing", - args: args{ - region: testRegion, - apiKey: testApiKey, - apiURL: testApiURL, - nodePoolID: testNodePoolID, - opts: []Option{ - WithKubernetesClient(fake.NewSimpleClientset()), - WithCivoClient(&FakeClient{}), - }, - }, - wantErr: true, - }, } for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - w, err := NewWatcher(t.Context(), - test.args.apiURL, - test.args.apiKey, - test.args.region, - test.args.clusterID, - test.args.nodePoolID, - test.args.opts...) + t.Run(test.description, func(t *testing.T) { + w, err := NewWatcher(test.args.opts...) if (err != nil) != test.wantErr { t.Errorf("error = %v, wantErr %v", err, test.wantErr) } @@ -186,803 +195,585 @@ func TestNew(t *testing.T) { } } -func TestRun(t *testing.T) { - type args struct { - opts []Option - nodePoolID string +// --- State machine transition tests --- + +func TestRun_HealthyNodeStaysHealthy(t *testing.T) { + node := newTestNode("node-01", corev1.ConditionTrue, 8) + w := newTestWatcher(t, + withNodeLister(&fakeNodeLister{nodes: []*corev1.Node{node}}), + WithCheckers(health.NewDefaultCheckers()), + ) + + if err := w.run(t.Context()); err != nil { + t.Fatal(err) } - type test struct { - name string - args args - beforeFunc func(*watcher) - wantErr bool + + state, ok := w.states.Get("node-01") + if !ok { + t.Fatal("state should exist for node-01") + } + if state.Phase() != PhaseHealthy { + t.Errorf("got phase %v, want PhaseHealthy", state.Phase()) } +} - tests := []test{ - { - name: "Returns nil when node GPU count is 8 and no reboot needed", - args: args{ - opts: []Option{ - WithKubernetesClient(fake.NewSimpleClientset()), - WithCivoClient(&FakeClient{}), - WithDesiredGPUCount(testNodeDesiredGPUCount), - }, - nodePoolID: testNodePoolID, - }, - beforeFunc: func(w *watcher) { - t.Helper() - client := w.client.(*fake.Clientset) - - nodes := &corev1.NodeList{ - Items: []corev1.Node{ - { - ObjectMeta: metav1.ObjectMeta{ - Name: "node-01", - Labels: map[string]string{ - nodePoolLabelKey: testNodePoolID, - }, - }, - Status: corev1.NodeStatus{ - Conditions: []corev1.NodeCondition{ - { - Type: corev1.NodeReady, - Status: corev1.ConditionTrue, - }, - { - Type: corev1.NodeReady, - Status: corev1.ConditionFalse, - }, - }, - Allocatable: corev1.ResourceList{ - gpuResourceName: resource.MustParse("8"), - }, - }, - }, - }, - } - client.Fake.PrependReactor("list", "nodes", func(action k8stesting.Action) (handled bool, ret runtime.Object, err error) { - return true, nodes, nil - }) - }, - }, - { - name: "Returns nil and triggers reboot when GPU count drops below desired (7 GPUs available)", - args: args{ - opts: []Option{ - WithKubernetesClient(fake.NewSimpleClientset()), - WithCivoClient(&FakeClient{}), - WithDesiredGPUCount(testNodeDesiredGPUCount), - }, - nodePoolID: testNodePoolID, - }, - beforeFunc: func(w *watcher) { - t.Helper() - client := w.client.(*fake.Clientset) - - nodes := &corev1.NodeList{ - Items: []corev1.Node{ - { - ObjectMeta: metav1.ObjectMeta{ - Name: "node-01", - Labels: map[string]string{ - nodePoolLabelKey: testNodePoolID, - }, - }, - Status: corev1.NodeStatus{ - Conditions: []corev1.NodeCondition{ - { - Type: corev1.NodeReady, - Status: corev1.ConditionTrue, - }, - { - Type: corev1.NodeReady, - Status: corev1.ConditionFalse, - }, - }, - Allocatable: corev1.ResourceList{ - gpuResourceName: resource.MustParse("7"), - }, - }, - }, - }, - } - client.Fake.PrependReactor("list", "nodes", func(action k8stesting.Action) (handled bool, ret runtime.Object, err error) { - return true, nodes, nil - }) +func TestRun_UnhealthyDetection(t *testing.T) { + now := time.Date(2026, 4, 13, 12, 0, 0, 0, time.UTC) + node := newTestNode("node-01", corev1.ConditionFalse, 8) + w := newTestWatcher(t, + withNodeLister(&fakeNodeLister{nodes: []*corev1.Node{node}}), + WithCheckers(health.NewDefaultCheckers()), + withNowFunc(func() time.Time { return now }), + ) + + if err := w.run(t.Context()); err != nil { + t.Fatal(err) + } - civoClient := w.civoClient.(*FakeClient) - instance := &civogo.Instance{ - ID: "instance-01", - } - civoClient.FindKubernetesClusterInstanceFunc = func(clusterID, search string) (*civogo.Instance, error) { - return instance, nil - } - civoClient.HardRebootInstanceFunc = func(id string) (*civogo.SimpleResponse, error) { - return new(civogo.SimpleResponse), nil - } - }, - }, - { - name: "Returns nil and triggers reboot when GPU count matches desired but node is not ready", - args: args{ - opts: []Option{ - WithKubernetesClient(fake.NewSimpleClientset()), - WithCivoClient(&FakeClient{}), - WithDesiredGPUCount(testNodeDesiredGPUCount), - }, - nodePoolID: testNodePoolID, - }, - beforeFunc: func(w *watcher) { - t.Helper() - client := w.client.(*fake.Clientset) - - nodes := &corev1.NodeList{ - Items: []corev1.Node{ - { - ObjectMeta: metav1.ObjectMeta{ - Name: "node-01", - Labels: map[string]string{ - nodePoolLabelKey: testNodePoolID, - }, - }, - Status: corev1.NodeStatus{ - Conditions: []corev1.NodeCondition{ - { - Type: corev1.NodeReady, - Status: corev1.ConditionFalse, - }, - }, - Allocatable: corev1.ResourceList{ - gpuResourceName: resource.MustParse("8"), - }, - }, - }, - }, - } - client.Fake.PrependReactor("list", "nodes", func(action k8stesting.Action) (handled bool, ret runtime.Object, err error) { - return true, nodes, nil - }) + state, _ := w.states.Get("node-01") + if state.Phase() != PhaseUnhealthy { + t.Errorf("got phase %v, want PhaseUnhealthy", state.Phase()) + } + if !state.UnhealthySince().Equal(now) { + t.Errorf("got unhealthySince %v, want %v", state.UnhealthySince(), now) + } +} - civoClient := w.civoClient.(*FakeClient) - instance := &civogo.Instance{ - ID: "instance-01", - } - civoClient.FindKubernetesClusterInstanceFunc = func(clusterID, search string) (*civogo.Instance, error) { - return instance, nil - } - civoClient.HardRebootInstanceFunc = func(id string) (*civogo.SimpleResponse, error) { - return new(civogo.SimpleResponse), nil - } - }, - }, - { - name: "Returns nil and skips reboot when GPU count matches desired but node is not ready, and LastTransitionTime is more recent than thresholdTime", - args: args{ - opts: []Option{ - WithKubernetesClient(fake.NewSimpleClientset()), - WithCivoClient(&FakeClient{}), - WithDesiredGPUCount(testNodeDesiredGPUCount), - }, - nodePoolID: testNodePoolID, - }, - beforeFunc: func(w *watcher) { - t.Helper() - client := w.client.(*fake.Clientset) - - w.lastRebootCmdTimes.Store("node-01", time.Now()) - - nodes := &corev1.NodeList{ - Items: []corev1.Node{ - { - ObjectMeta: metav1.ObjectMeta{ - Name: "node-01", - Labels: map[string]string{ - nodePoolLabelKey: testNodePoolID, - }, - }, - Status: corev1.NodeStatus{ - Conditions: []corev1.NodeCondition{ - { - Type: corev1.NodeReady, - Status: corev1.ConditionFalse, - }, - }, - Allocatable: corev1.ResourceList{ - gpuResourceName: resource.MustParse("8"), - }, - }, - }, - }, - } - client.Fake.PrependReactor("list", "nodes", func(action k8stesting.Action) (handled bool, ret runtime.Object, err error) { - return true, nodes, nil - }) - }, - }, - { - name: "Returns nil and skips reboot when GPU count matches desired but node is not ready, and LastRebootCmdTime is more recent than thresholdTime", - args: args{ - opts: []Option{ - WithKubernetesClient(fake.NewSimpleClientset()), - WithCivoClient(&FakeClient{}), - WithDesiredGPUCount(testNodeDesiredGPUCount), - }, - nodePoolID: testNodePoolID, - }, - beforeFunc: func(w *watcher) { - t.Helper() - client := w.client.(*fake.Clientset) - - nodes := &corev1.NodeList{ - Items: []corev1.Node{ - { - ObjectMeta: metav1.ObjectMeta{ - Name: "node-01", - Labels: map[string]string{ - nodePoolLabelKey: testNodePoolID, - }, - }, - Status: corev1.NodeStatus{ - Conditions: []corev1.NodeCondition{ - { - Type: corev1.NodeReady, - Status: corev1.ConditionFalse, - LastTransitionTime: metav1.NewTime(time.Now()), - }, - }, - Allocatable: corev1.ResourceList{ - gpuResourceName: resource.MustParse("8"), - }, - }, - }, - }, - } - client.Fake.PrependReactor("list", "nodes", func(action k8stesting.Action) (handled bool, ret runtime.Object, err error) { - return true, nodes, nil - }) - }, - }, - { - name: "Returns an error when unable to list nodes", - args: args{ - opts: []Option{ - WithKubernetesClient(fake.NewSimpleClientset()), - WithCivoClient(&FakeClient{}), - WithDesiredGPUCount(testNodeDesiredGPUCount), - }, - nodePoolID: testNodePoolID, - }, - beforeFunc: func(w *watcher) { - t.Helper() - client := w.client.(*fake.Clientset) +func TestRun_RebootTriggerActiveMode(t *testing.T) { + now := time.Date(2026, 4, 13, 12, 0, 0, 0, time.UTC) + node := newTestNode("node-01", corev1.ConditionFalse, 8) + exec := &mockExecutor{} + w := newTestWatcher(t, + withNodeLister(&fakeNodeLister{nodes: []*corev1.Node{node}}), + WithCheckers(health.NewDefaultCheckers()), + WithExecutor(exec), + WithMonitorOnly("false"), + withNowFunc(func() time.Time { return now }), + ) + + // First run: detect unhealthy. + if err := w.run(t.Context()); err != nil { + t.Fatal(err) + } - client.Fake.PrependReactor("list", "nodes", func(action k8stesting.Action) (handled bool, ret runtime.Object, err error) { - return true, &corev1.NodeList{}, errors.New("invalid error") - }) - }, - wantErr: true, - }, + // Advance past threshold. + now = now.Add(11 * time.Minute) - { - name: "Returns an error when finding the Kubernetes cluster instance fails during reboot", - args: args{ - opts: []Option{ - WithKubernetesClient(fake.NewSimpleClientset()), - WithCivoClient(&FakeClient{}), - WithDesiredGPUCount(testNodeDesiredGPUCount), - }, - nodePoolID: testNodePoolID, - }, - beforeFunc: func(w *watcher) { - t.Helper() - client := w.client.(*fake.Clientset) - - nodes := &corev1.NodeList{ - Items: []corev1.Node{ - { - ObjectMeta: metav1.ObjectMeta{ - Name: "node-01", - Labels: map[string]string{ - nodePoolLabelKey: testNodePoolID, - }, - }, - Status: corev1.NodeStatus{ - Conditions: []corev1.NodeCondition{ - { - Type: corev1.NodeReady, - Status: corev1.ConditionFalse, - }, - }, - Allocatable: corev1.ResourceList{ - gpuResourceName: resource.MustParse("8"), - }, - }, - }, - }, - } - client.Fake.PrependReactor("list", "nodes", func(action k8stesting.Action) (handled bool, ret runtime.Object, err error) { - return true, nodes, nil - }) + // Second run: should trigger reboot. + if err := w.run(t.Context()); err != nil { + t.Fatal(err) + } - civoClient := w.civoClient.(*FakeClient) - civoClient.FindKubernetesClusterInstanceFunc = func(clusterID, search string) (*civogo.Instance, error) { - return nil, errors.New("invalid error") - } - }, - wantErr: true, - }, + state, _ := w.states.Get("node-01") + if state.Phase() != PhaseWaitingReboot { + t.Errorf("got phase %v, want PhaseWaitingReboot", state.Phase()) + } + if state.RebootCount() != 1 { + t.Errorf("got rebootCount %d, want 1", state.RebootCount()) } + if len(exec.calls) != 1 || exec.calls[0] != "node-01" { + t.Errorf("expected 1 reboot call for node-01, got %v", exec.calls) + } +} - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - w, err := NewWatcher(t.Context(), - testApiURL, testApiKey, testRegion, testClusterID, test.args.nodePoolID, test.args.opts...) - if err != nil { - t.Fatal(err) - } +func TestRun_RebootSkippedInReportMode(t *testing.T) { + now := time.Date(2026, 4, 13, 12, 0, 0, 0, time.UTC) + node := newTestNode("node-01", corev1.ConditionFalse, 8) + exec := &mockExecutor{} + w := newTestWatcher(t, + withNodeLister(&fakeNodeLister{nodes: []*corev1.Node{node}}), + WithCheckers(health.NewDefaultCheckers()), + WithExecutor(exec), + WithMonitorOnly("true"), + withNowFunc(func() time.Time { return now }), + ) + + // First run: detect unhealthy. + if err := w.run(t.Context()); err != nil { + t.Fatal(err) + } - obj := w.(*watcher) - if test.beforeFunc != nil { - test.beforeFunc(obj) - } + // Advance past threshold. + now = now.Add(11 * time.Minute) - err = obj.run(t.Context()) - if (err != nil) != test.wantErr { - t.Errorf("error = %v, wantErr %v", err, test.wantErr) - } - }) + // Second run: should transition to WaitingReboot but NOT call executor. + if err := w.run(t.Context()); err != nil { + t.Fatal(err) + } + + state, _ := w.states.Get("node-01") + if state.Phase() != PhaseWaitingReboot { + t.Errorf("got phase %v, want PhaseWaitingReboot", state.Phase()) + } + if len(exec.calls) != 0 { + t.Errorf("expected no reboot calls in report mode, got %v", exec.calls) } } -func TestIsReadyOrNotReadyStatusChangedAfter(t *testing.T) { - type test struct { - name string - node *corev1.Node - thresholdTime time.Time - want bool +func TestRun_RecoveryAfterReboot(t *testing.T) { + now := time.Date(2026, 4, 13, 12, 0, 0, 0, time.UTC) + node := newTestNode("node-01", corev1.ConditionFalse, 8) + w := newTestWatcher(t, + withNodeLister(&fakeNodeLister{nodes: []*corev1.Node{node}}), + WithCheckers(health.NewDefaultCheckers()), + WithMonitorOnly("false"), + withNowFunc(func() time.Time { return now }), + ) + + // Run 1: detect unhealthy. + if err := w.run(t.Context()); err != nil { + t.Fatal(err) + } + // Run 2: trigger reboot. + now = now.Add(11 * time.Minute) + if err := w.run(t.Context()); err != nil { + t.Fatal(err) } - tests := []test{ - { - name: "Returns true when NodeReady condition is true (Ready) and last transition time is after threshold", - node: &corev1.Node{ - ObjectMeta: metav1.ObjectMeta{ - Name: "node-01", - }, - Status: corev1.NodeStatus{ - Conditions: []corev1.NodeCondition{ - { - Type: corev1.NodeReady, - Status: corev1.ConditionTrue, - LastTransitionTime: metav1.NewTime(time.Now()), - }, - }, - }, - }, - thresholdTime: time.Now().Add(-time.Hour), - want: true, - }, - { - name: "Returns true when NodeReady condition is false (NotReady) and last transition time is after threshold", - node: &corev1.Node{ - ObjectMeta: metav1.ObjectMeta{ - Name: "node-01", - }, - Status: corev1.NodeStatus{ - Conditions: []corev1.NodeCondition{ - { - Type: corev1.NodeReady, - Status: corev1.ConditionFalse, - LastTransitionTime: metav1.NewTime(time.Now()), - }, - }, - }, - }, - thresholdTime: time.Now().Add(-time.Hour), - want: true, - }, - { - name: "Returns false when the latest NodeReady condition is older than thresholdTime", - node: &corev1.Node{ - ObjectMeta: metav1.ObjectMeta{ - Name: "node-01", - }, - Status: corev1.NodeStatus{ - Conditions: []corev1.NodeCondition{ - { - Type: corev1.NodeReady, - Status: corev1.ConditionFalse, - LastTransitionTime: metav1.NewTime(time.Now().Add(-time.Hour)), - }, - }, - }, - }, - thresholdTime: time.Now(), - want: false, - }, - { - name: "Returns false when no conditions are present on the node", - node: &corev1.Node{ - ObjectMeta: metav1.ObjectMeta{ - Name: "node-01", - }, - Status: corev1.NodeStatus{ - Conditions: []corev1.NodeCondition{}, - }, - }, - thresholdTime: time.Now().Add(-time.Hour), - want: false, - }, - { - name: "Returns false when there is only NodeDiskPressure condition", - node: &corev1.Node{ - ObjectMeta: metav1.ObjectMeta{ - Name: "node-01", - }, - Status: corev1.NodeStatus{ - Conditions: []corev1.NodeCondition{ - { - Type: corev1.NodeDiskPressure, - Status: corev1.ConditionFalse, - LastHeartbeatTime: metav1.NewTime(time.Now()), - }, - }, - }, - }, - thresholdTime: time.Now().Add(-time.Hour), - want: false, - }, + // Node recovers. + node.Status.Conditions[0].Status = corev1.ConditionTrue + now = now.Add(5 * time.Minute) + + // Run 3: should detect recovery. + if err := w.run(t.Context()); err != nil { + t.Fatal(err) } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - got := isReadyOrNotReadyStatusChangedAfter(test.node, test.thresholdTime) - if got != test.want { - t.Errorf("got = %v, want %v", got, test.want) - } - }) + state, _ := w.states.Get("node-01") + if state.Phase() != PhaseHealthy { + t.Errorf("got phase %v, want PhaseHealthy", state.Phase()) + } + if state.RebootCount() != 0 { + t.Errorf("got rebootCount %d, want 0 after recovery", state.RebootCount()) } } -func TestIsLastRebootCommandTimeAfter(t *testing.T) { - type test struct { - name string - nodeName string - opts []Option - thresholdTime time.Time - beforeFunc func(*watcher) - want bool +func TestRun_RebootRetry(t *testing.T) { + now := time.Date(2026, 4, 13, 12, 0, 0, 0, time.UTC) + node := newTestNode("node-01", corev1.ConditionFalse, 8) + exec := &mockExecutor{} + w := newTestWatcher(t, + withNodeLister(&fakeNodeLister{nodes: []*corev1.Node{node}}), + WithCheckers(health.NewDefaultCheckers()), + WithExecutor(exec), + WithMonitorOnly("false"), + WithGPURebootWaitMinutes("40"), + withNowFunc(func() time.Time { return now }), + ) + + // Run 1: detect unhealthy. + if err := w.run(t.Context()); err != nil { + t.Fatal(err) + } + // Run 2: trigger first reboot (GPU checker threshold is 10min). + now = now.Add(11 * time.Minute) + if err := w.run(t.Context()); err != nil { + t.Fatal(err) } - tests := []test{ - { - name: "Return true when last reboot command time is after threshold", - opts: []Option{ - WithKubernetesClient(fake.NewSimpleClientset()), - WithCivoClient(&FakeClient{}), - }, - nodeName: "node-01", - thresholdTime: time.Now().Add(-time.Hour), - beforeFunc: func(w *watcher) { - w.lastRebootCmdTimes.Store("node-01", time.Now()) - }, - want: true, - }, - { - name: "Return false when last reboot command time is before threshold", - opts: []Option{ - WithKubernetesClient(fake.NewSimpleClientset()), - WithCivoClient(&FakeClient{}), - }, - nodeName: "node-01", - thresholdTime: time.Now().Add(-time.Hour), - beforeFunc: func(w *watcher) { - w.lastRebootCmdTimes.Store("nodde-01", time.Now().Add(-2*time.Hour)) - }, - want: false, - }, - { - name: "Return false when last reboot command time not found", - opts: []Option{ - WithKubernetesClient(fake.NewSimpleClientset()), - WithCivoClient(&FakeClient{}), - }, - nodeName: "node-01", - thresholdTime: time.Now().Add(-time.Hour), - want: false, - }, - { - name: "Return false when type of last reboot command time is invalid", - opts: []Option{ - WithKubernetesClient(fake.NewSimpleClientset()), - WithCivoClient(&FakeClient{}), - }, - nodeName: "node-01", - thresholdTime: time.Now().Add(-time.Hour), - beforeFunc: func(w *watcher) { - w.lastRebootCmdTimes.Store("nodde-01", "invalid-type") - }, - want: false, - }, + // Still unhealthy, but within reboot window → no retry. + now = now.Add(30 * time.Minute) + if err := w.run(t.Context()); err != nil { + t.Fatal(err) + } + if len(exec.calls) != 1 { + t.Fatalf("expected 1 reboot call before window expires, got %d", len(exec.calls)) } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - w, err := NewWatcher(t.Context(), - testApiURL, testApiKey, testRegion, testClusterID, testNodePoolID, test.opts...) - if err != nil { - t.Fatal(err) - } + // Advance past reboot window → retry. + now = now.Add(11 * time.Minute) + if err := w.run(t.Context()); err != nil { + t.Fatal(err) + } - obj := w.(*watcher) - if test.beforeFunc != nil { - test.beforeFunc(obj) - } - got := obj.isLastRebootCommandTimeAfter(test.nodeName, test.thresholdTime) - if got != test.want { - t.Errorf("got = %v, want %v", got, test.want) - } - }) + state, _ := w.states.Get("node-01") + if state.Phase() != PhaseWaitingReboot { + t.Errorf("got phase %v, want PhaseWaitingReboot", state.Phase()) + } + if state.RebootCount() != 2 { + t.Errorf("got rebootCount %d, want 2", state.RebootCount()) + } + if len(exec.calls) != 2 { + t.Errorf("expected 2 reboot calls, got %d", len(exec.calls)) } } -func TestIsNodeReady(t *testing.T) { - type test struct { - name string - node *corev1.Node - want bool +func TestRun_GPUMismatchTriggersUnhealthy(t *testing.T) { + now := time.Date(2026, 4, 13, 12, 0, 0, 0, time.UTC) + node := newTestNode("node-01", corev1.ConditionTrue, 8) + // Simulate GPU failure: label says 8 but only 7 allocatable. + node.Status.Allocatable["nvidia.com/gpu"] = resource.MustParse("7") + w := newTestWatcher(t, + withNodeLister(&fakeNodeLister{nodes: []*corev1.Node{node}}), + WithCheckers(health.NewDefaultCheckers()), + withNowFunc(func() time.Time { return now }), + ) + + if err := w.run(t.Context()); err != nil { + t.Fatal(err) } - tests := []test{ - { - name: "Returns true when Node is ready state", - node: &corev1.Node{ - ObjectMeta: metav1.ObjectMeta{ - Name: "node-01", - }, - Status: corev1.NodeStatus{ - Conditions: []corev1.NodeCondition{ - { - Type: corev1.NodeReady, - Status: corev1.ConditionTrue, - }, - { - Type: corev1.NodeReady, - Status: corev1.ConditionFalse, - }, - }, - }, - }, - want: true, - }, - { - name: "Returns false when Node is not ready state", - node: &corev1.Node{ - ObjectMeta: metav1.ObjectMeta{ - Name: "node-01", - }, - Status: corev1.NodeStatus{ - Conditions: []corev1.NodeCondition{ - { - Type: corev1.NodeReady, - Status: corev1.ConditionFalse, - }, - }, - }, - }, - }, - { - name: "Returns false when no conditions for the node", - node: &corev1.Node{ - ObjectMeta: metav1.ObjectMeta{ - Name: "node-01", - }, - Status: corev1.NodeStatus{ - Conditions: []corev1.NodeCondition{}, - }, - }, + state, _ := w.states.Get("node-01") + if state.Phase() != PhaseUnhealthy { + t.Errorf("got phase %v, want PhaseUnhealthy", state.Phase()) + } + if !state.IsGPUNode() { + t.Error("expected isGPUNode to be true for node with 7 GPUs") + } +} + +func TestRun_RebootErrorContinuesProcessing(t *testing.T) { + now := time.Date(2026, 4, 13, 12, 0, 0, 0, time.UTC) + node := newTestNode("node-01", corev1.ConditionFalse, 0) + exec := &mockExecutor{ + rebootFunc: func(_ context.Context, _ string) error { + return fmt.Errorf("reboot API error") }, } + w := newTestWatcher(t, + withNodeLister(&fakeNodeLister{nodes: []*corev1.Node{node}}), + WithCheckers(health.NewDefaultCheckers()), + WithExecutor(exec), + WithMonitorOnly("false"), + withNowFunc(func() time.Time { return now }), + ) + + // Run 1: detect unhealthy. + if err := w.run(t.Context()); err != nil { + t.Fatal(err) + } + // Run 2: threshold exceeded, reboot fails → should not error out, stays PhaseUnhealthy. + now = now.Add(11 * time.Minute) + if err := w.run(t.Context()); err != nil { + t.Fatal("run should not return error on reboot failure") + } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - got := isNodeReady(test.node) - if got != test.want { - t.Errorf("got = %v, want %v", got, test.want) - } - }) + state, _ := w.states.Get("node-01") + if state.Phase() != PhaseUnhealthy { + t.Errorf("got phase %v, want PhaseUnhealthy (reboot failed, no transition)", state.Phase()) } } -func TestIsNodeDesiredGPU(t *testing.T) { - type test struct { - name string - node *corev1.Node - desired int - want bool +func TestRun_NodeListError(t *testing.T) { + w := newTestWatcher(t, + withNodeLister(&fakeNodeLister{err: fmt.Errorf("list error")}), + WithCheckers(health.NewDefaultCheckers()), + ) + + if err := w.run(t.Context()); err == nil { + t.Error("expected error from node list failure") } +} - tests := []test{ - { - name: "Returns true when GPU count matches desired value", - node: &corev1.Node{ - ObjectMeta: metav1.ObjectMeta{ - Name: "node-01", - }, - Status: corev1.NodeStatus{ - Allocatable: corev1.ResourceList{ - gpuResourceName: resource.MustParse("8"), - }, - }, - }, - desired: 8, - want: true, - }, - { - name: "Returns true when desired GPU count is 0, so count check is skipped", - node: &corev1.Node{ - ObjectMeta: metav1.ObjectMeta{ - Name: "node-01", - }, - Status: corev1.NodeStatus{ - Allocatable: corev1.ResourceList{}, - }, - }, - desired: 0, - want: true, - }, - { - name: "Returns false when GPU count is 0", - node: &corev1.Node{ - ObjectMeta: metav1.ObjectMeta{ - Name: "node-01", - }, - Status: corev1.NodeStatus{ - Allocatable: corev1.ResourceList{ - gpuResourceName: resource.MustParse("0"), - }, - }, - }, - desired: 8, - want: false, - }, - { - name: "Returns false when GPU count is less than desired value", - node: &corev1.Node{ - ObjectMeta: metav1.ObjectMeta{ - Name: "node-01", - }, - Status: corev1.NodeStatus{ - Allocatable: corev1.ResourceList{ - gpuResourceName: resource.MustParse("7"), - }, - }, - }, - desired: 8, - want: false, +func TestRun_StaleStateCleanup(t *testing.T) { + now := time.Date(2026, 4, 13, 12, 0, 0, 0, time.UTC) + node := newTestNode("node-01", corev1.ConditionFalse, 0) + lister := &fakeNodeLister{nodes: []*corev1.Node{node}} + w := newTestWatcher(t, + withNodeLister(lister), + WithCheckers(health.NewDefaultCheckers()), + withNowFunc(func() time.Time { return now }), + ) + + // Run 1: detect node-01 unhealthy. + if err := w.run(t.Context()); err != nil { + t.Fatal(err) + } + if _, ok := w.states.Get("node-01"); !ok { + t.Fatal("state should exist for node-01") + } + + // Node removed from cluster. + lister.nodes = nil + + if err := w.run(t.Context()); err != nil { + t.Fatal(err) + } + + if _, ok := w.states.Get("node-01"); ok { + t.Error("state for node-01 should be cleaned up after removal") + } +} + +func TestRun_UnhealthyWithinThresholdNoReboot(t *testing.T) { + now := time.Date(2026, 4, 13, 12, 0, 0, 0, time.UTC) + node := newTestNode("node-01", corev1.ConditionFalse, 0) + exec := &mockExecutor{} + w := newTestWatcher(t, + withNodeLister(&fakeNodeLister{nodes: []*corev1.Node{node}}), + WithCheckers(health.NewDefaultCheckers()), + WithExecutor(exec), + WithMonitorOnly("false"), + withNowFunc(func() time.Time { return now }), + ) + + // Run 1: detect unhealthy. + if err := w.run(t.Context()); err != nil { + t.Fatal(err) + } + + // Run 2: still within threshold → no reboot. + now = now.Add(3 * time.Minute) + if err := w.run(t.Context()); err != nil { + t.Fatal(err) + } + + state, _ := w.states.Get("node-01") + if state.Phase() != PhaseUnhealthy { + t.Errorf("got phase %v, want PhaseUnhealthy", state.Phase()) + } + if len(exec.calls) != 0 { + t.Errorf("expected no reboot calls within threshold, got %v", exec.calls) + } +} + +func TestRun_RebootRetryLimitExceeded_TransitionsToFailed(t *testing.T) { + now := time.Date(2026, 4, 13, 12, 0, 0, 0, time.UTC) + node := newTestNode("node-01", corev1.ConditionFalse, 0) + exec := &mockExecutor{} + w := newTestWatcher(t, + withNodeLister(&fakeNodeLister{nodes: []*corev1.Node{node}}), + WithCheckers(health.NewDefaultCheckers()), + WithExecutor(exec), + WithMonitorOnly("false"), + WithRebootWaitMinutes("10"), + WithMaxRebootRetries("3"), + withNowFunc(func() time.Time { return now }), + ) + + // Run 1: detect unhealthy. + if err := w.run(t.Context()); err != nil { + t.Fatal(err) + } + // Run 2: threshold exceeded → first reboot (rebootCount=1). + now = now.Add(6 * time.Minute) + if err := w.run(t.Context()); err != nil { + t.Fatal(err) + } + // Reboot retries 2 and 3. + for i := 0; i < 2; i++ { + now = now.Add(11 * time.Minute) + if err := w.run(t.Context()); err != nil { + t.Fatal(err) + } + } + + state, _ := w.states.Get("node-01") + if state.RebootCount() != 3 { + t.Fatalf("expected rebootCount=3 after 3 reboots, got %d", state.RebootCount()) + } + if state.Phase() != PhaseWaitingReboot { + t.Fatalf("expected PhaseWaitingReboot after %d reboots, got %v", state.RebootCount(), state.Phase()) + } + + // Next retry should exceed the limit → PhaseFailed. + now = now.Add(11 * time.Minute) + if err := w.run(t.Context()); err != nil { + t.Fatal(err) + } + + state, _ = w.states.Get("node-01") + if state.Phase() != PhaseFailed { + t.Errorf("got phase %v, want PhaseFailed", state.Phase()) + } + if len(exec.calls) != 3 { + t.Errorf("expected exactly 3 reboot calls (no further reboots after Failed), got %d", len(exec.calls)) + } +} + +func TestRun_RebootFailureLimitExceeded_TransitionsToFailed(t *testing.T) { + now := time.Date(2026, 4, 13, 12, 0, 0, 0, time.UTC) + node := newTestNode("node-01", corev1.ConditionFalse, 0) + exec := &mockExecutor{ + rebootFunc: func(_ context.Context, _ string) error { + return fmt.Errorf("reboot API error") }, } + w := newTestWatcher(t, + withNodeLister(&fakeNodeLister{nodes: []*corev1.Node{node}}), + WithCheckers(health.NewDefaultCheckers()), + WithExecutor(exec), + WithMonitorOnly("false"), + WithRebootWaitMinutes("10"), + WithMaxRebootRetries("100"), + WithMaxRebootFailures("3"), + withNowFunc(func() time.Time { return now }), + ) + + // Run 1: detect unhealthy. + if err := w.run(t.Context()); err != nil { + t.Fatal(err) + } + // Advance past the unhealthy threshold; reboots will now be attempted and fail. + now = now.Add(11 * time.Minute) + + // Runs 2-4: reboot fails 3 times → failedRebootCount=3, still PhaseUnhealthy. + for i := 0; i < 3; i++ { + if err := w.run(t.Context()); err != nil { + t.Fatal(err) + } + } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - got := isNodeDesiredGPU(test.node, test.desired) - if got != test.want { - t.Errorf("got = %v, want %v", got, test.want) - } - }) + state, _ := w.states.Get("node-01") + if state.FailedRebootCount() != 3 { + t.Fatalf("expected failedRebootCount=3 after 3 failures, got %d", state.FailedRebootCount()) + } + if state.Phase() != PhaseUnhealthy { + t.Fatalf("expected still PhaseUnhealthy at limit, got %v", state.Phase()) + } + + // Next run: failedRebootCount=3 >= max=3 → PhaseFailed. + if err := w.run(t.Context()); err != nil { + t.Fatal(err) + } + + state, _ = w.states.Get("node-01") + if state.Phase() != PhaseFailed { + t.Errorf("got phase %v, want PhaseFailed", state.Phase()) + } + if len(exec.calls) != 3 { + t.Errorf("expected exactly 3 reboot calls (no further reboots after Failed), got %d", len(exec.calls)) } } -func TestRebootNode(t *testing.T) { - type args struct { - nodeName string - opts []Option +func TestRun_MonitorOnlySimulatesFullLifecycle(t *testing.T) { + now := time.Date(2026, 4, 13, 12, 0, 0, 0, time.UTC) + node := newTestNode("node-01", corev1.ConditionFalse, 0) + exec := &mockExecutor{} + w := newTestWatcher(t, + withNodeLister(&fakeNodeLister{nodes: []*corev1.Node{node}}), + WithCheckers(health.NewDefaultCheckers()), + WithExecutor(exec), + WithMonitorOnly("true"), + WithRebootWaitMinutes("10"), + WithMaxRebootRetries("3"), + withNowFunc(func() time.Time { return now }), + ) + + // Drive the state machine through detection + three reboot cycles + retry-limit check. + if err := w.run(t.Context()); err != nil { + t.Fatal(err) } - type test struct { - name string - args args - beforeFunc func(*testing.T, *watcher) - wantErr bool + for i := 0; i < 5; i++ { + now = now.Add(11 * time.Minute) + if err := w.run(t.Context()); err != nil { + t.Fatal(err) + } } - tests := []test{ - { - name: "Returns nil when there is no error finding and rebooting the instance", - args: args{ - nodeName: "node-01", - opts: []Option{ - WithKubernetesClient(fake.NewSimpleClientset()), - WithCivoClient(&FakeClient{}), - WithDesiredGPUCount(testNodeDesiredGPUCount), - }, - }, - beforeFunc: func(t *testing.T, w *watcher) { - t.Helper() - client := w.civoClient.(*FakeClient) + state, _ := w.states.Get("node-01") + if state.Phase() != PhaseFailed { + t.Errorf("expected PhaseFailed after monitor-only simulation, got %v", state.Phase()) + } + if state.RebootCount() != 3 { + t.Errorf("expected rebootCount=3 (maxRebootRetries), got %d", state.RebootCount()) + } + if len(exec.calls) != 0 { + t.Errorf("expected no executor calls in monitor-only mode, got %d", len(exec.calls)) + } +} - instance := &civogo.Instance{ - ID: "instance-01", - } +func TestRun_RecoverFromFailed(t *testing.T) { + now := time.Date(2026, 4, 13, 12, 0, 0, 0, time.UTC) + node := newTestNode("node-01", corev1.ConditionFalse, 0) + exec := &mockExecutor{} + w := newTestWatcher(t, + withNodeLister(&fakeNodeLister{nodes: []*corev1.Node{node}}), + WithCheckers(health.NewDefaultCheckers()), + WithExecutor(exec), + WithMonitorOnly("false"), + WithRebootWaitMinutes("10"), + WithMaxRebootRetries("1"), + withNowFunc(func() time.Time { return now }), + ) + + // Drive to Failed: detect → first reboot (rebootCount=1) → retry exceeds limit. + if err := w.run(t.Context()); err != nil { + t.Fatal(err) + } + now = now.Add(6 * time.Minute) + if err := w.run(t.Context()); err != nil { + t.Fatal(err) + } + now = now.Add(11 * time.Minute) + if err := w.run(t.Context()); err != nil { + t.Fatal(err) + } - client.FindKubernetesClusterInstanceFunc = func(clusterID, search string) (*civogo.Instance, error) { - return instance, nil - } - client.HardRebootInstanceFunc = func(id string) (*civogo.SimpleResponse, error) { - if instance.ID != id { - t.Errorf("instanceId dose not match. want: %s, but got: %s", instance.ID, id) - } - return new(civogo.SimpleResponse), nil - } - }, + if st, _ := w.states.Get("node-01"); st.Phase() != PhaseFailed { + t.Fatalf("expected PhaseFailed, got %v", st.Phase()) + } + + // Node recovers. + node.Status.Conditions[0].Status = corev1.ConditionTrue + now = now.Add(time.Minute) + if err := w.run(t.Context()); err != nil { + t.Fatal(err) + } + + st, _ := w.states.Get("node-01") + if st.Phase() != PhaseHealthy { + t.Errorf("expected PhaseHealthy after recovery, got %v", st.Phase()) + } + if st.RebootCount() != 0 { + t.Errorf("rebootCount should reset to 0 after recovery, got %d", st.RebootCount()) + } +} + +func TestBuildNodeSelector(t *testing.T) { + tests := []struct { + description string + nodePoolIDs []string + wantNil bool + wantLabels map[string]string + wantInExpr bool + }{ + { + description: "returns nil for empty IDs", + wantNil: true, }, { - name: "Returns an error when instance lookup fails", - args: args{ - nodeName: "node-01", - opts: []Option{ - WithKubernetesClient(fake.NewSimpleClientset()), - WithCivoClient(&FakeClient{}), - WithDesiredGPUCount(testNodeDesiredGPUCount), - }, - }, - beforeFunc: func(t *testing.T, w *watcher) { - t.Helper() - client := w.civoClient.(*FakeClient) - - client.FindKubernetesClusterInstanceFunc = func(clusterID, search string) (*civogo.Instance, error) { - return nil, errors.New("invalid error") - } - }, - wantErr: true, + description: "returns MatchLabels for single ID", + nodePoolIDs: []string{"pool-1"}, + wantLabels: map[string]string{nodePoolLabelKey: "pool-1"}, }, { - name: "Returns an error when instance reboot fails", - args: args{ - nodeName: "node-01", - opts: []Option{ - WithKubernetesClient(fake.NewSimpleClientset()), - WithCivoClient(&FakeClient{}), - WithDesiredGPUCount(testNodeDesiredGPUCount), - }, - }, - beforeFunc: func(t *testing.T, w *watcher) { - t.Helper() - client := w.civoClient.(*FakeClient) - - instance := &civogo.Instance{ - ID: "instance-01", - } - - client.FindKubernetesClusterInstanceFunc = func(clusterID, search string) (*civogo.Instance, error) { - return instance, nil - } - client.HardRebootInstanceFunc = func(id string) (*civogo.SimpleResponse, error) { - if instance.ID != id { - t.Errorf("instanceId dose not match. want: %s, but got: %s", instance.ID, id) - } - return nil, errors.New("invalid error") - } - }, - wantErr: true, + description: "returns MatchExpressions In for multiple IDs", + nodePoolIDs: []string{"pool-1", "pool-2"}, + wantInExpr: true, }, } for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - w, err := NewWatcher(t.Context(), - testApiURL, testApiKey, testRegion, testClusterID, testNodePoolID, test.args.opts...) - if err != nil { - t.Fatal(err) + t.Run(test.description, func(t *testing.T) { + sel := buildNodeSelector(test.nodePoolIDs) + if test.wantNil { + if sel != nil { + t.Errorf("expected nil selector, got %v", sel) + } + return } - - obj := w.(*watcher) - if test.beforeFunc != nil { - test.beforeFunc(t, obj) + if sel == nil { + t.Fatal("expected non-nil selector") } - - err = obj.rebootNode(test.args.nodeName) - if (err != nil) != test.wantErr { - t.Errorf("error = %v, wantErr %v", err, test.wantErr) + if test.wantLabels != nil { + for k, v := range test.wantLabels { + if sel.MatchLabels[k] != v { + t.Errorf("MatchLabels[%s] = %q, want %q", k, sel.MatchLabels[k], v) + } + } + } + if test.wantInExpr { + if len(sel.MatchExpressions) != 1 { + t.Fatalf("expected 1 MatchExpression, got %d", len(sel.MatchExpressions)) + } + expr := sel.MatchExpressions[0] + if expr.Key != nodePoolLabelKey { + t.Errorf("key = %q, want %q", expr.Key, nodePoolLabelKey) + } + if len(expr.Values) != len(test.nodePoolIDs) { + t.Errorf("values count = %d, want %d", len(expr.Values), len(test.nodePoolIDs)) + } } }) }