diff --git a/cmd/hostagent/subcmds/serve.go b/cmd/hostagent/subcmds/serve.go index 3f1c86e62..f4c271250 100644 --- a/cmd/hostagent/subcmds/serve.go +++ b/cmd/hostagent/subcmds/serve.go @@ -28,6 +28,7 @@ import ( "github.com/nvidia/doca-platform/internal/provisioning/hostagent" "github.com/nvidia/doca-platform/internal/provisioning/hostagent/networkmanager" "github.com/nvidia/doca-platform/internal/provisioning/hostagent/nodemanager" + "github.com/nvidia/doca-platform/internal/provisioning/hostagent/phase/reboot" "github.com/nvidia/doca-platform/internal/provisioning/hostagent/service" "github.com/spf13/cobra" @@ -103,11 +104,13 @@ var serveCmd = &cobra.Command{ os.Exit(1) } - if err := service.NewInstallationService(unCachedClient, nm).Start(true); err != nil { + rh := reboot.NewHandler(mgr.GetClient(), dpuNodeManager.GetNodeName, nm.GetDevice) + + if err := service.NewInstallationService(unCachedClient, nm, rh).Start(true); err != nil { klog.Fatalf("failed to start installation service: %v", err) } - reconciler := hostagent.NewHostAgentReconciler(mgr.GetClient(), opts.BFBRegistryAddress, dpuNodeManager, nm) + reconciler := hostagent.NewHostAgentReconciler(mgr.GetClient(), opts.BFBRegistryAddress, dpuNodeManager, nm, rh) if err = reconciler.SetupWithManager(mgr); err != nil { setupLog.Error(err, "unable to create controller", "controller", "DPU") os.Exit(1) diff --git a/internal/provisioning/controllers/util/dms/util.go b/internal/provisioning/controllers/util/dms/util.go index a2cc4c09e..bd0b368d3 100644 --- a/internal/provisioning/controllers/util/dms/util.go +++ b/internal/provisioning/controllers/util/dms/util.go @@ -219,6 +219,10 @@ func CreateHostAgentPod(ctx context.Context, client client.Client, node *corev1. Name: "run-udev", MountPath: "/run/udev", }, + { + Name: "etc-udev-rules", + MountPath: "/etc/udev/rules.d", + }, { Name: "systemd-network", MountPath: "/usr/lib/systemd/network", @@ -307,6 +311,15 @@ func CreateHostAgentPod(ctx context.Context, client client.Client, node *corev1. }, }, }, + { + Name: "etc-udev-rules", + VolumeSource: corev1.VolumeSource{ + HostPath: &corev1.HostPathVolumeSource{ + Path: "/etc/udev/rules.d", + Type: ptr.To(corev1.HostPathDirectoryOrCreate), + }, + }, + }, { Name: "systemd-network", VolumeSource: corev1.VolumeSource{ diff --git a/internal/provisioning/hostagent/controller.go b/internal/provisioning/hostagent/controller.go index b7e4f9dd5..9de03935a 100644 --- a/internal/provisioning/hostagent/controller.go +++ b/internal/provisioning/hostagent/controller.go @@ -56,7 +56,8 @@ type HostAgentReconciler struct { func NewHostAgentReconciler(client client.Client, bfbRegistryAddress string, nodeManager nodemanager.Interface, - networkManager networkmanager.Interface) *HostAgentReconciler { + networkManager networkmanager.Interface, + rebootHandler *reboot.Handler) *HostAgentReconciler { r := &HostAgentReconciler{ Client: client, NodeManager: nodeManager, @@ -70,7 +71,7 @@ func NewHostAgentReconciler(client client.Client, provisioningv1.DPUInitializeInterface: interfaceinit.NewHandler(client, r.NetworkManager.GetDevice), provisioningv1.DPUConfigFWParameters: configfw.NewHandler(client, r.NetworkManager.GetDevice), provisioningv1.DPUOSInstalling: install.NewHandler(client, bfbRegistry, r.NetworkManager.GetDevice), - provisioningv1.DPURebooting: reboot.NewHandler(client, r.NodeManager.GetNodeName, r.NetworkManager.GetDevice), + provisioningv1.DPURebooting: rebootHandler, provisioningv1.DPUHostNetworkConfiguration: network.NewHandler(r.NetworkManager.AddNetworkRequest), } return r diff --git a/internal/provisioning/hostagent/networkmanager/network_manager.go b/internal/provisioning/hostagent/networkmanager/network_manager.go index 4b932137c..aaaee1309 100644 --- a/internal/provisioning/hostagent/networkmanager/network_manager.go +++ b/internal/provisioning/hostagent/networkmanager/network_manager.go @@ -51,8 +51,9 @@ type Interface interface { Start() error // GetDevice returns the PCI device by serial number GetDevice(serialNumber string) (hostutil.Device, bool) - // AddNetworkRequest adds a network request for a DPU - AddNetworkRequest(dpu *provisioningv1.DPU) error + // AddNetworkRequest adds a network request for a DPU. + // If vfCount is non-nil it overrides the value derived from the DPUFlavor. + AddNetworkRequest(dpu *provisioningv1.DPU, vfCount *int) error } type NetworkManager struct { @@ -185,6 +186,12 @@ func (nm *NetworkManager) processNetworkRequest(nr NetworkRequest) error { return nil } operations := []networkOperation{ + { + name: "DisableNMForVFs", + f: func(nr NetworkRequest) error { + return nm.netBackend.EnsureVFsUnmanaged() + }, + }, { name: "CreateP0VF", f: func(nr NetworkRequest) error { @@ -247,7 +254,7 @@ func (nm *NetworkManager) processNetworkRequest(nr NetworkRequest) error { return nil } -func (nm *NetworkManager) AddNetworkRequest(dpu *provisioningv1.DPU) error { +func (nm *NetworkManager) AddNetworkRequest(dpu *provisioningv1.DPU, vfCount *int) error { nm.Lock() defer nm.Unlock() if !nm.initialized { @@ -256,7 +263,15 @@ func (nm *NetworkManager) AddNetworkRequest(dpu *provisioningv1.DPU) error { return fmt.Errorf("DPU is nil") } - if _, ok := nm.reqs[string(dpu.UID)]; ok { + if existing, ok := nm.reqs[string(dpu.UID)]; ok { + if vfCount != nil && *vfCount != 0 && existing.NumOfVFs != *vfCount { + existing.NumOfVFs = *vfCount + if err := writeNetworkRequestFile(&existing); err != nil { + return fmt.Errorf("failed to update network request file: %w", err) + } + nm.reqs[existing.UID] = existing + klog.Infof("Updated VF count to %d for DPU %s/%s", *vfCount, existing.DPUNamespace, existing.DpuName) + } return nil } @@ -272,9 +287,15 @@ func (nm *NetworkManager) AddNetworkRequest(dpu *provisioningv1.DPU) error { } nr.PCIAddress = dev.Address - numOfVFs, err := nm.getNumOfVFs(dpu) - if err != nil { - return fmt.Errorf("failed to get number of VFs: %w", err) + var numOfVFs int + if vfCount != nil && *vfCount != 0 { + numOfVFs = *vfCount + } else { + var err error + numOfVFs, err = nm.getNumOfVFs(dpu) + if err != nil { + return fmt.Errorf("failed to get number of VFs: %w", err) + } } nr.NumOfVFs = numOfVFs diff --git a/internal/provisioning/hostagent/phase/network/hostnetwork.go b/internal/provisioning/hostagent/phase/network/hostnetwork.go index c0fc6e564..eb75d116b 100644 --- a/internal/provisioning/hostagent/phase/network/hostnetwork.go +++ b/internal/provisioning/hostagent/phase/network/hostnetwork.go @@ -31,10 +31,10 @@ const ( ) type Handler struct { - AddNetworkRequest func(dpu *provisioningv1.DPU) error + AddNetworkRequest func(dpu *provisioningv1.DPU, vfCount *int) error } -func NewHandler(addNetworkRequest func(dpu *provisioningv1.DPU) error) *Handler { +func NewHandler(addNetworkRequest func(dpu *provisioningv1.DPU, vfCount *int) error) *Handler { return &Handler{ AddNetworkRequest: addNetworkRequest, } @@ -42,7 +42,7 @@ func NewHandler(addNetworkRequest func(dpu *provisioningv1.DPU) error) *Handler func (h *Handler) Handle(ctx context.Context, dpu *provisioningv1.DPU) (provisioningv1.DPUStatus, ctrl.Result, error) { log := log.FromContext(ctx) - if err := h.AddNetworkRequest(dpu); err != nil { + if err := h.AddNetworkRequest(dpu, nil); err != nil { log.Error(err, "Failed to add network request") hostutil.NewCondition(condition).Failure(err, "FailedToSetupHostNetwork").Set(&dpu.Status.Conditions) return dpu.Status, ctrl.Result{}, err diff --git a/internal/provisioning/hostagent/phase/reboot/sync.go b/internal/provisioning/hostagent/phase/reboot/sync.go index 7f595122c..8170b4617 100644 --- a/internal/provisioning/hostagent/phase/reboot/sync.go +++ b/internal/provisioning/hostagent/phase/reboot/sync.go @@ -130,18 +130,18 @@ func (r *Handler) reboot(ctx context.Context, dpuNode *provisioningv1.DPUNode, d } } if runPowerCycle { - if err := r.runPowerCycle(dpuNode, rebootNow); err != nil { + if err := r.RunPowerCycle(dpuNode, rebootNow); err != nil { return rebootNow, err } return nil, nil } - if err := r.runSLR(ctx, rebootNow); err != nil { + if err := r.RunSLR(ctx, rebootNow); err != nil { return rebootNow, err } return nil, nil } -func (r *Handler) runPowerCycle(dpuNode *provisioningv1.DPUNode, dpus []provisioningv1.DPU) error { +func (r *Handler) RunPowerCycle(dpuNode *provisioningv1.DPUNode, dpus []provisioningv1.DPU) error { powerCycleCommand, err := reboot.PowerCycleCommand(dpuNode) if err != nil { return fmt.Errorf("failed to get power cycle command: %w", err) @@ -157,7 +157,7 @@ func (r *Handler) runPowerCycle(dpuNode *provisioningv1.DPUNode, dpus []provisio return nil } -func (r *Handler) runSLR(ctx context.Context, toBeRebooted []provisioningv1.DPU) error { +func (r *Handler) RunSLR(ctx context.Context, toBeRebooted []provisioningv1.DPU) error { devs := make([]hostutil.Device, len(toBeRebooted)) for i, dpu := range toBeRebooted { dev, ok := r.getDeviceBySerialNumberFunc(dpu.Spec.SerialNumber) diff --git a/internal/provisioning/hostagent/service/installation_service.go b/internal/provisioning/hostagent/service/installation_service.go index 44922759b..8d19476ba 100644 --- a/internal/provisioning/hostagent/service/installation_service.go +++ b/internal/provisioning/hostagent/service/installation_service.go @@ -27,6 +27,7 @@ import ( "time" provisioningv1 "github.com/nvidia/doca-platform/api/provisioning/v1alpha1" + "github.com/nvidia/doca-platform/internal/provisioning/hostagent/phase/reboot" "github.com/nvidia/doca-platform/internal/provisioning/hostagent/service/types" restful "github.com/emicklei/go-restful/v3" @@ -67,7 +68,7 @@ const ( // NetworkConfigurator is an interface for triggering host network configuration. // It is satisfied by networkmanager.NetworkManager. type NetworkConfigurator interface { - AddNetworkRequest(dpu *provisioningv1.DPU) error + AddNetworkRequest(dpu *provisioningv1.DPU, vfCount *int) error } type InstallationService struct { @@ -78,16 +79,18 @@ type InstallationService struct { // listeners maps interface names to their listeners listeners map[string]net.Listener networkManager NetworkConfigurator + rebootHandler *reboot.Handler // stopCh is closed by Stop() to terminate background goroutines stopCh chan struct{} stopOnce sync.Once } -func NewInstallationService(client client.Client, nm NetworkConfigurator) *InstallationService { +func NewInstallationService(client client.Client, nm NetworkConfigurator, rh *reboot.Handler) *InstallationService { s := &InstallationService{ Client: client, listeners: make(map[string]net.Listener), networkManager: nm, + rebootHandler: rh, stopCh: make(chan struct{}), } ws := new(restful.WebService).Path("/") @@ -110,6 +113,11 @@ func NewInstallationService(client client.Client, nm NetworkConfigurator) *Insta Consumes(restful.MIME_JSON). Produces(restful.MIME_JSON). To(s.ConfigureHostVFs)) + ws.Route( + ws.POST("/trigger-reboot"). + Consumes(restful.MIME_JSON). + Produces(restful.MIME_JSON). + To(s.TriggerReboot)) ws.Route(ws.GET("/healthz").To(s.HealthCheck)) // Package repositories: serve .deb and .rpm packages for DPU provisioning. ws.Route(ws.GET("/deb/{subpath:*}").To(serveRepoFile(debRepoDir))) @@ -329,7 +337,7 @@ func (s *InstallationService) ConfigureHostVFs(req *restful.Request, resp *restf return } - if err := s.networkManager.AddNetworkRequest(dpu); err != nil { + if err := s.networkManager.AddNetworkRequest(dpu, &request.VFCount); err != nil { klog.Errorf("failed to add network request for DPU %s/%s: %v", request.DPUNamespace, request.DPUName, err) _ = resp.WriteError(http.StatusInternalServerError, err) return @@ -339,6 +347,69 @@ func (s *InstallationService) ConfigureHostVFs(req *restful.Request, resp *restf resp.WriteHeader(http.StatusOK) } +func (s *InstallationService) TriggerReboot(req *restful.Request, resp *restful.Response) { + var request types.TriggerRebootRequest + if err := req.ReadEntity(&request); err != nil { + klog.Errorf("failed to read trigger reboot request: %v", err) + _ = resp.WriteError(http.StatusBadRequest, err) + return + } + klog.Infof("Received trigger reboot request: %#v", request) + + ctx := req.Request.Context() + + dpu := &provisioningv1.DPU{} + if err := s.Get(ctx, client.ObjectKey{Namespace: request.DPUNamespace, Name: request.DPUName}, dpu); err != nil { + klog.Errorf("failed to get DPU %s/%s: %v", request.DPUNamespace, request.DPUName, err) + if apierrors.IsNotFound(err) { + _ = resp.WriteError(http.StatusNotFound, err) + } else { + _ = resp.WriteError(http.StatusInternalServerError, err) + } + return + } + + if string(dpu.UID) != request.DPUUID { + klog.Warningf("Rejecting trigger reboot request for DPU %s/%s: request UID %q does not match current DPU UID %q", + request.DPUNamespace, request.DPUName, request.DPUUID, dpu.UID) + _ = resp.WriteError(http.StatusConflict, fmt.Errorf("stale DPU object: expected UID %q but got %q", request.DPUUID, dpu.UID)) + return + } + + // Detach from the HTTP request context: the request arrives over tmfifo, + // and shutting down the ARM severs that connection. + rebootCtx := context.WithoutCancel(ctx) + + switch request.RebootMethod { + case provisioningv1.RebootMethodSystemLevelReset, + provisioningv1.RebootMethodFirmwareReset, + provisioningv1.RebootMethodSystemReboot: + if err := s.rebootHandler.RunSLR(rebootCtx, []provisioningv1.DPU{*dpu}); err != nil { + klog.Errorf("SLR failed for DPU %s/%s: %v", request.DPUNamespace, request.DPUName, err) + _ = resp.WriteError(http.StatusInternalServerError, err) + return + } + case provisioningv1.RebootMethodPowerCycle: + dpuNode := &provisioningv1.DPUNode{} + if err := s.Get(rebootCtx, client.ObjectKey{Name: dpu.Spec.DPUNodeName}, dpuNode); err != nil { + klog.Errorf("failed to get DPUNode %s: %v", dpu.Spec.DPUNodeName, err) + _ = resp.WriteError(http.StatusInternalServerError, err) + return + } + if err := s.rebootHandler.RunPowerCycle(dpuNode, []provisioningv1.DPU{*dpu}); err != nil { + klog.Errorf("PowerCycle failed for DPU %s/%s: %v", request.DPUNamespace, request.DPUName, err) + _ = resp.WriteError(http.StatusInternalServerError, err) + return + } + default: + _ = resp.WriteError(http.StatusBadRequest, fmt.Errorf("unsupported reboot method: %q", request.RebootMethod)) + return + } + + klog.Infof("Successfully triggered reboot (%s) for DPU %s/%s", request.RebootMethod, request.DPUNamespace, request.DPUName) + resp.WriteHeader(http.StatusOK) +} + func (s *InstallationService) UpdateStatus(req *restful.Request, resp *restful.Response) { var request types.UpdateStatusRequest if err := req.ReadEntity(&request); err != nil { diff --git a/internal/provisioning/hostagent/service/installation_service_test.go b/internal/provisioning/hostagent/service/installation_service_test.go index 105b658d7..811bb80bf 100644 --- a/internal/provisioning/hostagent/service/installation_service_test.go +++ b/internal/provisioning/hostagent/service/installation_service_test.go @@ -85,7 +85,7 @@ var _ = Describe("InstallationService", func() { testNS = &corev1.Namespace{ObjectMeta: metav1.ObjectMeta{GenerateName: "installation-service-testns-"}} Expect(k8sClient.Create(ctx, testNS)).To(Succeed()) - installationService = NewInstallationService(k8sClient, nil) + installationService = NewInstallationService(k8sClient, nil, nil) Expect(installationService.Start(false)).To(Succeed()) // Start() runs the server in a goroutine; wait until it is listening to avoid connection refused. Eventually(func() error { diff --git a/internal/provisioning/hostagent/service/types/types.go b/internal/provisioning/hostagent/service/types/types.go index 9cf0e47de..3f7810ffe 100644 --- a/internal/provisioning/hostagent/service/types/types.go +++ b/internal/provisioning/hostagent/service/types/types.go @@ -30,4 +30,12 @@ type UpdateStatusRequest struct { type ConfigureHostVFsRequest struct { DPUName string `json:"dpuName"` DPUNamespace string `json:"dpuNamespace"` + VFCount int `json:"vfCount"` +} + +type TriggerRebootRequest struct { + DPUName string `json:"dpuName"` + DPUNamespace string `json:"dpuNamespace"` + DPUUID string `json:"dpuUID"` + RebootMethod provisioningv1.RebootMethodType `json:"rebootMethod"` } diff --git a/internal/provisioning/hostagent/util/netconfig/backend.go b/internal/provisioning/hostagent/util/netconfig/backend.go index 78745666b..70302d65d 100644 --- a/internal/provisioning/hostagent/util/netconfig/backend.go +++ b/internal/provisioning/hostagent/util/netconfig/backend.go @@ -47,6 +47,11 @@ type Backend interface { // IsDHCPConfigured checks if DHCP is enabled for an interface. IsDHCPConfigured(interfaceName string) (bool, error) + + // EnsureVFsUnmanaged ensures that VF interfaces will not be managed by the + // network configuration backend. For NetworkManager this writes a udev rule; + // other backends may no-op. + EnsureVFsUnmanaged() error } // ConfigureNetwork orchestrates PF interface and bridge MTU configuration diff --git a/internal/provisioning/hostagent/util/netconfig/nm_backend.go b/internal/provisioning/hostagent/util/netconfig/nm_backend.go index 4275bc0e4..a93359e26 100644 --- a/internal/provisioning/hostagent/util/netconfig/nm_backend.go +++ b/internal/provisioning/hostagent/util/netconfig/nm_backend.go @@ -53,6 +53,8 @@ var ( getInterfaceNameFunc = func(pciAddress string, portNumber int) (string, error) { return hostutil.NewPCIHelper(pciAddress).PF(portNumber).InterfaceName() } + isVFFunc = hostutil.IsVF + setLinkMTUFunc = hostutil.SetLinkMTU ) // NetworkManagerBackend implements Backend using NetworkManager via D-Bus. @@ -77,6 +79,10 @@ func (n *NetworkManagerBackend) ResetPendingChanges() { n.modifiedConnPaths = nil } +func (n *NetworkManagerBackend) EnsureVFsUnmanaged() error { + return ensureNMUnmanagedUdevRule() +} + // ConfigurePFInterfaces configures physical function network interfaces via NM D-Bus. func (n *NetworkManagerBackend) ConfigurePFInterfaces(pciAddress string, portConfigs []hostutil.PortConfig) (bool, error) { needsApply := false @@ -108,7 +114,7 @@ func (n *NetworkManagerBackend) configureSinglePF(pciAddress string, portConfig updateSettings := make(ConnectionSettings) - if err := n.collectMTUDiff(interfaceName, portConfig.MTU, updateSettings); err != nil { + if err := n.collectMTUDiff(connPath, interfaceName, portConfig.MTU, updateSettings); err != nil { return false, err } if err := n.collectDHCPDiff(interfaceName, portConfig.DHCP, updateSettings); err != nil { @@ -126,10 +132,19 @@ func (n *NetworkManagerBackend) configureSinglePF(pciAddress string, portConfig return true, nil } -func (n *NetworkManagerBackend) collectMTUDiff(interfaceName string, desiredMTU *int32, out ConnectionSettings) error { +func (n *NetworkManagerBackend) collectMTUDiff(connPath ConnectionPath, interfaceName string, desiredMTU *int32, out ConnectionSettings) error { if desiredMTU == nil { return nil } + + // Check the NM profile MTU first. If the profile already has the desired + // value, skip — re-activating the connection would bounce the interface + // and temporarily reset the link MTU to the driver default. + if profileMTU, err := n.getProfileMTU(connPath); err == nil && profileMTU == uint32(*desiredMTU) { + klog.V(3).Infof("%s NM profile MTU already %d, skipping", interfaceName, profileMTU) + return nil + } + currentMTU, err := getCurrentMTUFunc(interfaceName) if err != nil { return fmt.Errorf("failed to get current MTU for %s: %w", interfaceName, err) @@ -144,6 +159,26 @@ func (n *NetworkManagerBackend) collectMTUDiff(interfaceName string, desiredMTU return nil } +func (n *NetworkManagerBackend) getProfileMTU(connPath ConnectionPath) (uint32, error) { + settings, err := n.client.GetConnectionSettings(connPath) + if err != nil { + return 0, err + } + ethSection, ok := settings[nmSectionEthernet] + if !ok { + return 0, fmt.Errorf("no %s section", nmSectionEthernet) + } + mtuVariant, ok := ethSection["mtu"] + if !ok { + return 0, fmt.Errorf("no mtu property") + } + mtu, ok := mtuVariant.Value().(uint32) + if !ok { + return 0, fmt.Errorf("mtu is not uint32") + } + return mtu, nil +} + func (n *NetworkManagerBackend) collectDHCPDiff(interfaceName string, desiredDHCP *bool, out ConnectionSettings) error { if desiredDHCP == nil { return nil @@ -229,6 +264,14 @@ func (n *NetworkManagerBackend) configureBridgeMembersMTU(bridgeName string, mtu klog.Infof("Bridge member %s MTU mismatch (current=%d, desired=%d)", memberName, currentMTU, mtu) + if isVFFunc(memberName) { + klog.Infof("Member %s is a VF, setting MTU via netlink", memberName) + if err := setLinkMTUFunc(memberName, mtu); err != nil { + return false, fmt.Errorf("failed to set MTU for VF %s via netlink: %w", memberName, err) + } + continue + } + connPath, err := n.getOrCreateConnectionForInterface(memberName, nmConnTypeEthernet) if err != nil { return false, fmt.Errorf("failed to get/create connection for member %s: %w", memberName, err) @@ -239,8 +282,9 @@ func (n *NetworkManagerBackend) configureBridgeMembersMTU(bridgeName string, mtu "mtu": dbus.MakeVariant(uint32(mtu)), }, "connection": map[string]dbus.Variant{ - "master": dbus.MakeVariant(bridgeName), - "slave-type": dbus.MakeVariant("bridge"), + "master": dbus.MakeVariant(bridgeName), + "slave-type": dbus.MakeVariant("bridge"), + nmPropInterfaceName: dbus.MakeVariant(memberName), }, } if err := n.mergeAndUpdateConnection(connPath, settings); err != nil { diff --git a/internal/provisioning/hostagent/util/netconfig/nm_backend_test.go b/internal/provisioning/hostagent/util/netconfig/nm_backend_test.go index 11c601a31..c73550e40 100644 --- a/internal/provisioning/hostagent/util/netconfig/nm_backend_test.go +++ b/internal/provisioning/hostagent/util/netconfig/nm_backend_test.go @@ -35,6 +35,8 @@ var _ = Describe("NetworkManagerBackend", func() { origGetCurrentMTU func(string) (int, error) origGetBridgeMembers func(string) ([]string, error) origGetIfaceName func(string, int) (string, error) + origIsVF func(string) bool + origSetLinkMTU func(string, int) error ) BeforeEach(func() { @@ -44,12 +46,18 @@ var _ = Describe("NetworkManagerBackend", func() { origGetCurrentMTU = getCurrentMTUFunc origGetBridgeMembers = getBridgeMembersFunc origGetIfaceName = getInterfaceNameFunc + origIsVF = isVFFunc + origSetLinkMTU = setLinkMTUFunc + + isVFFunc = func(string) bool { return false } }) AfterEach(func() { getCurrentMTUFunc = origGetCurrentMTU getBridgeMembersFunc = origGetBridgeMembers getInterfaceNameFunc = origGetIfaceName + isVFFunc = origIsVF + setLinkMTUFunc = origSetLinkMTU }) It("should report its name", func() { @@ -240,6 +248,22 @@ var _ = Describe("NetworkManagerBackend", func() { Expect(backend.modifiedConnPaths).To(BeEmpty()) }) + It("should skip when NM profile MTU already matches even if kernel MTU differs", func() { + getCurrentMTUFunc = func(name string) (int, error) { return 1500, nil } + mock.addTestConnection("/conn/eth0", ConnectionSettings{ + "connection": {"id": dbus.MakeVariant("eth0"), "interface-name": dbus.MakeVariant("eth0")}, + "802-3-ethernet": {"mtu": dbus.MakeVariant(uint32(9000))}, + }) + + mtu := int32(9000) + needsApply, err := backend.ConfigurePFInterfaces("0000:4d:00", []hostutil.PortConfig{ + {PortNumber: 0, MTU: &mtu}, + }) + Expect(err).NotTo(HaveOccurred()) + Expect(needsApply).To(BeFalse(), "should not re-activate when NM profile already has correct MTU") + Expect(backend.modifiedConnPaths).To(BeEmpty()) + }) + It("should update DHCP when current differs", func() { getCurrentMTUFunc = func(name string) (int, error) { return 1500, nil } mock.addTestConnection("/conn/eth0", ConnectionSettings{ @@ -303,6 +327,7 @@ var _ = Describe("NetworkManagerBackend", func() { Expect(memberUpdated["802-3-ethernet"]["mtu"].Value()).To(Equal(uint32(9000))) Expect(memberUpdated["connection"]["master"].Value()).To(Equal("br-dpu")) Expect(memberUpdated["connection"]["slave-type"].Value()).To(Equal("bridge")) + Expect(memberUpdated["connection"]["interface-name"].Value()).To(Equal("enp1s0f0")) }) It("should skip when MTUs already match", func() { @@ -318,6 +343,37 @@ var _ = Describe("NetworkManagerBackend", func() { Expect(needsApply).To(BeFalse()) }) + It("should set VF member MTU via netlink instead of NM", func() { + getCurrentMTUFunc = func(name string) (int, error) { return 1500, nil } + getBridgeMembersFunc = func(name string) ([]string, error) { return []string{"ens7f0v0"}, nil } + isVFFunc = func(name string) bool { return name == "ens7f0v0" } + + var netlinkMTUCalls []struct { + name string + mtu int + } + setLinkMTUFunc = func(name string, mtu int) error { + netlinkMTUCalls = append(netlinkMTUCalls, struct { + name string + mtu int + }{name, mtu}) + return nil + } + + mock.addTestConnection("/conn/br", ConnectionSettings{ + "connection": {"id": dbus.MakeVariant("br-dpu"), "interface-name": dbus.MakeVariant("br-dpu")}, + }) + + needsApply, err := backend.ConfigureBridgeMTU("br-dpu", 9000) + Expect(err).NotTo(HaveOccurred()) + Expect(needsApply).To(BeTrue()) + Expect(netlinkMTUCalls).To(HaveLen(1)) + Expect(netlinkMTUCalls[0].name).To(Equal("ens7f0v0")) + Expect(netlinkMTUCalls[0].mtu).To(Equal(9000)) + Expect(mock.updatedMap).NotTo(HaveKey(ConnectionPath("/conn/vf")), + "VF should not be updated via NM") + }) + It("should propagate GetBridgeMembers errors", func() { getCurrentMTUFunc = func(name string) (int, error) { return 9000, nil } getBridgeMembersFunc = func(name string) ([]string, error) { diff --git a/internal/provisioning/hostagent/util/netconfig/nm_udev.go b/internal/provisioning/hostagent/util/netconfig/nm_udev.go new file mode 100644 index 000000000..293fbcee7 --- /dev/null +++ b/internal/provisioning/hostagent/util/netconfig/nm_udev.go @@ -0,0 +1,98 @@ +/* +Copyright 2026 NVIDIA + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package netconfig + +import ( + "fmt" + "os" + "os/exec" + "path/filepath" + + "k8s.io/klog/v2" +) + +const ( + nmUnmanagedRulesContent = `ACTION=="add|change|move", ATTRS{device}=="0x101e", ENV{NM_UNMANAGED}="1" +` +) + +// nmUnmanagedRulesPath is the file path for the udev rule. Variable for testability. +var nmUnmanagedRulesPath = "/etc/udev/rules.d/10-nm-unmanaged.rules" + +// udevRunner abstracts command execution for testability. +var udevRunner = func(name string, args ...string) ([]byte, error) { + return exec.Command(name, args...).CombinedOutput() +} + +// udevRulesApplied tracks whether udev reload/trigger has succeeded at least +// once since the process started. This prevents skipping the reload/trigger +// when the rule file already exists but a previous udevadm invocation failed. +var udevRulesApplied bool + +// ensureNMUnmanagedUdevRule writes a udev rule that prevents NetworkManager +// from managing VF interfaces (PCI device ID 0x101e) and reloads/triggers +// udev to apply the rule to both new and already-existing devices. +func ensureNMUnmanagedUdevRule() error { + written, err := writeUdevRuleFile() + if err != nil { + return fmt.Errorf("failed to write udev rule file: %w", err) + } + if !written && udevRulesApplied { + return nil + } + + if err := reloadAndTriggerUdev(); err != nil { + return fmt.Errorf("failed to reload/trigger udev rules: %w", err) + } + udevRulesApplied = true + + return nil +} + +func writeUdevRuleFile() (bool, error) { + dir := filepath.Dir(nmUnmanagedRulesPath) + if err := os.MkdirAll(dir, 0755); err != nil { + return false, fmt.Errorf("failed to create directory %s: %w", dir, err) + } + + existing, err := os.ReadFile(nmUnmanagedRulesPath) + if err == nil && string(existing) == nmUnmanagedRulesContent { + klog.V(3).Infof("Udev rule %s already up-to-date", nmUnmanagedRulesPath) + return false, nil + } + + if err := os.WriteFile(nmUnmanagedRulesPath, []byte(nmUnmanagedRulesContent), 0644); err != nil { + return false, fmt.Errorf("failed to write file %s: %w", nmUnmanagedRulesPath, err) + } + klog.Infof("Wrote udev rule to disable NM management of VFs: %s", nmUnmanagedRulesPath) + return true, nil +} + +func reloadAndTriggerUdev() error { + output, err := udevRunner("udevadm", "control", "--reload-rules") + if err != nil { + return fmt.Errorf("udevadm control --reload-rules failed: %w, output: %s", err, string(output)) + } + + output, err = udevRunner("udevadm", "trigger", "--subsystem-match=net") + if err != nil { + return fmt.Errorf("udevadm trigger --subsystem-match=net failed: %w, output: %s", err, string(output)) + } + + klog.V(3).Infof("Reloaded udev rules and triggered net subsystem") + return nil +} diff --git a/internal/provisioning/hostagent/util/netconfig/nm_udev_test.go b/internal/provisioning/hostagent/util/netconfig/nm_udev_test.go new file mode 100644 index 000000000..e15974a9b --- /dev/null +++ b/internal/provisioning/hostagent/util/netconfig/nm_udev_test.go @@ -0,0 +1,170 @@ +/* +Copyright 2026 NVIDIA + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package netconfig + +import ( + "fmt" + "os" + "path/filepath" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("ensureNMUnmanagedUdevRule", func() { + var ( + origPath string + origRunner func(string, ...string) ([]byte, error) + origRulesApplied bool + tempDir string + commands [][]string + ) + + BeforeEach(func() { + var err error + tempDir, err = os.MkdirTemp("", "udev-test-*") + Expect(err).NotTo(HaveOccurred()) + + origPath = nmUnmanagedRulesPath + origRunner = udevRunner + origRulesApplied = udevRulesApplied + udevRulesApplied = false + + commands = nil + udevRunner = func(name string, args ...string) ([]byte, error) { + commands = append(commands, append([]string{name}, args...)) + return nil, nil + } + }) + + AfterEach(func() { + nmUnmanagedRulesPath = origPath + udevRunner = origRunner + udevRulesApplied = origRulesApplied + os.RemoveAll(tempDir) + }) + + setRulesPath := func() string { + p := filepath.Join(tempDir, "10-nm-unmanaged.rules") + nmUnmanagedRulesPath = p + return p + } + + It("should write the udev rule file and reload rules", func() { + rulesFile := setRulesPath() + + err := ensureNMUnmanagedUdevRule() + Expect(err).NotTo(HaveOccurred()) + + content, err := os.ReadFile(rulesFile) + Expect(err).NotTo(HaveOccurred()) + Expect(string(content)).To(Equal(nmUnmanagedRulesContent)) + + Expect(commands).To(HaveLen(2)) + Expect(commands[0]).To(Equal([]string{"udevadm", "control", "--reload-rules"})) + Expect(commands[1]).To(Equal([]string{"udevadm", "trigger", "--subsystem-match=net"})) + }) + + It("should skip reload/trigger when file is up-to-date and rules were previously applied", func() { + rulesFile := setRulesPath() + udevRulesApplied = true + + err := os.MkdirAll(filepath.Dir(rulesFile), 0755) + Expect(err).NotTo(HaveOccurred()) + err = os.WriteFile(rulesFile, []byte(nmUnmanagedRulesContent), 0644) + Expect(err).NotTo(HaveOccurred()) + + err = ensureNMUnmanagedUdevRule() + Expect(err).NotTo(HaveOccurred()) + + Expect(commands).To(BeEmpty()) + }) + + It("should retry reload/trigger when file is up-to-date but previous udevadm failed", func() { + rulesFile := setRulesPath() + + err := os.MkdirAll(filepath.Dir(rulesFile), 0755) + Expect(err).NotTo(HaveOccurred()) + err = os.WriteFile(rulesFile, []byte(nmUnmanagedRulesContent), 0644) + Expect(err).NotTo(HaveOccurred()) + + err = ensureNMUnmanagedUdevRule() + Expect(err).NotTo(HaveOccurred()) + + Expect(commands).To(HaveLen(2)) + Expect(commands[0]).To(Equal([]string{"udevadm", "control", "--reload-rules"})) + Expect(commands[1]).To(Equal([]string{"udevadm", "trigger", "--subsystem-match=net"})) + Expect(udevRulesApplied).To(BeTrue()) + }) + + It("should overwrite if content differs", func() { + rulesFile := setRulesPath() + + err := os.MkdirAll(filepath.Dir(rulesFile), 0755) + Expect(err).NotTo(HaveOccurred()) + err = os.WriteFile(rulesFile, []byte("old content"), 0644) + Expect(err).NotTo(HaveOccurred()) + + err = ensureNMUnmanagedUdevRule() + Expect(err).NotTo(HaveOccurred()) + + content, err := os.ReadFile(rulesFile) + Expect(err).NotTo(HaveOccurred()) + Expect(string(content)).To(Equal(nmUnmanagedRulesContent)) + }) + + It("should create parent directories if they don't exist", func() { + nmUnmanagedRulesPath = filepath.Join(tempDir, "subdir", "rules.d", "10-nm-unmanaged.rules") + + err := ensureNMUnmanagedUdevRule() + Expect(err).NotTo(HaveOccurred()) + + content, err := os.ReadFile(nmUnmanagedRulesPath) + Expect(err).NotTo(HaveOccurred()) + Expect(string(content)).To(Equal(nmUnmanagedRulesContent)) + }) + + It("should return error if udevadm reload fails", func() { + setRulesPath() + + udevRunner = func(name string, args ...string) ([]byte, error) { + if len(args) > 0 && args[0] == "control" { + return []byte("reload failed"), fmt.Errorf("exit status 1") + } + return nil, nil + } + + err := ensureNMUnmanagedUdevRule() + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("udevadm control --reload-rules failed")) + }) + + It("should return error if udevadm trigger fails", func() { + setRulesPath() + + udevRunner = func(name string, args ...string) ([]byte, error) { + if len(args) > 0 && args[0] == "trigger" { + return []byte("trigger failed"), fmt.Errorf("exit status 1") + } + return nil, nil + } + + err := ensureNMUnmanagedUdevRule() + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("udevadm trigger --subsystem-match=net failed")) + }) +}) diff --git a/internal/provisioning/hostagent/util/netconfig/systemd_networkd.go b/internal/provisioning/hostagent/util/netconfig/systemd_networkd.go index c0d3bd01e..7f1c7f764 100644 --- a/internal/provisioning/hostagent/util/netconfig/systemd_networkd.go +++ b/internal/provisioning/hostagent/util/netconfig/systemd_networkd.go @@ -59,3 +59,7 @@ func (s *SystemdNetworkdBackend) ApplyConfiguration() error { func (s *SystemdNetworkdBackend) IsDHCPConfigured(interfaceName string) (bool, error) { return hostutil.IsDHCPConfigured(interfaceName) } + +func (s *SystemdNetworkdBackend) EnsureVFsUnmanaged() error { + return nil +} diff --git a/internal/provisioning/hostagent/util/network.go b/internal/provisioning/hostagent/util/network.go index 14eae60be..0d7bf45d8 100644 --- a/internal/provisioning/hostagent/util/network.go +++ b/internal/provisioning/hostagent/util/network.go @@ -498,3 +498,9 @@ func writeBridgeMTUConfig(controlPlaneMTU int) error { return writeNetplanFile(BridgeMTUNetplanFile, &config) } + +// IsVF returns true if the given network interface is a PCI Virtual Function. +func IsVF(interfaceName string) bool { + _, err := os.Stat(filepath.Join("/sys/class/net", interfaceName, "device", "physfn")) + return err == nil +}