diff --git a/cmd/device-plugin/main.go b/cmd/device-plugin/main.go index 3dc1f1b..9b8cab2 100644 --- a/cmd/device-plugin/main.go +++ b/cmd/device-plugin/main.go @@ -42,10 +42,13 @@ func main() { initNvidiaSmi() initPreloaders() - devicePlugin := deviceplugin.NewDevicePlugin(topology, kubeClient) - if err = devicePlugin.Serve(); err != nil { - log.Printf("Failed to serve device plugin: %s\n", err) - os.Exit(1) + devicePlugins := deviceplugin.NewDevicePlugins(topology, kubeClient) + for _, devicePlugin := range devicePlugins { + log.Printf("Starting device plugin for %s\n", devicePlugin.Name()) + if err = devicePlugin.Serve(); err != nil { + log.Printf("Failed to serve device plugin: %s\n", err) + os.Exit(1) + } } sig := make(chan os.Signal, 1) diff --git a/internal/common/topology/const.go b/internal/common/topology/const.go index dd7120d..cdc645f 100644 --- a/internal/common/topology/const.go +++ b/internal/common/topology/const.go @@ -1,5 +1,5 @@ package topology const ( - cmTopologyKey = "topology.yml" + CmTopologyKey = "topology.yml" ) diff --git a/internal/common/topology/kubernetes.go b/internal/common/topology/kubernetes.go index 1246fe5..39a0662 100644 --- a/internal/common/topology/kubernetes.go +++ b/internal/common/topology/kubernetes.go @@ -78,7 +78,7 @@ func GetClusterTopologyFromCM(kubeclient kubernetes.Interface) (*ClusterTopology func FromClusterTopologyCM(cm *corev1.ConfigMap) (*ClusterTopology, error) { var clusterTopology ClusterTopology - err := yaml.Unmarshal([]byte(cm.Data[cmTopologyKey]), &clusterTopology) + err := yaml.Unmarshal([]byte(cm.Data[CmTopologyKey]), &clusterTopology) if err != nil { return nil, err } @@ -88,7 +88,7 @@ func FromClusterTopologyCM(cm *corev1.ConfigMap) (*ClusterTopology, error) { func FromNodeTopologyCM(cm *corev1.ConfigMap) (*NodeTopology, error) { var nodeTopology NodeTopology - err := yaml.Unmarshal([]byte(cm.Data[cmTopologyKey]), &nodeTopology) + err := yaml.Unmarshal([]byte(cm.Data[CmTopologyKey]), &nodeTopology) if err != nil { return nil, err } @@ -110,7 +110,7 @@ func ToClusterTopologyCM(clusterTopology *ClusterTopology) (*corev1.ConfigMap, e return nil, err } - cm.Data[cmTopologyKey] = string(topologyData) + cm.Data[CmTopologyKey] = string(topologyData) return cm, nil } @@ -134,7 +134,7 @@ func ToNodeTopologyCM(nodeTopology *NodeTopology, nodeName string) (*corev1.Conf return nil, nil, err } - cm.Data[cmTopologyKey] = string(topologyData) + cm.Data[CmTopologyKey] = string(topologyData) cmApplyConfig = cmApplyConfig.WithData(cm.Data) diff --git a/internal/common/topology/types.go b/internal/common/topology/types.go index c965172..76b92e3 100644 --- a/internal/common/topology/types.go +++ b/internal/common/topology/types.go @@ -14,16 +14,18 @@ type ClusterTopology struct { } type NodePoolTopology struct { - GpuCount int `yaml:"gpuCount"` - GpuMemory int `yaml:"gpuMemory"` - GpuProduct string `yaml:"gpuProduct"` + GpuCount int `yaml:"gpuCount"` + GpuMemory int `yaml:"gpuMemory"` + GpuProduct string `yaml:"gpuProduct"` + OtherDevices []GenericDevice `yaml:"otherDevices"` } type NodeTopology struct { - GpuMemory int `yaml:"gpuMemory"` - GpuProduct string `yaml:"gpuProduct"` - Gpus []GpuDetails `yaml:"gpus"` - MigStrategy string `yaml:"migStrategy"` + GpuMemory int `yaml:"gpuMemory"` + GpuProduct string `yaml:"gpuProduct"` + Gpus []GpuDetails `yaml:"gpus"` + MigStrategy string `yaml:"migStrategy"` + OtherDevices []GenericDevice `yaml:"otherDevices,omitempty"` } type GpuDetails struct { @@ -56,6 +58,11 @@ type Range struct { Max int `yaml:"max"` } +type GenericDevice struct { + Name string `yaml:"name"` + Count int `yaml:"count"` +} + // Errors var ErrNoNodes = fmt.Errorf("no nodes found") var ErrNoNode = fmt.Errorf("node not found") diff --git a/internal/deviceplugin/device_plugin.go b/internal/deviceplugin/device_plugin.go index df71de5..3c3e19d 100644 --- a/internal/deviceplugin/device_plugin.go +++ b/internal/deviceplugin/device_plugin.go @@ -1,34 +1,65 @@ package deviceplugin import ( + "path" + "strings" + "github.com/run-ai/fake-gpu-operator/internal/common/constants" "github.com/run-ai/fake-gpu-operator/internal/common/topology" "github.com/spf13/viper" "k8s.io/client-go/kubernetes" + pluginapi "k8s.io/kubelet/pkg/apis/deviceplugin/v1beta1" ) const ( - resourceName = "nvidia.com/gpu" + nvidiaGPUResourceName = "nvidia.com/gpu" ) type Interface interface { Serve() error + Name() string } -func NewDevicePlugin(topology *topology.NodeTopology, kubeClient kubernetes.Interface) Interface { +func NewDevicePlugins(topology *topology.NodeTopology, kubeClient kubernetes.Interface) []Interface { if topology == nil { panic("topology is nil") } if viper.GetBool(constants.EnvFakeNode) { - return &FakeNodeDevicePlugin{ - kubeClient: kubeClient, - gpuCount: getGpuCount(topology), + otherDevices := make(map[string]int) + for _, genericDevice := range topology.OtherDevices { + otherDevices[genericDevice.Name] = genericDevice.Count } + + return []Interface{&FakeNodeDevicePlugin{ + kubeClient: kubeClient, + gpuCount: getGpuCount(topology), + otherDevices: otherDevices, + }} + } + + devicePlugins := []Interface{ + &RealNodeDevicePlugin{ + devs: createDevices(getGpuCount(topology)), + socket: serverSock, + resourceName: nvidiaGPUResourceName, + }, } - return &RealNodeDevicePlugin{ - devs: createDevices(getGpuCount(topology)), - socket: serverSock, + for _, genericDevice := range topology.OtherDevices { + devicePlugins = append(devicePlugins, &RealNodeDevicePlugin{ + devs: createDevices(genericDevice.Count), + socket: path.Join(pluginapi.DevicePluginPath, normalizeDeviceName(genericDevice.Name)+".sock"), + resourceName: genericDevice.Name, + }) } + + return devicePlugins +} + +func normalizeDeviceName(deviceName string) string { + normalized := strings.ReplaceAll(deviceName, "/", "_") + normalized = strings.ReplaceAll(normalized, ".", "_") + normalized = strings.ReplaceAll(normalized, "-", "_") + return normalized } diff --git a/internal/deviceplugin/device_plugin_test.go b/internal/deviceplugin/device_plugin_test.go new file mode 100644 index 0000000..0ae8142 --- /dev/null +++ b/internal/deviceplugin/device_plugin_test.go @@ -0,0 +1,69 @@ +package deviceplugin + +import ( + "testing" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + "github.com/spf13/viper" + + "k8s.io/client-go/kubernetes/fake" + + "github.com/run-ai/fake-gpu-operator/internal/common/constants" + "github.com/run-ai/fake-gpu-operator/internal/common/topology" +) + +func TestDevicePlugin(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "DevicePlugin Suite") +} + +var _ = Describe("NewDevicePlugins", func() { + Context("When the topology is nil", func() { + It("should panic", func() { + Expect(func() { NewDevicePlugins(nil, nil) }).To(Panic()) + }) + }) + + Context("When the fake node is enabled", Ordered, func() { + BeforeAll(func() { + viper.Set(constants.EnvFakeNode, true) + }) + + AfterAll(func() { + viper.Set(constants.EnvFakeNode, false) + }) + + It("should return a fake node device plugin", func() { + topology := &topology.NodeTopology{} + kubeClient := &fake.Clientset{} + devicePlugins := NewDevicePlugins(topology, kubeClient) + Expect(devicePlugins).To(HaveLen(1)) + Expect(devicePlugins[0]).To(BeAssignableToTypeOf(&FakeNodeDevicePlugin{})) + }) + }) + + Context("With normal node", func() { + It("should return a real node device plugin", func() { + topology := &topology.NodeTopology{} + kubeClient := &fake.Clientset{} + devicePlugins := NewDevicePlugins(topology, kubeClient) + Expect(devicePlugins).To(HaveLen(1)) + Expect(devicePlugins[0]).To(BeAssignableToTypeOf(&RealNodeDevicePlugin{})) + }) + + It("should return a device plugin for each other device", func() { + topology := &topology.NodeTopology{ + OtherDevices: []topology.GenericDevice{ + {Name: "device1", Count: 1}, + {Name: "device2", Count: 2}, + }, + } + kubeClient := &fake.Clientset{} + devicePlugins := NewDevicePlugins(topology, kubeClient) + Expect(devicePlugins).To(HaveLen(3)) + Expect(devicePlugins[0]).To(BeAssignableToTypeOf(&RealNodeDevicePlugin{})) + }) + }) + +}) diff --git a/internal/deviceplugin/fake_node.go b/internal/deviceplugin/fake_node.go index b1c503e..d0f9c74 100644 --- a/internal/deviceplugin/fake_node.go +++ b/internal/deviceplugin/fake_node.go @@ -1,11 +1,14 @@ package deviceplugin import ( + "encoding/json" "fmt" "os" "github.com/run-ai/fake-gpu-operator/internal/common/constants" "golang.org/x/net/context" + v1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/api/resource" "k8s.io/apimachinery/pkg/types" "k8s.io/client-go/kubernetes" @@ -13,16 +16,41 @@ import ( ) type FakeNodeDevicePlugin struct { - kubeClient kubernetes.Interface - gpuCount int + kubeClient kubernetes.Interface + gpuCount int + otherDevices map[string]int } func (f *FakeNodeDevicePlugin) Serve() error { - patch := fmt.Sprintf(`{"status": {"capacity": {"%s": "%d"}, "allocatable": {"%s": "%d"}}}`, resourceName, f.gpuCount, resourceName, f.gpuCount) - _, err := f.kubeClient.CoreV1().Nodes().Patch(context.TODO(), os.Getenv(constants.EnvNodeName), types.MergePatchType, []byte(patch), metav1.PatchOptions{}, "status") + nodeStatus := v1.NodeStatus{ + Capacity: v1.ResourceList{ + v1.ResourceName(nvidiaGPUResourceName): *resource.NewQuantity(int64(f.gpuCount), resource.DecimalSI), + }, + Allocatable: v1.ResourceList{ + v1.ResourceName(nvidiaGPUResourceName): *resource.NewQuantity(int64(f.gpuCount), resource.DecimalSI), + }, + } + + for deviceName, count := range f.otherDevices { + nodeStatus.Capacity[v1.ResourceName(deviceName)] = *resource.NewQuantity(int64(count), resource.DecimalSI) + nodeStatus.Allocatable[v1.ResourceName(deviceName)] = *resource.NewQuantity(int64(count), resource.DecimalSI) + } + + // Convert the patch struct to JSON + patchBytes, err := json.Marshal(v1.Node{Status: nodeStatus}) + if err != nil { + return fmt.Errorf("failed to marshal patch: %v", err) + } + + // Apply the patch + _, err = f.kubeClient.CoreV1().Nodes().Patch(context.TODO(), os.Getenv(constants.EnvNodeName), types.MergePatchType, patchBytes, metav1.PatchOptions{}, "status") if err != nil { return fmt.Errorf("failed to update node capacity and allocatable: %v", err) } return nil } + +func (f *FakeNodeDevicePlugin) Name() string { + return "FakeNodeDevicePlugin" +} diff --git a/internal/deviceplugin/fake_node_test.go b/internal/deviceplugin/fake_node_test.go new file mode 100644 index 0000000..4ed96f3 --- /dev/null +++ b/internal/deviceplugin/fake_node_test.go @@ -0,0 +1,50 @@ +package deviceplugin + +import ( + "os" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + "golang.org/x/net/context" + + v1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/client-go/kubernetes/fake" +) + +var _ = Describe("FakeNodeDevicePlugin.Serve", func() { + It("should update the node capacity and allocatable", func() { + node := &v1.Node{ + ObjectMeta: metav1.ObjectMeta{ + Name: "node1", + }, + } + os.Setenv("NODE_NAME", "node1") + + fakeClient := fake.NewSimpleClientset(node) + + fakeNodeDevicePlugin := &FakeNodeDevicePlugin{ + kubeClient: fakeClient, + gpuCount: 1, + otherDevices: map[string]int{"device1": 2}, + } + + err := fakeNodeDevicePlugin.Serve() + Expect(err).ToNot(HaveOccurred()) + + updateNode, err := fakeClient.CoreV1().Nodes().Get(context.TODO(), "node1", metav1.GetOptions{}) + Expect(err).ToNot(HaveOccurred()) + Expect(testResourceListCondition(updateNode.Status.Capacity, v1.ResourceName(nvidiaGPUResourceName), 1)).To(BeTrue()) + Expect(testResourceListCondition(updateNode.Status.Allocatable, v1.ResourceName(nvidiaGPUResourceName), 1)).To(BeTrue()) + Expect(testResourceListCondition(updateNode.Status.Capacity, v1.ResourceName("device1"), 2)).To(BeTrue()) + Expect(testResourceListCondition(updateNode.Status.Allocatable, v1.ResourceName("device1"), 2)).To(BeTrue()) + }) +}) + +func testResourceListCondition(resourceList v1.ResourceList, resourceName v1.ResourceName, value int64) bool { + quantity, found := resourceList[resourceName] + if !found { + return false + } + return quantity.Value() == value +} diff --git a/internal/deviceplugin/real_node.go b/internal/deviceplugin/real_node.go index e78ab38..3fc325f 100644 --- a/internal/deviceplugin/real_node.go +++ b/internal/deviceplugin/real_node.go @@ -28,6 +28,8 @@ type RealNodeDevicePlugin struct { stop chan interface{} health chan *pluginapi.Device server *grpc.Server + + resourceName string } func getGpuCount(nodeTopology *topology.NodeTopology) int { @@ -115,7 +117,7 @@ func (m *RealNodeDevicePlugin) Stop() error { return m.cleanup() } -func (m *RealNodeDevicePlugin) Register(kubeletEndpoint, resourceName string) error { +func (m *RealNodeDevicePlugin) Register(kubeletEndpoint string) error { conn, err := dial(kubeletEndpoint, 5*time.Second) if err != nil { return err @@ -126,7 +128,7 @@ func (m *RealNodeDevicePlugin) Register(kubeletEndpoint, resourceName string) er reqt := &pluginapi.RegisterRequest{ Version: pluginapi.Version, Endpoint: path.Base(m.socket), - ResourceName: resourceName, + ResourceName: m.resourceName, } _, err = client.Register(context.Background(), reqt) @@ -202,7 +204,7 @@ func (m *RealNodeDevicePlugin) Serve() error { } log.Println("Starting to serve on", m.socket) - err = m.Register(pluginapi.KubeletSocket, resourceName) + err = m.Register(pluginapi.KubeletSocket) if err != nil { log.Printf("Could not register device plugin: %s", err) stopErr := m.Stop() @@ -215,3 +217,7 @@ func (m *RealNodeDevicePlugin) Serve() error { return nil } + +func (m *RealNodeDevicePlugin) Name() string { + return "RealNodeDevicePlugin-" + m.resourceName +} diff --git a/internal/kwok-gpu-device-plugin/handlers/configmap/handler.go b/internal/kwok-gpu-device-plugin/handlers/configmap/handler.go index 3c342f2..7e922ce 100644 --- a/internal/kwok-gpu-device-plugin/handlers/configmap/handler.go +++ b/internal/kwok-gpu-device-plugin/handlers/configmap/handler.go @@ -2,12 +2,15 @@ package configmap import ( "context" + "encoding/json" "fmt" "log" "github.com/run-ai/fake-gpu-operator/internal/common/constants" "github.com/run-ai/fake-gpu-operator/internal/common/topology" + v1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/api/resource" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/types" "k8s.io/client-go/kubernetes" @@ -41,16 +44,33 @@ func (p *ConfigMapHandler) HandleAdd(cm *v1.ConfigMap) error { } nodeName := cm.Labels[constants.LabelTopologyCMNodeName] - return p.applyFakeDevicePlugin(len(nodeTopology.Gpus), nodeName) + return p.applyFakeDevicePlugin(nodeTopology, nodeName) } -func (p *ConfigMapHandler) applyFakeDevicePlugin(gpuCount int, nodeName string) error { - patch := fmt.Sprintf( - `{"status": {"capacity": {"%s": "%d"}, "allocatable": {"%s": "%d"}}}`, - constants.GpuResourceName, gpuCount, constants.GpuResourceName, gpuCount, - ) - _, err := p.kubeClient.CoreV1().Nodes().Patch( - context.TODO(), nodeName, types.MergePatchType, []byte(patch), metav1.PatchOptions{}, "status", +func (p *ConfigMapHandler) applyFakeDevicePlugin(nodeTopology *topology.NodeTopology, nodeName string) error { + nodePatch := &v1.Node{ + Status: v1.NodeStatus{ + Capacity: v1.ResourceList{ + v1.ResourceName(constants.GpuResourceName): *resource.NewQuantity(int64(len(nodeTopology.Gpus)), resource.DecimalSI), + }, + Allocatable: v1.ResourceList{ + v1.ResourceName(constants.GpuResourceName): *resource.NewQuantity(int64(len(nodeTopology.Gpus)), resource.DecimalSI), + }, + }, + } + + for _, otherDevice := range nodeTopology.OtherDevices { + nodePatch.Status.Capacity[v1.ResourceName(otherDevice.Name)] = *resource.NewQuantity(int64(otherDevice.Count), resource.DecimalSI) + nodePatch.Status.Allocatable[v1.ResourceName(otherDevice.Name)] = *resource.NewQuantity(int64(otherDevice.Count), resource.DecimalSI) + } + + patchBytes, err := json.Marshal(nodePatch) + if err != nil { + return fmt.Errorf("failed to update node: failed to marshal patch: %v", err) + } + + _, err = p.kubeClient.CoreV1().Nodes().Patch( + context.TODO(), nodeName, types.MergePatchType, patchBytes, metav1.PatchOptions{}, "status", ) if err != nil { return fmt.Errorf("failed to update node capacity and allocatable: %v", err) diff --git a/internal/kwok-gpu-device-plugin/handlers/configmap/handler_test.go b/internal/kwok-gpu-device-plugin/handlers/configmap/handler_test.go new file mode 100644 index 0000000..6c91b75 --- /dev/null +++ b/internal/kwok-gpu-device-plugin/handlers/configmap/handler_test.go @@ -0,0 +1,80 @@ +package configmap + +import ( + "testing" + + "gopkg.in/yaml.v3" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + "github.com/run-ai/fake-gpu-operator/internal/common/constants" + "github.com/run-ai/fake-gpu-operator/internal/common/topology" + "golang.org/x/net/context" + + v1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/client-go/kubernetes/fake" +) + +func TestKWOKGPUDevicePlugin(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "KWOK GPU DevicePlugin Suite") +} + +var _ = Describe("HandleAdd", func() { + It("should update the node capacity and allocatable", func() { + nodeName := "node1" + nodeTopology := &topology.NodeTopology{ + Gpus: []topology.GpuDetails{ + {ID: "0"}, + }, + OtherDevices: []topology.GenericDevice{ + {Name: "device1", Count: 2}, + }, + } + + topologyData, err := yaml.Marshal(nodeTopology) + if err != nil { + Fail("Failed to marshal topology data") + } + + configMap := &v1.ConfigMap{ + ObjectMeta: metav1.ObjectMeta{ + Name: nodeName, + Labels: map[string]string{ + constants.LabelTopologyCMNodeName: nodeName, + }, + }, + Data: map[string]string{ + topology.CmTopologyKey: string(topologyData), + }, + } + + node := &v1.Node{ + ObjectMeta: metav1.ObjectMeta{ + Name: nodeName, + }, + } + + fakeClient := fake.NewSimpleClientset(node, configMap) + + fakeNodeCMHandler := NewConfigMapHandler(fakeClient, nil) + err = fakeNodeCMHandler.HandleAdd(configMap) + Expect(err).ToNot(HaveOccurred()) + + updateNode, err := fakeClient.CoreV1().Nodes().Get(context.TODO(), nodeName, metav1.GetOptions{}) + Expect(err).ToNot(HaveOccurred()) + Expect(testResourceListCondition(updateNode.Status.Capacity, v1.ResourceName(constants.GpuResourceName), 1)).To(BeTrue()) + Expect(testResourceListCondition(updateNode.Status.Allocatable, v1.ResourceName(constants.GpuResourceName), 1)).To(BeTrue()) + Expect(testResourceListCondition(updateNode.Status.Capacity, v1.ResourceName("device1"), 2)).To(BeTrue()) + Expect(testResourceListCondition(updateNode.Status.Allocatable, v1.ResourceName("device1"), 2)).To(BeTrue()) + }) +}) + +func testResourceListCondition(resourceList v1.ResourceList, resourceName v1.ResourceName, value int64) bool { + quantity, found := resourceList[resourceName] + if !found { + return false + } + return quantity.Value() == value +} diff --git a/internal/status-updater/handlers/node/topology_cm.go b/internal/status-updater/handlers/node/topology_cm.go index 44c707c..9f9fe18 100644 --- a/internal/status-updater/handlers/node/topology_cm.go +++ b/internal/status-updater/handlers/node/topology_cm.go @@ -25,10 +25,11 @@ func (p *NodeHandler) createNodeTopologyCM(node *v1.Node) error { } nodeTopology = &topology.NodeTopology{ - GpuMemory: nodePoolTopology.GpuMemory, - GpuProduct: nodePoolTopology.GpuProduct, - Gpus: generateGpuDetails(nodePoolTopology.GpuCount, node.Name), - MigStrategy: p.clusterTopology.MigStrategy, + GpuMemory: nodePoolTopology.GpuMemory, + GpuProduct: nodePoolTopology.GpuProduct, + Gpus: generateGpuDetails(nodePoolTopology.GpuCount, node.Name), + MigStrategy: p.clusterTopology.MigStrategy, + OtherDevices: nodePoolTopology.OtherDevices, } err := topology.CreateNodeTopologyCM(p.kubeClient, nodeTopology, node)