Skip to content

Commit

Permalink
adding multiple other devices fake device plugins
Browse files Browse the repository at this point in the history
  • Loading branch information
enoodle committed Nov 25, 2024
1 parent 83b94a7 commit 22dc131
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 22 deletions.
11 changes: 7 additions & 4 deletions cmd/device-plugin/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,13 @@ func main() {
initNvidiaSmi()
initPreloaders()

devicePlugin := deviceplugin.NewDevicePlugin(topology, kubeClient)
if err = devicePlugin.Serve(); err != nil {
log.Printf("Failed to serve device plugin: %s\n", err)
os.Exit(1)
devicePlugins := deviceplugin.NewDevicePlugins(topology, kubeClient)
for _, devicePlugin := range devicePlugins {
log.Printf("Starting device plugin for %s\n", devicePlugin.Name())
if err = devicePlugin.Serve(); err != nil {
log.Printf("Failed to serve device plugin: %s\n", err)
os.Exit(1)
}
}

sig := make(chan os.Signal, 1)
Expand Down
21 changes: 14 additions & 7 deletions internal/common/topology/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,18 @@ type ClusterTopology struct {
}

type NodePoolTopology struct {
GpuCount int `yaml:"gpuCount"`
GpuMemory int `yaml:"gpuMemory"`
GpuProduct string `yaml:"gpuProduct"`
GpuCount int `yaml:"gpuCount"`
GpuMemory int `yaml:"gpuMemory"`
GpuProduct string `yaml:"gpuProduct"`
OtherDevices []GenericDevice `yaml:"otherDevices,omitempty"`
}

type NodeTopology struct {
GpuMemory int `yaml:"gpuMemory"`
GpuProduct string `yaml:"gpuProduct"`
Gpus []GpuDetails `yaml:"gpus"`
MigStrategy string `yaml:"migStrategy"`
GpuMemory int `yaml:"gpuMemory"`
GpuProduct string `yaml:"gpuProduct"`
Gpus []GpuDetails `yaml:"gpus"`
MigStrategy string `yaml:"migStrategy"`
OtherDevices []GenericDevice `yaml:"otherDevices,omitempty"`
}

type GpuDetails struct {
Expand Down Expand Up @@ -56,6 +58,11 @@ type Range struct {
Max int `yaml:"max"`
}

type GenericDevice struct {
Name string `yaml:"name"`
Count int `yaml:"count"`
}

// Errors
var ErrNoNodes = fmt.Errorf("no nodes found")
var ErrNoNode = fmt.Errorf("node not found")
39 changes: 32 additions & 7 deletions internal/deviceplugin/device_plugin.go
Original file line number Diff line number Diff line change
@@ -1,34 +1,59 @@
package deviceplugin

import (
"path"
"strings"

"github.com/run-ai/fake-gpu-operator/internal/common/constants"
"github.com/run-ai/fake-gpu-operator/internal/common/topology"
"github.com/spf13/viper"
"k8s.io/client-go/kubernetes"
pluginapi "k8s.io/kubelet/pkg/apis/deviceplugin/v1beta1"
)

const (
resourceName = "nvidia.com/gpu"
nvidiaGPUResourceName = "nvidia.com/gpu"
)

type Interface interface {
Serve() error
Name() string
}

func NewDevicePlugin(topology *topology.NodeTopology, kubeClient kubernetes.Interface) Interface {
func NewDevicePlugins(topology *topology.NodeTopology, kubeClient kubernetes.Interface) []Interface {
if topology == nil {
panic("topology is nil")
}

if viper.GetBool(constants.EnvFakeNode) {
return &FakeNodeDevicePlugin{
return []Interface{&FakeNodeDevicePlugin{
kubeClient: kubeClient,
gpuCount: getGpuCount(topology),
}
}}
}

devicePlugins := []Interface{
&RealNodeDevicePlugin{
devs: createDevices(getGpuCount(topology)),
socket: serverSock,
resourceName: nvidiaGPUResourceName,
},
}

return &RealNodeDevicePlugin{
devs: createDevices(getGpuCount(topology)),
socket: serverSock,
for _, genericDevice := range topology.OtherDevices {
devicePlugins = append(devicePlugins, &RealNodeDevicePlugin{
devs: createDevices(genericDevice.Count),
socket: path.Join(pluginapi.DevicePluginPath, normalizeDeviceName(genericDevice.Name)+".sock"),
resourceName: genericDevice.Name,
})
}

return devicePlugins
}

func normalizeDeviceName(deviceName string) string {
normalized := strings.ReplaceAll(deviceName, "/", "_")
normalized = strings.ReplaceAll(normalized, ".", "_")
normalized = strings.ReplaceAll(normalized, "-", "_")
return normalized
}
6 changes: 5 additions & 1 deletion internal/deviceplugin/fake_node.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,15 @@ type FakeNodeDevicePlugin struct {
}

func (f *FakeNodeDevicePlugin) Serve() error {
patch := fmt.Sprintf(`{"status": {"capacity": {"%s": "%d"}, "allocatable": {"%s": "%d"}}}`, resourceName, f.gpuCount, resourceName, f.gpuCount)
patch := fmt.Sprintf(`{"status": {"capacity": {"%s": "%d"}, "allocatable": {"%s": "%d"}}}`, nvidiaGPUResourceName, f.gpuCount, nvidiaGPUResourceName, f.gpuCount)
_, err := f.kubeClient.CoreV1().Nodes().Patch(context.TODO(), os.Getenv(constants.EnvNodeName), types.MergePatchType, []byte(patch), metav1.PatchOptions{}, "status")
if err != nil {
return fmt.Errorf("failed to update node capacity and allocatable: %v", err)
}

return nil
}

func (f *FakeNodeDevicePlugin) Name() string {
return "FakeNodeDevicePlugin"
}
12 changes: 9 additions & 3 deletions internal/deviceplugin/real_node.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ type RealNodeDevicePlugin struct {
stop chan interface{}
health chan *pluginapi.Device
server *grpc.Server

resourceName string
}

func getGpuCount(nodeTopology *topology.NodeTopology) int {
Expand Down Expand Up @@ -115,7 +117,7 @@ func (m *RealNodeDevicePlugin) Stop() error {
return m.cleanup()
}

func (m *RealNodeDevicePlugin) Register(kubeletEndpoint, resourceName string) error {
func (m *RealNodeDevicePlugin) Register(kubeletEndpoint string) error {
conn, err := dial(kubeletEndpoint, 5*time.Second)
if err != nil {
return err
Expand All @@ -126,7 +128,7 @@ func (m *RealNodeDevicePlugin) Register(kubeletEndpoint, resourceName string) er
reqt := &pluginapi.RegisterRequest{
Version: pluginapi.Version,
Endpoint: path.Base(m.socket),
ResourceName: resourceName,
ResourceName: m.resourceName,
}

_, err = client.Register(context.Background(), reqt)
Expand Down Expand Up @@ -202,7 +204,7 @@ func (m *RealNodeDevicePlugin) Serve() error {
}
log.Println("Starting to serve on", m.socket)

err = m.Register(pluginapi.KubeletSocket, resourceName)
err = m.Register(pluginapi.KubeletSocket)
if err != nil {
log.Printf("Could not register device plugin: %s", err)
stopErr := m.Stop()
Expand All @@ -215,3 +217,7 @@ func (m *RealNodeDevicePlugin) Serve() error {

return nil
}

func (m *RealNodeDevicePlugin) Name() string {
return "RealNodeDevicePlugin-" + m.resourceName
}

0 comments on commit 22dc131

Please sign in to comment.