diff --git a/cmd/device-plugin/main.go b/cmd/device-plugin/main.go index 3dc1f1b..9b8cab2 100644 --- a/cmd/device-plugin/main.go +++ b/cmd/device-plugin/main.go @@ -42,10 +42,13 @@ func main() { initNvidiaSmi() initPreloaders() - devicePlugin := deviceplugin.NewDevicePlugin(topology, kubeClient) - if err = devicePlugin.Serve(); err != nil { - log.Printf("Failed to serve device plugin: %s\n", err) - os.Exit(1) + devicePlugins := deviceplugin.NewDevicePlugins(topology, kubeClient) + for _, devicePlugin := range devicePlugins { + log.Printf("Starting device plugin for %s\n", devicePlugin.Name()) + if err = devicePlugin.Serve(); err != nil { + log.Printf("Failed to serve device plugin: %s\n", err) + os.Exit(1) + } } sig := make(chan os.Signal, 1) diff --git a/internal/common/topology/types.go b/internal/common/topology/types.go index c965172..1dbcf7c 100644 --- a/internal/common/topology/types.go +++ b/internal/common/topology/types.go @@ -14,16 +14,18 @@ type ClusterTopology struct { } type NodePoolTopology struct { - GpuCount int `yaml:"gpuCount"` - GpuMemory int `yaml:"gpuMemory"` - GpuProduct string `yaml:"gpuProduct"` + GpuCount int `yaml:"gpuCount"` + GpuMemory int `yaml:"gpuMemory"` + GpuProduct string `yaml:"gpuProduct"` + OtherDevices []GenericDevice `yaml:"otherDevices,omitempty"` } type NodeTopology struct { - GpuMemory int `yaml:"gpuMemory"` - GpuProduct string `yaml:"gpuProduct"` - Gpus []GpuDetails `yaml:"gpus"` - MigStrategy string `yaml:"migStrategy"` + GpuMemory int `yaml:"gpuMemory"` + GpuProduct string `yaml:"gpuProduct"` + Gpus []GpuDetails `yaml:"gpus"` + MigStrategy string `yaml:"migStrategy"` + OtherDevices []GenericDevice `yaml:"otherDevices,omitempty"` } type GpuDetails struct { @@ -56,6 +58,11 @@ type Range struct { Max int `yaml:"max"` } +type GenericDevice struct { + Name string `yaml:"name"` + Count int `yaml:"count"` +} + // Errors var ErrNoNodes = fmt.Errorf("no nodes found") var ErrNoNode = fmt.Errorf("node not found") diff --git a/internal/deviceplugin/device_plugin.go b/internal/deviceplugin/device_plugin.go index df71de5..e64c95c 100644 --- a/internal/deviceplugin/device_plugin.go +++ b/internal/deviceplugin/device_plugin.go @@ -1,34 +1,59 @@ package deviceplugin import ( + "path" + "strings" + "github.com/run-ai/fake-gpu-operator/internal/common/constants" "github.com/run-ai/fake-gpu-operator/internal/common/topology" "github.com/spf13/viper" "k8s.io/client-go/kubernetes" + pluginapi "k8s.io/kubelet/pkg/apis/deviceplugin/v1beta1" ) const ( - resourceName = "nvidia.com/gpu" + nvidiaGPUResourceName = "nvidia.com/gpu" ) type Interface interface { Serve() error + Name() string } -func NewDevicePlugin(topology *topology.NodeTopology, kubeClient kubernetes.Interface) Interface { +func NewDevicePlugins(topology *topology.NodeTopology, kubeClient kubernetes.Interface) []Interface { if topology == nil { panic("topology is nil") } if viper.GetBool(constants.EnvFakeNode) { - return &FakeNodeDevicePlugin{ + return []Interface{&FakeNodeDevicePlugin{ kubeClient: kubeClient, gpuCount: getGpuCount(topology), - } + }} + } + + devicePlugins := []Interface{ + &RealNodeDevicePlugin{ + devs: createDevices(getGpuCount(topology)), + socket: serverSock, + resourceName: nvidiaGPUResourceName, + }, } - return &RealNodeDevicePlugin{ - devs: createDevices(getGpuCount(topology)), - socket: serverSock, + for _, genericDevice := range topology.OtherDevices { + devicePlugins = append(devicePlugins, &RealNodeDevicePlugin{ + devs: createDevices(genericDevice.Count), + socket: path.Join(pluginapi.DevicePluginPath, normalizeDeviceName(genericDevice.Name)+".sock"), + resourceName: genericDevice.Name, + }) } + + return devicePlugins +} + +func normalizeDeviceName(deviceName string) string { + normalized := strings.ReplaceAll(deviceName, "/", "_") + normalized = strings.ReplaceAll(normalized, ".", "_") + normalized = strings.ReplaceAll(normalized, "-", "_") + return normalized } diff --git a/internal/deviceplugin/fake_node.go b/internal/deviceplugin/fake_node.go index b1c503e..ad2f168 100644 --- a/internal/deviceplugin/fake_node.go +++ b/internal/deviceplugin/fake_node.go @@ -18,7 +18,7 @@ type FakeNodeDevicePlugin struct { } func (f *FakeNodeDevicePlugin) Serve() error { - patch := fmt.Sprintf(`{"status": {"capacity": {"%s": "%d"}, "allocatable": {"%s": "%d"}}}`, resourceName, f.gpuCount, resourceName, f.gpuCount) + patch := fmt.Sprintf(`{"status": {"capacity": {"%s": "%d"}, "allocatable": {"%s": "%d"}}}`, nvidiaGPUResourceName, f.gpuCount, nvidiaGPUResourceName, f.gpuCount) _, err := f.kubeClient.CoreV1().Nodes().Patch(context.TODO(), os.Getenv(constants.EnvNodeName), types.MergePatchType, []byte(patch), metav1.PatchOptions{}, "status") if err != nil { return fmt.Errorf("failed to update node capacity and allocatable: %v", err) @@ -26,3 +26,7 @@ func (f *FakeNodeDevicePlugin) Serve() error { return nil } + +func (f *FakeNodeDevicePlugin) Name() string { + return "FakeNodeDevicePlugin" +} diff --git a/internal/deviceplugin/real_node.go b/internal/deviceplugin/real_node.go index e78ab38..3fc325f 100644 --- a/internal/deviceplugin/real_node.go +++ b/internal/deviceplugin/real_node.go @@ -28,6 +28,8 @@ type RealNodeDevicePlugin struct { stop chan interface{} health chan *pluginapi.Device server *grpc.Server + + resourceName string } func getGpuCount(nodeTopology *topology.NodeTopology) int { @@ -115,7 +117,7 @@ func (m *RealNodeDevicePlugin) Stop() error { return m.cleanup() } -func (m *RealNodeDevicePlugin) Register(kubeletEndpoint, resourceName string) error { +func (m *RealNodeDevicePlugin) Register(kubeletEndpoint string) error { conn, err := dial(kubeletEndpoint, 5*time.Second) if err != nil { return err @@ -126,7 +128,7 @@ func (m *RealNodeDevicePlugin) Register(kubeletEndpoint, resourceName string) er reqt := &pluginapi.RegisterRequest{ Version: pluginapi.Version, Endpoint: path.Base(m.socket), - ResourceName: resourceName, + ResourceName: m.resourceName, } _, err = client.Register(context.Background(), reqt) @@ -202,7 +204,7 @@ func (m *RealNodeDevicePlugin) Serve() error { } log.Println("Starting to serve on", m.socket) - err = m.Register(pluginapi.KubeletSocket, resourceName) + err = m.Register(pluginapi.KubeletSocket) if err != nil { log.Printf("Could not register device plugin: %s", err) stopErr := m.Stop() @@ -215,3 +217,7 @@ func (m *RealNodeDevicePlugin) Serve() error { return nil } + +func (m *RealNodeDevicePlugin) Name() string { + return "RealNodeDevicePlugin-" + m.resourceName +}