Skip to content

Commit

Permalink
Merge pull request #100 from run-ai/erez/mig-devices-RUN-23506
Browse files Browse the repository at this point in the history
adding multiple other devices fake device plugins
  • Loading branch information
enoodle authored Nov 27, 2024
2 parents 83b94a7 + 6b261a7 commit eddc6cf
Show file tree
Hide file tree
Showing 12 changed files with 338 additions and 43 deletions.
11 changes: 7 additions & 4 deletions cmd/device-plugin/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,13 @@ func main() {
initNvidiaSmi()
initPreloaders()

devicePlugin := deviceplugin.NewDevicePlugin(topology, kubeClient)
if err = devicePlugin.Serve(); err != nil {
log.Printf("Failed to serve device plugin: %s\n", err)
os.Exit(1)
devicePlugins := deviceplugin.NewDevicePlugins(topology, kubeClient)
for _, devicePlugin := range devicePlugins {
log.Printf("Starting device plugin for %s\n", devicePlugin.Name())
if err = devicePlugin.Serve(); err != nil {
log.Printf("Failed to serve device plugin: %s\n", err)
os.Exit(1)
}
}

sig := make(chan os.Signal, 1)
Expand Down
2 changes: 1 addition & 1 deletion internal/common/topology/const.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
package topology

const (
cmTopologyKey = "topology.yml"
CmTopologyKey = "topology.yml"
)
8 changes: 4 additions & 4 deletions internal/common/topology/kubernetes.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ func GetClusterTopologyFromCM(kubeclient kubernetes.Interface) (*ClusterTopology

func FromClusterTopologyCM(cm *corev1.ConfigMap) (*ClusterTopology, error) {
var clusterTopology ClusterTopology
err := yaml.Unmarshal([]byte(cm.Data[cmTopologyKey]), &clusterTopology)
err := yaml.Unmarshal([]byte(cm.Data[CmTopologyKey]), &clusterTopology)
if err != nil {
return nil, err
}
Expand All @@ -88,7 +88,7 @@ func FromClusterTopologyCM(cm *corev1.ConfigMap) (*ClusterTopology, error) {

func FromNodeTopologyCM(cm *corev1.ConfigMap) (*NodeTopology, error) {
var nodeTopology NodeTopology
err := yaml.Unmarshal([]byte(cm.Data[cmTopologyKey]), &nodeTopology)
err := yaml.Unmarshal([]byte(cm.Data[CmTopologyKey]), &nodeTopology)
if err != nil {
return nil, err
}
Expand All @@ -110,7 +110,7 @@ func ToClusterTopologyCM(clusterTopology *ClusterTopology) (*corev1.ConfigMap, e
return nil, err
}

cm.Data[cmTopologyKey] = string(topologyData)
cm.Data[CmTopologyKey] = string(topologyData)

return cm, nil
}
Expand All @@ -134,7 +134,7 @@ func ToNodeTopologyCM(nodeTopology *NodeTopology, nodeName string) (*corev1.Conf
return nil, nil, err
}

cm.Data[cmTopologyKey] = string(topologyData)
cm.Data[CmTopologyKey] = string(topologyData)

cmApplyConfig = cmApplyConfig.WithData(cm.Data)

Expand Down
21 changes: 14 additions & 7 deletions internal/common/topology/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,18 @@ type ClusterTopology struct {
}

type NodePoolTopology struct {
GpuCount int `yaml:"gpuCount"`
GpuMemory int `yaml:"gpuMemory"`
GpuProduct string `yaml:"gpuProduct"`
GpuCount int `yaml:"gpuCount"`
GpuMemory int `yaml:"gpuMemory"`
GpuProduct string `yaml:"gpuProduct"`
OtherDevices []GenericDevice `yaml:"otherDevices"`
}

type NodeTopology struct {
GpuMemory int `yaml:"gpuMemory"`
GpuProduct string `yaml:"gpuProduct"`
Gpus []GpuDetails `yaml:"gpus"`
MigStrategy string `yaml:"migStrategy"`
GpuMemory int `yaml:"gpuMemory"`
GpuProduct string `yaml:"gpuProduct"`
Gpus []GpuDetails `yaml:"gpus"`
MigStrategy string `yaml:"migStrategy"`
OtherDevices []GenericDevice `yaml:"otherDevices,omitempty"`
}

type GpuDetails struct {
Expand Down Expand Up @@ -56,6 +58,11 @@ type Range struct {
Max int `yaml:"max"`
}

type GenericDevice struct {
Name string `yaml:"name"`
Count int `yaml:"count"`
}

// Errors
var ErrNoNodes = fmt.Errorf("no nodes found")
var ErrNoNode = fmt.Errorf("node not found")
47 changes: 39 additions & 8 deletions internal/deviceplugin/device_plugin.go
Original file line number Diff line number Diff line change
@@ -1,34 +1,65 @@
package deviceplugin

import (
"path"
"strings"

"github.com/run-ai/fake-gpu-operator/internal/common/constants"
"github.com/run-ai/fake-gpu-operator/internal/common/topology"
"github.com/spf13/viper"
"k8s.io/client-go/kubernetes"
pluginapi "k8s.io/kubelet/pkg/apis/deviceplugin/v1beta1"
)

const (
resourceName = "nvidia.com/gpu"
nvidiaGPUResourceName = "nvidia.com/gpu"
)

type Interface interface {
Serve() error
Name() string
}

func NewDevicePlugin(topology *topology.NodeTopology, kubeClient kubernetes.Interface) Interface {
func NewDevicePlugins(topology *topology.NodeTopology, kubeClient kubernetes.Interface) []Interface {
if topology == nil {
panic("topology is nil")
}

if viper.GetBool(constants.EnvFakeNode) {
return &FakeNodeDevicePlugin{
kubeClient: kubeClient,
gpuCount: getGpuCount(topology),
otherDevices := make(map[string]int)
for _, genericDevice := range topology.OtherDevices {
otherDevices[genericDevice.Name] = genericDevice.Count
}

return []Interface{&FakeNodeDevicePlugin{
kubeClient: kubeClient,
gpuCount: getGpuCount(topology),
otherDevices: otherDevices,
}}
}

devicePlugins := []Interface{
&RealNodeDevicePlugin{
devs: createDevices(getGpuCount(topology)),
socket: serverSock,
resourceName: nvidiaGPUResourceName,
},
}

return &RealNodeDevicePlugin{
devs: createDevices(getGpuCount(topology)),
socket: serverSock,
for _, genericDevice := range topology.OtherDevices {
devicePlugins = append(devicePlugins, &RealNodeDevicePlugin{
devs: createDevices(genericDevice.Count),
socket: path.Join(pluginapi.DevicePluginPath, normalizeDeviceName(genericDevice.Name)+".sock"),
resourceName: genericDevice.Name,
})
}

return devicePlugins
}

func normalizeDeviceName(deviceName string) string {
normalized := strings.ReplaceAll(deviceName, "/", "_")
normalized = strings.ReplaceAll(normalized, ".", "_")
normalized = strings.ReplaceAll(normalized, "-", "_")
return normalized
}
69 changes: 69 additions & 0 deletions internal/deviceplugin/device_plugin_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
package deviceplugin

import (
"testing"

. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"github.com/spf13/viper"

"k8s.io/client-go/kubernetes/fake"

"github.com/run-ai/fake-gpu-operator/internal/common/constants"
"github.com/run-ai/fake-gpu-operator/internal/common/topology"
)

func TestDevicePlugin(t *testing.T) {
RegisterFailHandler(Fail)
RunSpecs(t, "DevicePlugin Suite")
}

var _ = Describe("NewDevicePlugins", func() {
Context("When the topology is nil", func() {
It("should panic", func() {
Expect(func() { NewDevicePlugins(nil, nil) }).To(Panic())
})
})

Context("When the fake node is enabled", Ordered, func() {
BeforeAll(func() {
viper.Set(constants.EnvFakeNode, true)
})

AfterAll(func() {
viper.Set(constants.EnvFakeNode, false)
})

It("should return a fake node device plugin", func() {
topology := &topology.NodeTopology{}
kubeClient := &fake.Clientset{}
devicePlugins := NewDevicePlugins(topology, kubeClient)
Expect(devicePlugins).To(HaveLen(1))
Expect(devicePlugins[0]).To(BeAssignableToTypeOf(&FakeNodeDevicePlugin{}))
})
})

Context("With normal node", func() {
It("should return a real node device plugin", func() {
topology := &topology.NodeTopology{}
kubeClient := &fake.Clientset{}
devicePlugins := NewDevicePlugins(topology, kubeClient)
Expect(devicePlugins).To(HaveLen(1))
Expect(devicePlugins[0]).To(BeAssignableToTypeOf(&RealNodeDevicePlugin{}))
})

It("should return a device plugin for each other device", func() {
topology := &topology.NodeTopology{
OtherDevices: []topology.GenericDevice{
{Name: "device1", Count: 1},
{Name: "device2", Count: 2},
},
}
kubeClient := &fake.Clientset{}
devicePlugins := NewDevicePlugins(topology, kubeClient)
Expect(devicePlugins).To(HaveLen(3))
Expect(devicePlugins[0]).To(BeAssignableToTypeOf(&RealNodeDevicePlugin{}))
})
})

})
36 changes: 32 additions & 4 deletions internal/deviceplugin/fake_node.go
Original file line number Diff line number Diff line change
@@ -1,28 +1,56 @@
package deviceplugin

import (
"encoding/json"
"fmt"
"os"

"github.com/run-ai/fake-gpu-operator/internal/common/constants"
"golang.org/x/net/context"
v1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/api/resource"
"k8s.io/apimachinery/pkg/types"
"k8s.io/client-go/kubernetes"

metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
)

type FakeNodeDevicePlugin struct {
kubeClient kubernetes.Interface
gpuCount int
kubeClient kubernetes.Interface
gpuCount int
otherDevices map[string]int
}

func (f *FakeNodeDevicePlugin) Serve() error {
patch := fmt.Sprintf(`{"status": {"capacity": {"%s": "%d"}, "allocatable": {"%s": "%d"}}}`, resourceName, f.gpuCount, resourceName, f.gpuCount)
_, err := f.kubeClient.CoreV1().Nodes().Patch(context.TODO(), os.Getenv(constants.EnvNodeName), types.MergePatchType, []byte(patch), metav1.PatchOptions{}, "status")
nodeStatus := v1.NodeStatus{
Capacity: v1.ResourceList{
v1.ResourceName(nvidiaGPUResourceName): *resource.NewQuantity(int64(f.gpuCount), resource.DecimalSI),
},
Allocatable: v1.ResourceList{
v1.ResourceName(nvidiaGPUResourceName): *resource.NewQuantity(int64(f.gpuCount), resource.DecimalSI),
},
}

for deviceName, count := range f.otherDevices {
nodeStatus.Capacity[v1.ResourceName(deviceName)] = *resource.NewQuantity(int64(count), resource.DecimalSI)
nodeStatus.Allocatable[v1.ResourceName(deviceName)] = *resource.NewQuantity(int64(count), resource.DecimalSI)
}

// Convert the patch struct to JSON
patchBytes, err := json.Marshal(v1.Node{Status: nodeStatus})
if err != nil {
return fmt.Errorf("failed to marshal patch: %v", err)
}

// Apply the patch
_, err = f.kubeClient.CoreV1().Nodes().Patch(context.TODO(), os.Getenv(constants.EnvNodeName), types.MergePatchType, patchBytes, metav1.PatchOptions{}, "status")
if err != nil {
return fmt.Errorf("failed to update node capacity and allocatable: %v", err)
}

return nil
}

func (f *FakeNodeDevicePlugin) Name() string {
return "FakeNodeDevicePlugin"
}
50 changes: 50 additions & 0 deletions internal/deviceplugin/fake_node_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package deviceplugin

import (
"os"

. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"golang.org/x/net/context"

v1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/client-go/kubernetes/fake"
)

var _ = Describe("FakeNodeDevicePlugin.Serve", func() {
It("should update the node capacity and allocatable", func() {
node := &v1.Node{
ObjectMeta: metav1.ObjectMeta{
Name: "node1",
},
}
os.Setenv("NODE_NAME", "node1")

fakeClient := fake.NewSimpleClientset(node)

fakeNodeDevicePlugin := &FakeNodeDevicePlugin{
kubeClient: fakeClient,
gpuCount: 1,
otherDevices: map[string]int{"device1": 2},
}

err := fakeNodeDevicePlugin.Serve()
Expect(err).ToNot(HaveOccurred())

updateNode, err := fakeClient.CoreV1().Nodes().Get(context.TODO(), "node1", metav1.GetOptions{})
Expect(err).ToNot(HaveOccurred())
Expect(testResourceListCondition(updateNode.Status.Capacity, v1.ResourceName(nvidiaGPUResourceName), 1)).To(BeTrue())
Expect(testResourceListCondition(updateNode.Status.Allocatable, v1.ResourceName(nvidiaGPUResourceName), 1)).To(BeTrue())
Expect(testResourceListCondition(updateNode.Status.Capacity, v1.ResourceName("device1"), 2)).To(BeTrue())
Expect(testResourceListCondition(updateNode.Status.Allocatable, v1.ResourceName("device1"), 2)).To(BeTrue())
})
})

func testResourceListCondition(resourceList v1.ResourceList, resourceName v1.ResourceName, value int64) bool {
quantity, found := resourceList[resourceName]
if !found {
return false
}
return quantity.Value() == value
}
12 changes: 9 additions & 3 deletions internal/deviceplugin/real_node.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ type RealNodeDevicePlugin struct {
stop chan interface{}
health chan *pluginapi.Device
server *grpc.Server

resourceName string
}

func getGpuCount(nodeTopology *topology.NodeTopology) int {
Expand Down Expand Up @@ -115,7 +117,7 @@ func (m *RealNodeDevicePlugin) Stop() error {
return m.cleanup()
}

func (m *RealNodeDevicePlugin) Register(kubeletEndpoint, resourceName string) error {
func (m *RealNodeDevicePlugin) Register(kubeletEndpoint string) error {
conn, err := dial(kubeletEndpoint, 5*time.Second)
if err != nil {
return err
Expand All @@ -126,7 +128,7 @@ func (m *RealNodeDevicePlugin) Register(kubeletEndpoint, resourceName string) er
reqt := &pluginapi.RegisterRequest{
Version: pluginapi.Version,
Endpoint: path.Base(m.socket),
ResourceName: resourceName,
ResourceName: m.resourceName,
}

_, err = client.Register(context.Background(), reqt)
Expand Down Expand Up @@ -202,7 +204,7 @@ func (m *RealNodeDevicePlugin) Serve() error {
}
log.Println("Starting to serve on", m.socket)

err = m.Register(pluginapi.KubeletSocket, resourceName)
err = m.Register(pluginapi.KubeletSocket)
if err != nil {
log.Printf("Could not register device plugin: %s", err)
stopErr := m.Stop()
Expand All @@ -215,3 +217,7 @@ func (m *RealNodeDevicePlugin) Serve() error {

return nil
}

func (m *RealNodeDevicePlugin) Name() string {
return "RealNodeDevicePlugin-" + m.resourceName
}
Loading

0 comments on commit eddc6cf

Please sign in to comment.