Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding multiple other devices fake device plugins #100

Merged
merged 5 commits into from
Nov 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions cmd/device-plugin/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,13 @@ func main() {
initNvidiaSmi()
initPreloaders()

devicePlugin := deviceplugin.NewDevicePlugin(topology, kubeClient)
if err = devicePlugin.Serve(); err != nil {
log.Printf("Failed to serve device plugin: %s\n", err)
os.Exit(1)
devicePlugins := deviceplugin.NewDevicePlugins(topology, kubeClient)
for _, devicePlugin := range devicePlugins {
log.Printf("Starting device plugin for %s\n", devicePlugin.Name())
if err = devicePlugin.Serve(); err != nil {
log.Printf("Failed to serve device plugin: %s\n", err)
os.Exit(1)
}
}

sig := make(chan os.Signal, 1)
Expand Down
2 changes: 1 addition & 1 deletion internal/common/topology/const.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
package topology

const (
cmTopologyKey = "topology.yml"
CmTopologyKey = "topology.yml"
)
8 changes: 4 additions & 4 deletions internal/common/topology/kubernetes.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ func GetClusterTopologyFromCM(kubeclient kubernetes.Interface) (*ClusterTopology

func FromClusterTopologyCM(cm *corev1.ConfigMap) (*ClusterTopology, error) {
var clusterTopology ClusterTopology
err := yaml.Unmarshal([]byte(cm.Data[cmTopologyKey]), &clusterTopology)
err := yaml.Unmarshal([]byte(cm.Data[CmTopologyKey]), &clusterTopology)
if err != nil {
return nil, err
}
Expand All @@ -88,7 +88,7 @@ func FromClusterTopologyCM(cm *corev1.ConfigMap) (*ClusterTopology, error) {

func FromNodeTopologyCM(cm *corev1.ConfigMap) (*NodeTopology, error) {
var nodeTopology NodeTopology
err := yaml.Unmarshal([]byte(cm.Data[cmTopologyKey]), &nodeTopology)
err := yaml.Unmarshal([]byte(cm.Data[CmTopologyKey]), &nodeTopology)
if err != nil {
return nil, err
}
Expand All @@ -110,7 +110,7 @@ func ToClusterTopologyCM(clusterTopology *ClusterTopology) (*corev1.ConfigMap, e
return nil, err
}

cm.Data[cmTopologyKey] = string(topologyData)
cm.Data[CmTopologyKey] = string(topologyData)

return cm, nil
}
Expand All @@ -134,7 +134,7 @@ func ToNodeTopologyCM(nodeTopology *NodeTopology, nodeName string) (*corev1.Conf
return nil, nil, err
}

cm.Data[cmTopologyKey] = string(topologyData)
cm.Data[CmTopologyKey] = string(topologyData)

cmApplyConfig = cmApplyConfig.WithData(cm.Data)

Expand Down
21 changes: 14 additions & 7 deletions internal/common/topology/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,18 @@ type ClusterTopology struct {
}

type NodePoolTopology struct {
GpuCount int `yaml:"gpuCount"`
GpuMemory int `yaml:"gpuMemory"`
GpuProduct string `yaml:"gpuProduct"`
GpuCount int `yaml:"gpuCount"`
GpuMemory int `yaml:"gpuMemory"`
GpuProduct string `yaml:"gpuProduct"`
OtherDevices []GenericDevice `yaml:"otherDevices"`
}

type NodeTopology struct {
GpuMemory int `yaml:"gpuMemory"`
GpuProduct string `yaml:"gpuProduct"`
Gpus []GpuDetails `yaml:"gpus"`
MigStrategy string `yaml:"migStrategy"`
GpuMemory int `yaml:"gpuMemory"`
GpuProduct string `yaml:"gpuProduct"`
Gpus []GpuDetails `yaml:"gpus"`
MigStrategy string `yaml:"migStrategy"`
OtherDevices []GenericDevice `yaml:"otherDevices,omitempty"`
}

type GpuDetails struct {
Expand Down Expand Up @@ -56,6 +58,11 @@ type Range struct {
Max int `yaml:"max"`
}

type GenericDevice struct {
Name string `yaml:"name"`
Count int `yaml:"count"`
}

// Errors
var ErrNoNodes = fmt.Errorf("no nodes found")
var ErrNoNode = fmt.Errorf("node not found")
47 changes: 39 additions & 8 deletions internal/deviceplugin/device_plugin.go
Original file line number Diff line number Diff line change
@@ -1,34 +1,65 @@
package deviceplugin

import (
"path"
"strings"

"github.com/run-ai/fake-gpu-operator/internal/common/constants"
"github.com/run-ai/fake-gpu-operator/internal/common/topology"
"github.com/spf13/viper"
"k8s.io/client-go/kubernetes"
pluginapi "k8s.io/kubelet/pkg/apis/deviceplugin/v1beta1"
)

const (
resourceName = "nvidia.com/gpu"
nvidiaGPUResourceName = "nvidia.com/gpu"
)

type Interface interface {
Serve() error
Name() string
}

func NewDevicePlugin(topology *topology.NodeTopology, kubeClient kubernetes.Interface) Interface {
func NewDevicePlugins(topology *topology.NodeTopology, kubeClient kubernetes.Interface) []Interface {
if topology == nil {
panic("topology is nil")
}

if viper.GetBool(constants.EnvFakeNode) {
return &FakeNodeDevicePlugin{
kubeClient: kubeClient,
gpuCount: getGpuCount(topology),
otherDevices := make(map[string]int)
for _, genericDevice := range topology.OtherDevices {
otherDevices[genericDevice.Name] = genericDevice.Count
}

return []Interface{&FakeNodeDevicePlugin{
kubeClient: kubeClient,
gpuCount: getGpuCount(topology),
otherDevices: otherDevices,
}}
}

devicePlugins := []Interface{
&RealNodeDevicePlugin{
devs: createDevices(getGpuCount(topology)),
socket: serverSock,
resourceName: nvidiaGPUResourceName,
},
}

return &RealNodeDevicePlugin{
devs: createDevices(getGpuCount(topology)),
socket: serverSock,
for _, genericDevice := range topology.OtherDevices {
devicePlugins = append(devicePlugins, &RealNodeDevicePlugin{
devs: createDevices(genericDevice.Count),
socket: path.Join(pluginapi.DevicePluginPath, normalizeDeviceName(genericDevice.Name)+".sock"),
resourceName: genericDevice.Name,
})
}

return devicePlugins
}

func normalizeDeviceName(deviceName string) string {
normalized := strings.ReplaceAll(deviceName, "/", "_")
normalized = strings.ReplaceAll(normalized, ".", "_")
normalized = strings.ReplaceAll(normalized, "-", "_")
return normalized
}
69 changes: 69 additions & 0 deletions internal/deviceplugin/device_plugin_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
package deviceplugin

import (
"testing"

. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"github.com/spf13/viper"

"k8s.io/client-go/kubernetes/fake"

"github.com/run-ai/fake-gpu-operator/internal/common/constants"
"github.com/run-ai/fake-gpu-operator/internal/common/topology"
)

func TestDevicePlugin(t *testing.T) {
RegisterFailHandler(Fail)
RunSpecs(t, "DevicePlugin Suite")
}

var _ = Describe("NewDevicePlugins", func() {
Context("When the topology is nil", func() {
It("should panic", func() {
Expect(func() { NewDevicePlugins(nil, nil) }).To(Panic())
})
})

Context("When the fake node is enabled", Ordered, func() {
BeforeAll(func() {
viper.Set(constants.EnvFakeNode, true)
})

AfterAll(func() {
viper.Set(constants.EnvFakeNode, false)
})

It("should return a fake node device plugin", func() {
topology := &topology.NodeTopology{}
kubeClient := &fake.Clientset{}
devicePlugins := NewDevicePlugins(topology, kubeClient)
Expect(devicePlugins).To(HaveLen(1))
Expect(devicePlugins[0]).To(BeAssignableToTypeOf(&FakeNodeDevicePlugin{}))
})
})

Context("With normal node", func() {
It("should return a real node device plugin", func() {
topology := &topology.NodeTopology{}
kubeClient := &fake.Clientset{}
devicePlugins := NewDevicePlugins(topology, kubeClient)
Expect(devicePlugins).To(HaveLen(1))
Expect(devicePlugins[0]).To(BeAssignableToTypeOf(&RealNodeDevicePlugin{}))
})

It("should return a device plugin for each other device", func() {
topology := &topology.NodeTopology{
OtherDevices: []topology.GenericDevice{
{Name: "device1", Count: 1},
{Name: "device2", Count: 2},
},
}
kubeClient := &fake.Clientset{}
devicePlugins := NewDevicePlugins(topology, kubeClient)
Expect(devicePlugins).To(HaveLen(3))
Expect(devicePlugins[0]).To(BeAssignableToTypeOf(&RealNodeDevicePlugin{}))
})
})

})
36 changes: 32 additions & 4 deletions internal/deviceplugin/fake_node.go
Original file line number Diff line number Diff line change
@@ -1,28 +1,56 @@
package deviceplugin

import (
"encoding/json"
"fmt"
"os"

"github.com/run-ai/fake-gpu-operator/internal/common/constants"
"golang.org/x/net/context"
v1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/api/resource"
"k8s.io/apimachinery/pkg/types"
"k8s.io/client-go/kubernetes"

metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
)

type FakeNodeDevicePlugin struct {
kubeClient kubernetes.Interface
gpuCount int
kubeClient kubernetes.Interface
gpuCount int
otherDevices map[string]int
}

func (f *FakeNodeDevicePlugin) Serve() error {
patch := fmt.Sprintf(`{"status": {"capacity": {"%s": "%d"}, "allocatable": {"%s": "%d"}}}`, resourceName, f.gpuCount, resourceName, f.gpuCount)
_, err := f.kubeClient.CoreV1().Nodes().Patch(context.TODO(), os.Getenv(constants.EnvNodeName), types.MergePatchType, []byte(patch), metav1.PatchOptions{}, "status")
nodeStatus := v1.NodeStatus{
Capacity: v1.ResourceList{
v1.ResourceName(nvidiaGPUResourceName): *resource.NewQuantity(int64(f.gpuCount), resource.DecimalSI),
},
Allocatable: v1.ResourceList{
v1.ResourceName(nvidiaGPUResourceName): *resource.NewQuantity(int64(f.gpuCount), resource.DecimalSI),
},
}

for deviceName, count := range f.otherDevices {
nodeStatus.Capacity[v1.ResourceName(deviceName)] = *resource.NewQuantity(int64(count), resource.DecimalSI)
nodeStatus.Allocatable[v1.ResourceName(deviceName)] = *resource.NewQuantity(int64(count), resource.DecimalSI)
}

// Convert the patch struct to JSON
patchBytes, err := json.Marshal(v1.Node{Status: nodeStatus})
if err != nil {
return fmt.Errorf("failed to marshal patch: %v", err)
}

// Apply the patch
_, err = f.kubeClient.CoreV1().Nodes().Patch(context.TODO(), os.Getenv(constants.EnvNodeName), types.MergePatchType, patchBytes, metav1.PatchOptions{}, "status")
if err != nil {
return fmt.Errorf("failed to update node capacity and allocatable: %v", err)
}

return nil
}

func (f *FakeNodeDevicePlugin) Name() string {
return "FakeNodeDevicePlugin"
}
50 changes: 50 additions & 0 deletions internal/deviceplugin/fake_node_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package deviceplugin

import (
"os"

. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"golang.org/x/net/context"

v1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/client-go/kubernetes/fake"
)

var _ = Describe("FakeNodeDevicePlugin.Serve", func() {
It("should update the node capacity and allocatable", func() {
node := &v1.Node{
ObjectMeta: metav1.ObjectMeta{
Name: "node1",
},
}
os.Setenv("NODE_NAME", "node1")

fakeClient := fake.NewSimpleClientset(node)

fakeNodeDevicePlugin := &FakeNodeDevicePlugin{
kubeClient: fakeClient,
gpuCount: 1,
otherDevices: map[string]int{"device1": 2},
}

err := fakeNodeDevicePlugin.Serve()
Expect(err).ToNot(HaveOccurred())

updateNode, err := fakeClient.CoreV1().Nodes().Get(context.TODO(), "node1", metav1.GetOptions{})
Expect(err).ToNot(HaveOccurred())
Expect(testResourceListCondition(updateNode.Status.Capacity, v1.ResourceName(nvidiaGPUResourceName), 1)).To(BeTrue())
Expect(testResourceListCondition(updateNode.Status.Allocatable, v1.ResourceName(nvidiaGPUResourceName), 1)).To(BeTrue())
Expect(testResourceListCondition(updateNode.Status.Capacity, v1.ResourceName("device1"), 2)).To(BeTrue())
Expect(testResourceListCondition(updateNode.Status.Allocatable, v1.ResourceName("device1"), 2)).To(BeTrue())
})
})

func testResourceListCondition(resourceList v1.ResourceList, resourceName v1.ResourceName, value int64) bool {
quantity, found := resourceList[resourceName]
if !found {
return false
}
return quantity.Value() == value
}
12 changes: 9 additions & 3 deletions internal/deviceplugin/real_node.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ type RealNodeDevicePlugin struct {
stop chan interface{}
health chan *pluginapi.Device
server *grpc.Server

resourceName string
}

func getGpuCount(nodeTopology *topology.NodeTopology) int {
Expand Down Expand Up @@ -115,7 +117,7 @@ func (m *RealNodeDevicePlugin) Stop() error {
return m.cleanup()
}

func (m *RealNodeDevicePlugin) Register(kubeletEndpoint, resourceName string) error {
func (m *RealNodeDevicePlugin) Register(kubeletEndpoint string) error {
conn, err := dial(kubeletEndpoint, 5*time.Second)
if err != nil {
return err
Expand All @@ -126,7 +128,7 @@ func (m *RealNodeDevicePlugin) Register(kubeletEndpoint, resourceName string) er
reqt := &pluginapi.RegisterRequest{
Version: pluginapi.Version,
Endpoint: path.Base(m.socket),
ResourceName: resourceName,
ResourceName: m.resourceName,
}

_, err = client.Register(context.Background(), reqt)
Expand Down Expand Up @@ -202,7 +204,7 @@ func (m *RealNodeDevicePlugin) Serve() error {
}
log.Println("Starting to serve on", m.socket)

err = m.Register(pluginapi.KubeletSocket, resourceName)
err = m.Register(pluginapi.KubeletSocket)
if err != nil {
log.Printf("Could not register device plugin: %s", err)
stopErr := m.Stop()
Expand All @@ -215,3 +217,7 @@ func (m *RealNodeDevicePlugin) Serve() error {

return nil
}

func (m *RealNodeDevicePlugin) Name() string {
return "RealNodeDevicePlugin-" + m.resourceName
}
Loading
Loading