diff --git a/pkg/devices/rdma.go b/pkg/devices/rdma.go index 60d64b429..f60036d22 100644 --- a/pkg/devices/rdma.go +++ b/pkg/devices/rdma.go @@ -18,6 +18,7 @@ package devices import ( + "github.com/golang/glog" pluginapi "k8s.io/kubelet/pkg/apis/deviceplugin/v1beta1" "github.com/k8snetworkplumbingwg/sriov-network-device-plugin/pkg/types" @@ -25,15 +26,47 @@ import ( ) type rdmaSpec struct { - isSupportRdma bool - deviceSpec []*pluginapi.DeviceSpec + deviceID string + deviceType types.DeviceType } -func newRdmaSpec(rdmaResources []string) types.RdmaSpec { +// NewRdmaSpec returns the RdmaSpec +func NewRdmaSpec(dt types.DeviceType, id string) types.RdmaSpec { + if dt == types.AcceleratorType { + return nil + } + return &rdmaSpec{deviceID: id, deviceType: dt} +} + +func (r *rdmaSpec) IsRdma() bool { + if len(r.getRdmaResources()) > 0 { + return true + } + // Checking for netlink param for exclusive RDMA use case + rdma, err := utils.HasRdmaParam(r.deviceID) + if err != nil { + glog.Infof("HasRdmaParam(): unable to get Netlink RDMA param for device %s : %q", r.deviceID, err) + return false + } + return rdma +} + +func (r *rdmaSpec) getRdmaResources() []string { + //nolint: exhaustive + switch r.deviceType { + case types.NetDeviceType: + return utils.GetRdmaProvider().GetRdmaDevicesForPcidev(r.deviceID) + case types.AuxNetDeviceType: + return utils.GetRdmaProvider().GetRdmaDevicesForAuxdev(r.deviceID) + default: + return make([]string, 0) + } +} + +func (r *rdmaSpec) GetRdmaDeviceSpec() []*pluginapi.DeviceSpec { + rdmaResources := r.getRdmaResources() deviceSpec := make([]*pluginapi.DeviceSpec, 0) - isSupportRdma := false if len(rdmaResources) > 0 { - isSupportRdma = true for _, res := range rdmaResources { resRdmaDevices := utils.GetRdmaProvider().GetRdmaCharDevices(res) for _, rdmaDevice := range resRdmaDevices { @@ -45,26 +78,5 @@ func newRdmaSpec(rdmaResources []string) types.RdmaSpec { } } } - - return &rdmaSpec{isSupportRdma: isSupportRdma, deviceSpec: deviceSpec} -} - -// NewRdmaSpec returns the RdmaSpec for PCI address -func NewRdmaSpec(pciAddr string) types.RdmaSpec { - rdmaResources := utils.GetRdmaProvider().GetRdmaDevicesForPcidev(pciAddr) - return newRdmaSpec(rdmaResources) -} - -// NewAuxRdmaSpec returns the RdmaSpec for auxiliary device ID -func NewAuxRdmaSpec(deviceID string) types.RdmaSpec { - rdmaResources := utils.GetRdmaProvider().GetRdmaDevicesForAuxdev(deviceID) - return newRdmaSpec(rdmaResources) -} - -func (r *rdmaSpec) IsRdma() bool { - return r.isSupportRdma -} - -func (r *rdmaSpec) GetRdmaDeviceSpec() []*pluginapi.DeviceSpec { - return r.deviceSpec + return deviceSpec } diff --git a/pkg/devices/rdma_test.go b/pkg/devices/rdma_test.go index cd321b6d9..07c6ed122 100644 --- a/pkg/devices/rdma_test.go +++ b/pkg/devices/rdma_test.go @@ -23,6 +23,7 @@ import ( pluginapi "k8s.io/kubelet/pkg/apis/deviceplugin/v1beta1" "github.com/k8snetworkplumbingwg/sriov-network-device-plugin/pkg/devices" + "github.com/k8snetworkplumbingwg/sriov-network-device-plugin/pkg/types" "github.com/k8snetworkplumbingwg/sriov-network-device-plugin/pkg/utils" "github.com/k8snetworkplumbingwg/sriov-network-device-plugin/pkg/utils/mocks" ) @@ -31,16 +32,32 @@ var _ = Describe("RdmaSpec", func() { Describe("creating new RdmaSpec", func() { t := GinkgoT() Context("successfully", func() { - It("without device specs", func() { + It("without device specs, without netlik enable_rdma param", func() { + mockProvider := &mocks.NetlinkProvider{} + mockProvider.On("HasRdmaParam", "0000:00:00.0").Return(false, nil) + utils.SetNetlinkProviderInst(mockProvider) fakeRdmaProvider := mocks.RdmaProvider{} fakeRdmaProvider.On("GetRdmaDevicesForPcidev", "0000:00:00.0").Return([]string{}) utils.SetRdmaProviderInst(&fakeRdmaProvider) - spec := devices.NewRdmaSpec("0000:00:00.0") + spec := devices.NewRdmaSpec(types.NetDeviceType, "0000:00:00.0") Expect(spec.IsRdma()).To(BeFalse()) Expect(spec.GetRdmaDeviceSpec()).To(HaveLen(0)) fakeRdmaProvider.AssertExpectations(t) }) + It("without device specs, with netlik enable_rdma param", func() { + mockProvider := &mocks.NetlinkProvider{} + mockProvider.On("HasRdmaParam", "0000:00:00.0").Return(true, nil) + utils.SetNetlinkProviderInst(mockProvider) + fakeRdmaProvider := mocks.RdmaProvider{} + fakeRdmaProvider.On("GetRdmaDevicesForPcidev", "0000:00:00.0").Return([]string{}) + utils.SetRdmaProviderInst(&fakeRdmaProvider) + spec := devices.NewRdmaSpec(types.NetDeviceType, "0000:00:00.0") + + Expect(spec.IsRdma()).To(BeTrue()) + Expect(spec.GetRdmaDeviceSpec()).To(HaveLen(0)) + fakeRdmaProvider.AssertExpectations(t) + }) It("with device specs", func() { fakeRdmaProvider := mocks.RdmaProvider{} fakeRdmaProvider.On("GetRdmaDevicesForPcidev", "0000:00:00.0"). @@ -50,7 +67,7 @@ var _ = Describe("RdmaSpec", func() { "/dev/infiniband/uverbs0", "/dev/infiniband/rdma_cm", }).On("GetRdmaCharDevices", "fake_1").Return([]string{"/dev/infiniband/rdma_cm"}) utils.SetRdmaProviderInst(&fakeRdmaProvider) - spec := devices.NewRdmaSpec("0000:00:00.0") + spec := devices.NewRdmaSpec(types.NetDeviceType, "0000:00:00.0") Expect(spec.IsRdma()).To(BeTrue()) Expect(spec.GetRdmaDeviceSpec()).To(Equal([]*pluginapi.DeviceSpec{ diff --git a/pkg/factory/factory.go b/pkg/factory/factory.go index 15a287658..3b3ad938b 100644 --- a/pkg/factory/factory.go +++ b/pkg/factory/factory.go @@ -163,15 +163,7 @@ func (rf *resourceFactory) GetResourcePool(rc *types.ResourceConfig, filteredDev } func (rf *resourceFactory) GetRdmaSpec(dt types.DeviceType, deviceID string) types.RdmaSpec { - //nolint: exhaustive - switch dt { - case types.NetDeviceType: - return devices.NewRdmaSpec(deviceID) - case types.AuxNetDeviceType: - return devices.NewAuxRdmaSpec(deviceID) - default: - return nil - } + return devices.NewRdmaSpec(dt, deviceID) } func (rf *resourceFactory) GetVdpaDevice(pciAddr string) types.VdpaDevice { diff --git a/pkg/factory/factory_test.go b/pkg/factory/factory_test.go index 8d5251627..15b06fcdc 100644 --- a/pkg/factory/factory_test.go +++ b/pkg/factory/factory_test.go @@ -25,10 +25,12 @@ import ( "github.com/k8snetworkplumbingwg/sriov-network-device-plugin/pkg/types" "github.com/k8snetworkplumbingwg/sriov-network-device-plugin/pkg/types/mocks" "github.com/k8snetworkplumbingwg/sriov-network-device-plugin/pkg/utils" + utilmocks "github.com/k8snetworkplumbingwg/sriov-network-device-plugin/pkg/utils/mocks" . "github.com/onsi/ginkgo" . "github.com/onsi/ginkgo/extensions/table" . "github.com/onsi/gomega" + "github.com/stretchr/testify/mock" pluginapi "k8s.io/kubelet/pkg/apis/deviceplugin/v1beta1" ) @@ -606,6 +608,9 @@ var _ = Describe("Factory", func() { ) Describe("getting rdma spec", func() { Context("check c rdma spec", func() { + mockProvider := &utilmocks.NetlinkProvider{} + mockProvider.On("HasRdmaParam", mock.AnythingOfType("string")).Return(false, nil) + utils.SetNetlinkProviderInst(mockProvider) f := factory.NewResourceFactory("fake", "fake", true, false) rs1 := f.GetRdmaSpec(types.NetDeviceType, "0000:00:00.1") rs2 := f.GetRdmaSpec(types.AcceleratorType, "0000:00:00.2") diff --git a/pkg/utils/mocks/NetlinkProvider.go b/pkg/utils/mocks/NetlinkProvider.go index d9f4c2caf..55e63e157 100644 --- a/pkg/utils/mocks/NetlinkProvider.go +++ b/pkg/utils/mocks/NetlinkProvider.go @@ -132,6 +132,34 @@ func (_m *NetlinkProvider) GetLinkAttrs(ifName string) (*netlink.LinkAttrs, erro return r0, r1 } +// HasRdmaParam provides a mock function with given fields: pciAddr +func (_m *NetlinkProvider) HasRdmaParam(pciAddr string) (bool, error) { + ret := _m.Called(pciAddr) + + if len(ret) == 0 { + panic("no return value specified for HasRdmaParam") + } + + var r0 bool + var r1 error + if rf, ok := ret.Get(0).(func(string) (bool, error)); ok { + return rf(pciAddr) + } + if rf, ok := ret.Get(0).(func(string) bool); ok { + r0 = rf(pciAddr) + } else { + r0 = ret.Get(0).(bool) + } + + if rf, ok := ret.Get(1).(func(string) error); ok { + r1 = rf(pciAddr) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // NewNetlinkProvider creates a new instance of NetlinkProvider. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. // The first argument is typically a *testing.T value. func NewNetlinkProvider(t interface { diff --git a/pkg/utils/netlink_provider.go b/pkg/utils/netlink_provider.go index a87074d7b..a6561dcb3 100644 --- a/pkg/utils/netlink_provider.go +++ b/pkg/utils/netlink_provider.go @@ -31,6 +31,8 @@ type NetlinkProvider interface { GetIPv4RouteList(ifName string) ([]nl.Route, error) // DevlinkGetDeviceInfoByNameAsMap returns devlink info for selected device as a map GetDevlinkGetDeviceInfoByNameAsMap(bus, device string) (map[string]string, error) + // HasRdmaParam returns true if PCI device has "enable_rdma" param + HasRdmaParam(pciAddr string) (bool, error) } type defaultNetlinkProvider struct { @@ -48,6 +50,19 @@ func GetNetlinkProvider() NetlinkProvider { return netlinkProvider } +// HasRdmaParam returns true if PCI device has "enable_rdma" param +// equivalent to "devlink dev param show pci/0000:d8:01.1 name enable_rdma" +func (defaultNetlinkProvider) HasRdmaParam(pciAddr string) (bool, error) { + param, err := nl.DevlinkGetDeviceParamByName("pci", pciAddr, "enable_rdma") + if err != nil { + return false, fmt.Errorf("error getting enable_rdma attribute for pci device %s %v", pciAddr, err) + } + if len(param.Values) == 0 || param.Values[0].Data == nil { + return false, nil + } + return true, nil +} + // GetLinkAttrs returns a net device's link attributes. func (defaultNetlinkProvider) GetLinkAttrs(ifName string) (*nl.LinkAttrs, error) { link, err := nl.LinkByName(ifName) diff --git a/pkg/utils/utils.go b/pkg/utils/utils.go index 262525035..db836f43b 100644 --- a/pkg/utils/utils.go +++ b/pkg/utils/utils.go @@ -474,6 +474,15 @@ func GetPfEswitchMode(pciAddr string) (string, error) { return devLinkDeviceAttrs.Mode, nil } +// HasRdmaParam returns true if PCI device has "enable_rdma" param +func HasRdmaParam(pciAddr string) (bool, error) { + rdma, err := GetNetlinkProvider().HasRdmaParam(pciAddr) + if err != nil { + return false, err + } + return rdma, nil +} + // HasDefaultRoute returns true if PCI network device is default route interface func HasDefaultRoute(pciAddr string) (bool, error) { // Get net interface name