diff --git a/.github/workflows/pr-check.yaml b/.github/workflows/pr-check.yaml index 1701b2a..aabe535 100644 --- a/.github/workflows/pr-check.yaml +++ b/.github/workflows/pr-check.yaml @@ -22,7 +22,7 @@ jobs: build: name: Test & Build - runs-on: ubuntu-20.04 + runs-on: ubuntu-24.04 steps: - name: Setup up Go 1.22 uses: actions/setup-go@v5 diff --git a/.gitignore b/.gitignore index ba2f9ed..4c724fd 100644 --- a/.gitignore +++ b/.gitignore @@ -3,4 +3,4 @@ /test/e2e/ginkgo cloud-config /vendor - +/hack/tools/bin diff --git a/Makefile b/Makefile index fadfa9d..f804adc 100644 --- a/Makefile +++ b/Makefile @@ -18,6 +18,20 @@ IMPORTPATH_LDFLAGS = -X ${PKG}/pkg/driver.driverVersion=$(REV) -X ${PKG}/pkg/dri LDFLAGS = -s -w FULL_LDFLAGS = $(LDFLAGS) $(IMPORTPATH_LDFLAGS) +export REPO_ROOT := $(shell git rev-parse --show-toplevel) + +# Directories +TOOLS_DIR := $(REPO_ROOT)/hack/tools +TOOLS_BIN_DIR := $(TOOLS_DIR)/bin +BIN_DIR ?= bin + +GO_INSTALL := ./hack/go_install.sh + +MOCKGEN_BIN := mockgen +MOCKGEN_VER := v0.5.0 +MOCKGEN := $(abspath $(TOOLS_BIN_DIR)/$(MOCKGEN_BIN)-$(MOCKGEN_VER)) +MOCKGEN_PKG := go.uber.org/mock/mockgen + .PHONY: all all: build @@ -41,6 +55,11 @@ $(CMDS:%=container-%): container-%: build-% $(DOCKER) build -f ./cmd/$*/Dockerfile -t $*:latest \ --label org.opencontainers.image.revision=$(REV) . +.PHONY: generate-mocks +generate-mocks: $(MOCKGEN) pkg/cloud/mock_cloud.go ## Generate mocks needed for testing. Primarily mocks of the cloud package. +pkg/cloud/mock%.go: $(shell find ./pkg/cloud -type f -name "*test*" -prune -o -print) + go generate ./... + .PHONY: test test: go test ./... @@ -59,3 +78,11 @@ test/e2e/e2e.test test/e2e/ginkgo: .PHONY: test-e2e test-e2e: setup-external-e2e bash ./test/e2e/run.sh + +##@ hack/tools: + +.PHONY: $(MOCKGEN_BIN) +$(MOCKGEN_BIN): $(MOCKGEN) ## Build a local copy of mockgen. + +$(MOCKGEN): # Build mockgen from tools folder. + GOBIN=$(TOOLS_BIN_DIR) $(GO_INSTALL) $(MOCKGEN_PKG) $(MOCKGEN_BIN) $(MOCKGEN_VER) diff --git a/cmd/cloudstack-csi-driver/main.go b/cmd/cloudstack-csi-driver/main.go index 39793cc..6bc2cda 100644 --- a/cmd/cloudstack-csi-driver/main.go +++ b/cmd/cloudstack-csi-driver/main.go @@ -95,7 +95,7 @@ func main() { ctx := klog.NewContext(context.Background(), logger) csConnector := cloud.New(config) - d, err := driver.New(ctx, csConnector, &options, nil) + d, err := driver.NewDriver(ctx, csConnector, &options, nil) if err != nil { logger.Error(err, "Failed to initialize driver") klog.FlushAndExit(klog.ExitFlushTimeout, 1) diff --git a/go.mod b/go.mod index 6222171..9d6754a 100644 --- a/go.mod +++ b/go.mod @@ -7,10 +7,13 @@ toolchain go1.22.8 require ( github.com/apache/cloudstack-go/v2 v2.16.1 github.com/container-storage-interface/spec v1.9.0 + github.com/golang/mock v1.6.0 github.com/hashicorp/go-uuid v1.0.3 github.com/kubernetes-csi/csi-lib-utils v0.18.1 github.com/kubernetes-csi/csi-test/v5 v5.2.0 github.com/spf13/pflag v1.0.5 + github.com/stretchr/testify v1.9.0 + go.uber.org/mock v0.5.0 golang.org/x/sys v0.26.0 golang.org/x/text v0.19.0 google.golang.org/grpc v1.65.0 @@ -37,7 +40,6 @@ require ( github.com/go-openapi/swag v0.22.3 // indirect github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect github.com/gogo/protobuf v1.3.2 // indirect - github.com/golang/mock v1.6.0 // indirect github.com/golang/protobuf v1.5.4 // indirect github.com/google/gnostic-models v0.6.8 // indirect github.com/google/go-cmp v0.6.0 // indirect @@ -56,6 +58,7 @@ require ( github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect github.com/onsi/ginkgo/v2 v2.15.0 // indirect github.com/onsi/gomega v1.31.0 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect github.com/prometheus/client_golang v1.16.0 // indirect github.com/prometheus/client_model v0.4.0 // indirect github.com/prometheus/common v0.44.0 // indirect @@ -63,11 +66,11 @@ require ( github.com/spf13/cobra v1.7.0 // indirect go.uber.org/multierr v1.11.0 // indirect go.uber.org/zap v1.27.0 // indirect - golang.org/x/net v0.25.0 // indirect + golang.org/x/net v0.26.0 // indirect golang.org/x/oauth2 v0.20.0 // indirect - golang.org/x/term v0.20.0 // indirect + golang.org/x/term v0.21.0 // indirect golang.org/x/time v0.5.0 // indirect - golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d // indirect + golang.org/x/tools v0.22.0 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20240528184218-531527333157 // indirect google.golang.org/protobuf v1.34.1 // indirect gopkg.in/inf.v0 v0.9.1 // indirect diff --git a/go.sum b/go.sum index 4088810..2608bfc 100644 --- a/go.sum +++ b/go.sum @@ -115,13 +115,15 @@ github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= -github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= -github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= +go.uber.org/mock v0.5.0 h1:KAMbZvZPyBPWgD14IrIQ38QCyjwpvVVV6K/bHl1IwQU= +go.uber.org/mock v0.5.0/go.mod h1:ge71pBPLYDk7QIi1LupWxdAykm7KIEFchiOqd6z7qMM= go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8= @@ -137,8 +139,8 @@ golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLL golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= -golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac= -golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM= +golang.org/x/net v0.26.0 h1:soB7SVo0PWrY4vPW/+ay0jKDNScG2X9wFeYlXIvJsOQ= +golang.org/x/net v0.26.0/go.mod h1:5YKkiSynbBIh3p6iOc/vibscux0x38BZDkn8sCUPxHE= golang.org/x/oauth2 v0.20.0 h1:4mQdhULixXKP1rwYBW0vAijoXnkTG0BLCDRzfe1idMo= golang.org/x/oauth2 v0.20.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -157,8 +159,8 @@ golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo= golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= -golang.org/x/term v0.20.0 h1:VnkxpohqXaOBYJtBmEppKUG6mXpi+4O6purfc2+sMhw= -golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY= +golang.org/x/term v0.21.0 h1:WVXCp+/EBEHOj53Rvu+7KiT/iElMrO8ACK16SMZ3jaA= +golang.org/x/term v0.21.0/go.mod h1:ooXLefLobQVslOqselCNF4SxFAaoS6KujMbsGzSDmX0= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.19.0 h1:kTxAhCbGbxhK0IwgSKiMO5awPoDQ0RpfiVYBfK860YM= @@ -170,8 +172,8 @@ golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtn golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= -golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d h1:vU5i/LfpvrRCpgM/VPfJLg5KjxD3E+hfT1SH+d9zLwg= -golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk= +golang.org/x/tools v0.22.0 h1:gqSGLZqv+AI9lIQzniJ0nZDRG5GBPsSi+DRNHWNz6yA= +golang.org/x/tools v0.22.0/go.mod h1:aCwcsjqvq7Yqt6TNyX7QMU2enbQ/Gt0bo6krSeEri+c= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/hack/go_install.sh b/hack/go_install.sh new file mode 100755 index 0000000..0663b0b --- /dev/null +++ b/hack/go_install.sh @@ -0,0 +1,45 @@ +#!/usr/bin/env bash +# Copyright 2021 The Kubernetes Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +set -o errexit +set -o nounset +set -o pipefail + +if [ -z "${1}" ]; then + echo "must provide module as first parameter" + exit 1 +fi + +if [ -z "${2}" ]; then + echo "must provide binary name as second parameter" + exit 1 +fi + +if [ -z "${3}" ]; then + echo "must provide version as third parameter" + exit 1 +fi + +if [ -z "${GOBIN}" ]; then + echo "GOBIN is not set. Must set GOBIN to install the bin in a specified directory." + exit 1 +fi + +rm -f "${GOBIN}/${2}"* || true + +# install the golang module specified as the first argument +go install "${1}@${3}" +mv "${GOBIN}/${2}" "${GOBIN}/${2}-${3}" +ln -sf "${GOBIN}/${2}-${3}" "${GOBIN}/${2}" \ No newline at end of file diff --git a/pkg/cloud/cloud.go b/pkg/cloud/cloud.go index dc45a05..9561a8f 100644 --- a/pkg/cloud/cloud.go +++ b/pkg/cloud/cloud.go @@ -9,8 +9,10 @@ import ( "github.com/apache/cloudstack-go/v2/cloudstack" ) -// Interface is the CloudStack client interface. -type Interface interface { +//go:generate ../../hack/tools/bin/mockgen -destination=./mock_cloud.go -package=cloud -source ./cloud.go + +// Cloud is the CloudStack client interface. +type Cloud interface { GetNodeInfo(ctx context.Context, vmName string) (*VM, error) GetVMByID(ctx context.Context, vmID string) (*VM, error) @@ -52,14 +54,14 @@ var ( ErrTooManyResults = errors.New("too many results") ) -// client is the implementation of Interface. +// client is the implementation of Cloud. type client struct { *cloudstack.CloudStackClient projectID string } // New creates a new cloud connector, given its configuration. -func New(config *Config) Interface { +func New(config *Config) Cloud { csClient := &client{ projectID: config.ProjectID, } diff --git a/pkg/cloud/fake/fake.go b/pkg/cloud/fake/fake.go index a9db626..2e29957 100644 --- a/pkg/cloud/fake/fake.go +++ b/pkg/cloud/fake/fake.go @@ -4,6 +4,7 @@ package fake import ( "context" + "errors" "github.com/hashicorp/go-uuid" @@ -21,7 +22,7 @@ type fakeConnector struct { // New returns a new fake implementation of the // CloudStack connector. -func New() cloud.Interface { +func New() cloud.Cloud { volume := cloud.Volume{ ID: "ace9f28b-3081-40c1-8353-4cc3e3014072", Name: "vol-1", @@ -102,12 +103,32 @@ func (f *fakeConnector) DeleteVolume(_ context.Context, id string) error { return nil } -func (f *fakeConnector) AttachVolume(_ context.Context, _, _ string) (string, error) { - return "1", nil +func (f *fakeConnector) AttachVolume(_ context.Context, volumeID, nodeID string) (string, error) { + if f.getVolumesPerServer(nodeID) >= 16 { + return "", errors.New("The specified VM already has the maximum number of data disks (16) attached. Please specify another VM.") //nolint:revive,stylecheck + } + + if vol, ok := f.volumesByID[volumeID]; ok { + vol.VirtualMachineID = nodeID + f.volumesByID[volumeID] = vol + f.volumesByName[vol.Name] = vol + + return "1", nil + } + + return "", cloud.ErrNotFound } -func (f *fakeConnector) DetachVolume(_ context.Context, _ string) error { - return nil +func (f *fakeConnector) DetachVolume(_ context.Context, volumeID string) error { + if vol, ok := f.volumesByID[volumeID]; ok { + vol.VirtualMachineID = "" + f.volumesByID[volumeID] = vol + f.volumesByName[vol.Name] = vol + + return nil + } + + return cloud.ErrNotFound } func (f *fakeConnector) ExpandVolume(_ context.Context, volumeID string, newSizeInGB int64) error { @@ -124,3 +145,14 @@ func (f *fakeConnector) ExpandVolume(_ context.Context, volumeID string, newSize return cloud.ErrNotFound } + +func (f *fakeConnector) getVolumesPerServer(nodeID string) int { + volumesCount := 0 + for _, v := range f.volumesByID { + if v.VirtualMachineID == nodeID { + volumesCount++ + } + } + + return volumesCount +} diff --git a/pkg/cloud/mock_cloud.go b/pkg/cloud/mock_cloud.go new file mode 100644 index 0000000..b2de673 --- /dev/null +++ b/pkg/cloud/mock_cloud.go @@ -0,0 +1,188 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: ./cloud.go +// +// Generated by this command: +// +// mockgen -destination=./mock_cloud.go -package=cloud -source ./cloud.go +// + +// Package cloud is a generated GoMock package. +package cloud + +import ( + context "context" + reflect "reflect" + + gomock "go.uber.org/mock/gomock" +) + +// MockCloud is a mock of Cloud interface. +type MockCloud struct { + ctrl *gomock.Controller + recorder *MockCloudMockRecorder + isgomock struct{} +} + +// MockCloudMockRecorder is the mock recorder for MockCloud. +type MockCloudMockRecorder struct { + mock *MockCloud +} + +// NewMockCloud creates a new mock instance. +func NewMockCloud(ctrl *gomock.Controller) *MockCloud { + mock := &MockCloud{ctrl: ctrl} + mock.recorder = &MockCloudMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockCloud) EXPECT() *MockCloudMockRecorder { + return m.recorder +} + +// AttachVolume mocks base method. +func (m *MockCloud) AttachVolume(ctx context.Context, volumeID, vmID string) (string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AttachVolume", ctx, volumeID, vmID) + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// AttachVolume indicates an expected call of AttachVolume. +func (mr *MockCloudMockRecorder) AttachVolume(ctx, volumeID, vmID any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AttachVolume", reflect.TypeOf((*MockCloud)(nil).AttachVolume), ctx, volumeID, vmID) +} + +// CreateVolume mocks base method. +func (m *MockCloud) CreateVolume(ctx context.Context, diskOfferingID, zoneID, name string, sizeInGB int64) (string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CreateVolume", ctx, diskOfferingID, zoneID, name, sizeInGB) + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CreateVolume indicates an expected call of CreateVolume. +func (mr *MockCloudMockRecorder) CreateVolume(ctx, diskOfferingID, zoneID, name, sizeInGB any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateVolume", reflect.TypeOf((*MockCloud)(nil).CreateVolume), ctx, diskOfferingID, zoneID, name, sizeInGB) +} + +// DeleteVolume mocks base method. +func (m *MockCloud) DeleteVolume(ctx context.Context, id string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteVolume", ctx, id) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeleteVolume indicates an expected call of DeleteVolume. +func (mr *MockCloudMockRecorder) DeleteVolume(ctx, id any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteVolume", reflect.TypeOf((*MockCloud)(nil).DeleteVolume), ctx, id) +} + +// DetachVolume mocks base method. +func (m *MockCloud) DetachVolume(ctx context.Context, volumeID string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DetachVolume", ctx, volumeID) + ret0, _ := ret[0].(error) + return ret0 +} + +// DetachVolume indicates an expected call of DetachVolume. +func (mr *MockCloudMockRecorder) DetachVolume(ctx, volumeID any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DetachVolume", reflect.TypeOf((*MockCloud)(nil).DetachVolume), ctx, volumeID) +} + +// ExpandVolume mocks base method. +func (m *MockCloud) ExpandVolume(ctx context.Context, volumeID string, newSizeInGB int64) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ExpandVolume", ctx, volumeID, newSizeInGB) + ret0, _ := ret[0].(error) + return ret0 +} + +// ExpandVolume indicates an expected call of ExpandVolume. +func (mr *MockCloudMockRecorder) ExpandVolume(ctx, volumeID, newSizeInGB any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ExpandVolume", reflect.TypeOf((*MockCloud)(nil).ExpandVolume), ctx, volumeID, newSizeInGB) +} + +// GetNodeInfo mocks base method. +func (m *MockCloud) GetNodeInfo(ctx context.Context, vmName string) (*VM, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetNodeInfo", ctx, vmName) + ret0, _ := ret[0].(*VM) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetNodeInfo indicates an expected call of GetNodeInfo. +func (mr *MockCloudMockRecorder) GetNodeInfo(ctx, vmName any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetNodeInfo", reflect.TypeOf((*MockCloud)(nil).GetNodeInfo), ctx, vmName) +} + +// GetVMByID mocks base method. +func (m *MockCloud) GetVMByID(ctx context.Context, vmID string) (*VM, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetVMByID", ctx, vmID) + ret0, _ := ret[0].(*VM) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetVMByID indicates an expected call of GetVMByID. +func (mr *MockCloudMockRecorder) GetVMByID(ctx, vmID any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetVMByID", reflect.TypeOf((*MockCloud)(nil).GetVMByID), ctx, vmID) +} + +// GetVolumeByID mocks base method. +func (m *MockCloud) GetVolumeByID(ctx context.Context, volumeID string) (*Volume, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetVolumeByID", ctx, volumeID) + ret0, _ := ret[0].(*Volume) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetVolumeByID indicates an expected call of GetVolumeByID. +func (mr *MockCloudMockRecorder) GetVolumeByID(ctx, volumeID any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetVolumeByID", reflect.TypeOf((*MockCloud)(nil).GetVolumeByID), ctx, volumeID) +} + +// GetVolumeByName mocks base method. +func (m *MockCloud) GetVolumeByName(ctx context.Context, name string) (*Volume, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetVolumeByName", ctx, name) + ret0, _ := ret[0].(*Volume) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetVolumeByName indicates an expected call of GetVolumeByName. +func (mr *MockCloudMockRecorder) GetVolumeByName(ctx, name any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetVolumeByName", reflect.TypeOf((*MockCloud)(nil).GetVolumeByName), ctx, name) +} + +// ListZonesID mocks base method. +func (m *MockCloud) ListZonesID(ctx context.Context) ([]string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ListZonesID", ctx) + ret0, _ := ret[0].([]string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ListZonesID indicates an expected call of ListZonesID. +func (mr *MockCloudMockRecorder) ListZonesID(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListZonesID", reflect.TypeOf((*MockCloud)(nil).ListZonesID), ctx) +} diff --git a/pkg/driver/controller.go b/pkg/driver/controller.go index 8cf8765..364e7d4 100644 --- a/pkg/driver/controller.go +++ b/pkg/driver/controller.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "math/rand" + "regexp" "github.com/container-storage-interface/spec/lib/go/csi" "github.com/kubernetes-csi/csi-lib-utils/protosanitizer" @@ -16,18 +17,25 @@ import ( "github.com/leaseweb/cloudstack-csi-driver/pkg/util" ) -// onlyVolumeCapAccessMode is the only volume capability access -// mode possible for CloudStack: SINGLE_NODE_WRITER, since a -// CloudStack volume can only be attached to a single node at -// any given time. -var onlyVolumeCapAccessMode = csi.VolumeCapability_AccessMode{ - Mode: csi.VolumeCapability_AccessMode_SINGLE_NODE_WRITER, -} +var ( + // onlyVolumeCapAccessMode is the only volume capability access + // mode possible for CloudStack: SINGLE_NODE_WRITER, since a + // CloudStack volume can only be attached to a single node at + // any given time. + onlyVolumeCapAccessMode = csi.VolumeCapability_AccessMode{ + Mode: csi.VolumeCapability_AccessMode_SINGLE_NODE_WRITER, + } -type controllerServer struct { + // maxVolumesPerVMErrorMessage is the error message returned by the CloudStack + // API when the per-server volume limit would be exceeded. + maxVolumesPerVMErrorMessageRe = regexp.MustCompile(`The specified VM already has the maximum number of data disks \(\d+\) attached\. Please specify another VM\.`) +) + +// ControllerService represents the controller service of CSI driver. +type ControllerService struct { csi.UnimplementedControllerServer // connector is the CloudStack client interface - connector cloud.Interface + connector cloud.Cloud // A map storing all volumes with ongoing operations so that additional operations // for that same volume (as defined by VolumeID/volume name) return an Aborted error @@ -37,18 +45,18 @@ type controllerServer struct { operationLocks *util.OperationLock } -// NewControllerServer creates a new Controller gRPC server. -func NewControllerServer(connector cloud.Interface) csi.ControllerServer { - return &controllerServer{ +// NewControllerService creates a new controller service. +func NewControllerService(connector cloud.Cloud) *ControllerService { + return &ControllerService{ connector: connector, volumeLocks: util.NewVolumeLocks(), operationLocks: util.NewOperationLock(), } } -func (cs *controllerServer) CreateVolume(ctx context.Context, req *csi.CreateVolumeRequest) (*csi.CreateVolumeResponse, error) { +func (cs *ControllerService) CreateVolume(ctx context.Context, req *csi.CreateVolumeRequest) (*csi.CreateVolumeResponse, error) { logger := klog.FromContext(ctx) - logger.V(6).Info("CreateVolume: called", "args", *req) + logger.V(4).Info("CreateVolume: called", "args", *req) // Check arguments. @@ -74,9 +82,7 @@ func (cs *controllerServer) CreateVolume(ctx context.Context, req *csi.CreateVol } if acquired := cs.volumeLocks.TryAcquire(name); !acquired { - logger.Error(errors.New(util.ErrVolumeOperationAlreadyExistsVolumeName), "failed to acquire volume lock", "volumeName", name) - - return nil, status.Errorf(codes.Aborted, util.VolumeOperationAlreadyExistsFmt, name) + return nil, status.Errorf(codes.Aborted, util.VolumeOperationAlreadyExistsVolumeNameFmt, name) } defer cs.volumeLocks.Release(name) @@ -228,9 +234,9 @@ func determineSize(req *csi.CreateVolumeRequest) (int64, error) { return sizeInGB, nil } -func (cs *controllerServer) DeleteVolume(ctx context.Context, req *csi.DeleteVolumeRequest) (*csi.DeleteVolumeResponse, error) { +func (cs *ControllerService) DeleteVolume(ctx context.Context, req *csi.DeleteVolumeRequest) (*csi.DeleteVolumeResponse, error) { logger := klog.FromContext(ctx) - logger.V(6).Info("DeleteVolume: called", "args", *req) + logger.V(4).Info("DeleteVolume: called", "args", *req) if req.GetVolumeId() == "" { return nil, status.Error(codes.InvalidArgument, "Volume ID missing in request") @@ -239,9 +245,7 @@ func (cs *controllerServer) DeleteVolume(ctx context.Context, req *csi.DeleteVol volumeID := req.GetVolumeId() if acquired := cs.volumeLocks.TryAcquire(volumeID); !acquired { - logger.Error(errors.New(util.ErrVolumeOperationAlreadyExistsVolumeID), "failed to acquire volume lock", "volumeID", volumeID) - - return nil, status.Errorf(codes.Aborted, util.VolumeOperationAlreadyExistsFmt, volumeID) + return nil, status.Errorf(codes.Aborted, util.VolumeOperationAlreadyExistsVolumeIDFmt, volumeID) } defer cs.volumeLocks.Release(volumeID) @@ -265,9 +269,9 @@ func (cs *controllerServer) DeleteVolume(ctx context.Context, req *csi.DeleteVol return &csi.DeleteVolumeResponse{}, nil } -func (cs *controllerServer) ControllerPublishVolume(ctx context.Context, req *csi.ControllerPublishVolumeRequest) (*csi.ControllerPublishVolumeResponse, error) { +func (cs *ControllerService) ControllerPublishVolume(ctx context.Context, req *csi.ControllerPublishVolumeRequest) (*csi.ControllerPublishVolumeResponse, error) { logger := klog.FromContext(ctx) - logger.V(6).Info("ControllerPublishVolume: called", "args", *req) + logger.V(4).Info("ControllerPublishVolume: called", "args", *req) // Check arguments. @@ -292,6 +296,11 @@ func (cs *controllerServer) ControllerPublishVolume(ctx context.Context, req *cs return nil, status.Error(codes.InvalidArgument, "Access mode not accepted") } + if acquired := cs.volumeLocks.TryAcquire(volumeID); !acquired { + return nil, status.Errorf(codes.Aborted, util.VolumeOperationAlreadyExistsVolumeIDFmt, volumeID) + } + defer cs.volumeLocks.Release(volumeID) + logger.Info("Initiating attaching volume", "volumeID", volumeID, "nodeID", nodeID, @@ -344,6 +353,10 @@ func (cs *controllerServer) ControllerPublishVolume(ctx context.Context, req *cs deviceID, err := cs.connector.AttachVolume(ctx, volumeID, nodeID) if err != nil { + if maxVolumesPerVMErrorMessageRe.MatchString(err.Error()) { + return nil, status.Errorf(codes.ResourceExhausted, "Cannot attach volume %s: %s", volumeID, err.Error()) + } + return nil, status.Errorf(codes.Internal, "Cannot attach volume %s: %s", volumeID, err.Error()) } @@ -359,9 +372,9 @@ func (cs *controllerServer) ControllerPublishVolume(ctx context.Context, req *cs return &csi.ControllerPublishVolumeResponse{PublishContext: publishContext}, nil } -func (cs *controllerServer) ControllerUnpublishVolume(ctx context.Context, req *csi.ControllerUnpublishVolumeRequest) (*csi.ControllerUnpublishVolumeResponse, error) { +func (cs *ControllerService) ControllerUnpublishVolume(ctx context.Context, req *csi.ControllerUnpublishVolumeRequest) (*csi.ControllerUnpublishVolumeResponse, error) { logger := klog.FromContext(ctx) - logger.V(6).Info("ControllerUnpublishVolume: called", "args", *req) + logger.V(4).Info("ControllerUnpublishVolume: called", "args", *req) // Check arguments. @@ -371,6 +384,11 @@ func (cs *controllerServer) ControllerUnpublishVolume(ctx context.Context, req * volumeID := req.GetVolumeId() nodeID := req.GetNodeId() + if acquired := cs.volumeLocks.TryAcquire(volumeID); !acquired { + return nil, status.Errorf(codes.Aborted, util.VolumeOperationAlreadyExistsVolumeIDFmt, volumeID) + } + defer cs.volumeLocks.Release(volumeID) + // Check volume. if vol, err := cs.connector.GetVolumeByID(ctx, volumeID); errors.Is(err, cloud.ErrNotFound) { // Volume does not exist in CloudStack. We can safely assume this volume is no longer attached @@ -416,9 +434,9 @@ func (cs *controllerServer) ControllerUnpublishVolume(ctx context.Context, req * return &csi.ControllerUnpublishVolumeResponse{}, nil } -func (cs *controllerServer) ValidateVolumeCapabilities(ctx context.Context, req *csi.ValidateVolumeCapabilitiesRequest) (*csi.ValidateVolumeCapabilitiesResponse, error) { +func (cs *ControllerService) ValidateVolumeCapabilities(ctx context.Context, req *csi.ValidateVolumeCapabilitiesRequest) (*csi.ValidateVolumeCapabilitiesResponse, error) { logger := klog.FromContext(ctx) - logger.V(6).Info("ValidateVolumeCapabilities: called", "args", *req) + logger.V(4).Info("ValidateVolumeCapabilities: called", "args", *req) volumeID := req.GetVolumeId() if len(volumeID) == 0 { @@ -460,9 +478,9 @@ func isValidVolumeCapabilities(volCaps []*csi.VolumeCapability) bool { return true } -func (cs *controllerServer) ControllerExpandVolume(ctx context.Context, req *csi.ControllerExpandVolumeRequest) (*csi.ControllerExpandVolumeResponse, error) { +func (cs *ControllerService) ControllerExpandVolume(ctx context.Context, req *csi.ControllerExpandVolumeRequest) (*csi.ControllerExpandVolumeResponse, error) { logger := klog.FromContext(ctx) - logger.V(6).Info("ControllerExpandVolume: called", "args", protosanitizer.StripSecrets(*req)) + logger.V(4).Info("ControllerExpandVolume: called", "args", protosanitizer.StripSecrets(*req)) volumeID := req.GetVolumeId() if len(volumeID) == 0 { @@ -476,9 +494,7 @@ func (cs *controllerServer) ControllerExpandVolume(ctx context.Context, req *csi // lock out parallel requests against the same volume ID if acquired := cs.volumeLocks.TryAcquire(volumeID); !acquired { - logger.Error(errors.New(util.ErrVolumeOperationAlreadyExistsVolumeID), "failed to acquire volume lock", "volumeID", volumeID) - - return nil, status.Errorf(codes.Aborted, util.VolumeOperationAlreadyExistsFmt, volumeID) + return nil, status.Errorf(codes.Aborted, util.VolumeOperationAlreadyExistsVolumeIDFmt, volumeID) } defer cs.volumeLocks.Release(volumeID) @@ -530,9 +546,9 @@ func (cs *controllerServer) ControllerExpandVolume(ctx context.Context, req *csi }, nil } -func (cs *controllerServer) ControllerGetCapabilities(ctx context.Context, req *csi.ControllerGetCapabilitiesRequest) (*csi.ControllerGetCapabilitiesResponse, error) { +func (cs *ControllerService) ControllerGetCapabilities(ctx context.Context, req *csi.ControllerGetCapabilitiesRequest) (*csi.ControllerGetCapabilitiesResponse, error) { logger := klog.FromContext(ctx) - logger.V(6).Info("ControllerGetCapabilities: called", "args", protosanitizer.StripSecrets(*req)) + logger.V(4).Info("ControllerGetCapabilities: called", "args", protosanitizer.StripSecrets(*req)) resp := &csi.ControllerGetCapabilitiesResponse{ Capabilities: []*csi.ControllerServiceCapability{ diff --git a/pkg/driver/controller_test.go b/pkg/driver/controller_test.go index bf9f3d1..a157ed7 100644 --- a/pkg/driver/controller_test.go +++ b/pkg/driver/controller_test.go @@ -1,9 +1,28 @@ package driver import ( + "context" "testing" "github.com/container-storage-interface/spec/lib/go/csi" + "github.com/stretchr/testify/assert" + "go.uber.org/mock/gomock" + + "github.com/leaseweb/cloudstack-csi-driver/pkg/cloud" +) + +var ( + FakeCapacityGiB = 1 + FakeVolName = "CSIVolumeName" + FakeVolID = "CSIVolumeID" + FakeAvailability = "nova" + FakeDiskOfferingID = "9743fd77-0f5d-4ef9-b2f8-f194235c769c" + FakeVol = cloud.Volume{ + ID: FakeVolID, + Name: FakeVolName, + Size: int64(FakeCapacityGiB), + ZoneID: FakeAvailability, + } ) func TestDetermineSize(t *testing.T) { @@ -40,3 +59,51 @@ func TestDetermineSize(t *testing.T) { }) } } + +func TestCreateVolume(t *testing.T) { + ctx := context.Background() + mockCtl := gomock.NewController(t) + defer mockCtl.Finish() + mockCloud := cloud.NewMockCloud(mockCtl) + mockCloud.EXPECT().CreateVolume(gomock.Eq(ctx), FakeDiskOfferingID, FakeAvailability, FakeVolName, gomock.Any()).Return(FakeVolID, nil) + mockCloud.EXPECT().GetVolumeByName(gomock.Eq(ctx), FakeVolName).Return(nil, cloud.ErrNotFound) + fakeCs := NewControllerService(mockCloud) + // mock CloudStack + // CreateVolume(ctx context.Context, diskOfferingID, zoneID, name string, sizeInGB int64) (string, error) + // csmock.On("CreateVolume", FakeCtx, FakeDiskOfferingID, FakeAvailability, FakeVolName, mock.AnythingOfType("int64")).Return(FakeVolID, nil) + // csmock.On("GetVolumeByName", FakeCtx, FakeVolName).Return(nil, cloud.ErrNotFound) + // Init assert + assert := assert.New(t) + // Fake request + fakeReq := &csi.CreateVolumeRequest{ + Name: FakeVolName, + VolumeCapabilities: []*csi.VolumeCapability{ + { + AccessMode: &csi.VolumeCapability_AccessMode{ + Mode: csi.VolumeCapability_AccessMode_SINGLE_NODE_WRITER, + }, + }, + }, + Parameters: map[string]string{ + DiskOfferingKey: FakeDiskOfferingID, + }, + AccessibilityRequirements: &csi.TopologyRequirement{ + Requisite: []*csi.Topology{ + { + Segments: map[string]string{"topology.csi.cloudstack.apache.org/zone": FakeAvailability}, + }, + }, + }, + } + // Invoke CreateVolume + actualRes, err := fakeCs.CreateVolume(ctx, fakeReq) + if err != nil { + t.Errorf("failed to CreateVolume: %v", err) + } + // Assert + assert.NotNil(actualRes.GetVolume()) + assert.NotNil(actualRes.GetVolume().GetCapacityBytes()) + assert.NotEmpty(actualRes.GetVolume().GetVolumeId(), "Volume Id is empty") + assert.NotNil(actualRes.GetVolume().GetAccessibleTopology()) + assert.Equal(FakeAvailability, actualRes.GetVolume().GetAccessibleTopology()[0].GetSegments()[ZoneKey]) +} diff --git a/pkg/driver/driver.go b/pkg/driver/driver.go index 839574e..bdc10f2 100644 --- a/pkg/driver/driver.go +++ b/pkg/driver/driver.go @@ -23,14 +23,16 @@ type Interface interface { Run(ctx context.Context) error } -type cloudstackDriver struct { - controller csi.ControllerServer - node csi.NodeServer +type Driver struct { + controller *ControllerService + node *NodeService + srv *grpc.Server options *Options + csi.UnimplementedIdentityServer } -// New instantiates a new CloudStack CSI driver. -func New(ctx context.Context, csConnector cloud.Interface, options *Options, mounter mount.Interface) (Interface, error) { +// NewDriver instantiates a new CloudStack CSI driver. +func NewDriver(ctx context.Context, csConnector cloud.Cloud, options *Options, mounter mount.Mounter) (*Driver, error) { logger := klog.FromContext(ctx) logger.Info("Driver starting", "Driver", DriverName, "Version", driverVersion) @@ -38,18 +40,18 @@ func New(ctx context.Context, csConnector cloud.Interface, options *Options, mou return nil, fmt.Errorf("invalid driver options: %w", err) } - driver := &cloudstackDriver{ + driver := &Driver{ options: options, } switch options.Mode { case ControllerMode: - driver.controller = NewControllerServer(csConnector) + driver.controller = NewControllerService(csConnector) case NodeMode: - driver.node = NewNodeServer(csConnector, mounter, options) + driver.node = NewNodeService(csConnector, mounter, options) case AllMode: - driver.controller = NewControllerServer(csConnector) - driver.node = NewNodeServer(csConnector, mounter, options) + driver.controller = NewControllerService(csConnector) + driver.node = NewNodeService(csConnector, mounter, options) default: return nil, fmt.Errorf("unknown mode: %s", options.Mode) } @@ -57,9 +59,9 @@ func New(ctx context.Context, csConnector cloud.Interface, options *Options, mou return driver, nil } -func (cs *cloudstackDriver) Run(ctx context.Context) error { +func (d *Driver) Run(ctx context.Context) error { logger := klog.FromContext(ctx) - scheme, addr, err := util.ParseEndpoint(cs.options.Endpoint) + scheme, addr, err := util.ParseEndpoint(d.options.Endpoint) if err != nil { return err } @@ -80,24 +82,24 @@ func (cs *cloudstackDriver) Run(ctx context.Context) error { return resp, err }), } - grpcServer := grpc.NewServer(opts...) + d.srv = grpc.NewServer(opts...) + csi.RegisterIdentityServer(d.srv, d) - csi.RegisterIdentityServer(grpcServer, cs) - switch cs.options.Mode { + switch d.options.Mode { case ControllerMode: - csi.RegisterControllerServer(grpcServer, cs.controller) + csi.RegisterControllerServer(d.srv, d.controller) case NodeMode: - csi.RegisterNodeServer(grpcServer, cs.node) + csi.RegisterNodeServer(d.srv, d.node) case AllMode: - csi.RegisterControllerServer(grpcServer, cs.controller) - csi.RegisterNodeServer(grpcServer, cs.node) + csi.RegisterControllerServer(d.srv, d.controller) + csi.RegisterNodeServer(d.srv, d.node) default: - return fmt.Errorf("unknown mode: %s", cs.options.Mode) + return fmt.Errorf("unknown mode: %s", d.options.Mode) } logger.Info("Listening for connections", "address", listener.Addr()) - return grpcServer.Serve(listener) + return d.srv.Serve(listener) } func validateMode(mode Mode) error { diff --git a/pkg/driver/driver_test.go b/pkg/driver/driver_test.go new file mode 100644 index 0000000..0fc5abe --- /dev/null +++ b/pkg/driver/driver_test.go @@ -0,0 +1,82 @@ +package driver + +import ( + "context" + "testing" + + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/require" + + "github.com/leaseweb/cloudstack-csi-driver/pkg/cloud/fake" + "github.com/leaseweb/cloudstack-csi-driver/pkg/mount" +) + +func TestNewDriver(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + ctx := context.Background() + fakeCloud := fake.New() + fakeMounter := mount.NewFake() + testCases := []struct { + name string + o *Options + expectError bool + hasController bool + hasNode bool + }{ + { + name: "Valid driver controllerMode", + o: &Options{ + Mode: ControllerMode, + VolumeAttachLimit: 16, + }, + expectError: false, + hasController: true, + hasNode: false, + }, + { + name: "Valid driver nodeMode", + o: &Options{ + Mode: NodeMode, + }, + expectError: false, + hasController: false, + hasNode: true, + }, + { + name: "Valid driver allMode", + o: &Options{ + Mode: AllMode, + VolumeAttachLimit: 16, + }, + expectError: false, + hasController: true, + hasNode: true, + }, + { + name: "Invalid driver options", + o: &Options{ + Mode: "InvalidMode", + }, + expectError: true, + hasController: false, + hasNode: false, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + driver, err := NewDriver(ctx, fakeCloud, tc.o, fakeMounter) + if tc.hasNode && driver.node == nil { + t.Fatalf("Expected driver to have node but driver does not have node") + } + if tc.hasController && driver.controller == nil { + t.Fatalf("Expected driver to have controller but driver does not have controller") + } + if tc.expectError { + require.Error(t, err) + } else { + require.NoError(t, err) + } + }) + } +} diff --git a/pkg/driver/identity.go b/pkg/driver/identity.go index e265275..e0ebdbc 100644 --- a/pkg/driver/identity.go +++ b/pkg/driver/identity.go @@ -7,7 +7,7 @@ import ( "k8s.io/klog/v2" ) -func (cs *cloudstackDriver) GetPluginInfo(ctx context.Context, req *csi.GetPluginInfoRequest) (*csi.GetPluginInfoResponse, error) { +func (d *Driver) GetPluginInfo(ctx context.Context, req *csi.GetPluginInfoRequest) (*csi.GetPluginInfoResponse, error) { logger := klog.FromContext(ctx) logger.V(6).Info("GetPluginInfo: called", "args", *req) resp := &csi.GetPluginInfoResponse{ @@ -18,14 +18,14 @@ func (cs *cloudstackDriver) GetPluginInfo(ctx context.Context, req *csi.GetPlugi return resp, nil } -func (cs *cloudstackDriver) Probe(ctx context.Context, req *csi.ProbeRequest) (*csi.ProbeResponse, error) { +func (d *Driver) Probe(ctx context.Context, req *csi.ProbeRequest) (*csi.ProbeResponse, error) { logger := klog.FromContext(ctx) logger.V(6).Info("Probe: called", "args", *req) return &csi.ProbeResponse{}, nil } -func (cs *cloudstackDriver) GetPluginCapabilities(ctx context.Context, req *csi.GetPluginCapabilitiesRequest) (*csi.GetPluginCapabilitiesResponse, error) { +func (d *Driver) GetPluginCapabilities(ctx context.Context, req *csi.GetPluginCapabilitiesRequest) (*csi.GetPluginCapabilitiesResponse, error) { logger := klog.FromContext(ctx) logger.V(6).Info("Probe: called", "args", *req) diff --git a/pkg/driver/node.go b/pkg/driver/node.go index 6050af8..5704b53 100644 --- a/pkg/driver/node.go +++ b/pkg/driver/node.go @@ -30,22 +30,23 @@ var ValidFSTypes = map[string]struct{}{ FSTypeXfs: {}, } -type nodeServer struct { +// NodeService represents the node service of CSI driver. +type NodeService struct { csi.UnimplementedNodeServer - connector cloud.Interface - mounter mount.Interface + connector cloud.Cloud + mounter mount.Mounter maxVolumesPerNode int64 nodeName string volumeLocks *util.VolumeLocks } -// NewNodeServer creates a new Node gRPC server. -func NewNodeServer(connector cloud.Interface, mounter mount.Interface, options *Options) csi.NodeServer { +// NewNodeService creates a new node service. +func NewNodeService(connector cloud.Cloud, mounter mount.Mounter, options *Options) *NodeService { if mounter == nil { mounter = mount.New() } - return &nodeServer{ + return &NodeService{ connector: connector, mounter: mounter, maxVolumesPerNode: options.VolumeAttachLimit, @@ -54,9 +55,9 @@ func NewNodeServer(connector cloud.Interface, mounter mount.Interface, options * } } -func (ns *nodeServer) NodeStageVolume(ctx context.Context, req *csi.NodeStageVolumeRequest) (*csi.NodeStageVolumeResponse, error) { +func (ns *NodeService) NodeStageVolume(ctx context.Context, req *csi.NodeStageVolumeRequest) (*csi.NodeStageVolumeResponse, error) { logger := klog.FromContext(ctx) - logger.V(6).Info("NodeStageVolume: called", "args", *req) + logger.V(4).Info("NodeStageVolume: called", "args", *req) // Check parameters @@ -106,11 +107,12 @@ func (ns *nodeServer) NodeStageVolume(ctx context.Context, req *csi.NodeStageVol } if acquired := ns.volumeLocks.TryAcquire(volumeID); !acquired { - logger.Error(errors.New(util.ErrVolumeOperationAlreadyExistsVolumeID), "failed to acquire volume lock", "volumeID", volumeID) - - return nil, status.Errorf(codes.Aborted, util.VolumeOperationAlreadyExistsFmt, volumeID) + return nil, status.Errorf(codes.Aborted, util.VolumeOperationAlreadyExistsVolumeIDFmt, volumeID) } - defer ns.volumeLocks.Release(volumeID) + defer func() { + logger.V(4).Info("NodeStageVolume: volume operation finished", "volumeId", volumeID) + ns.volumeLocks.Release(volumeID) + }() // Now, find the device path source, err := ns.mounter.GetDevicePath(ctx, volumeID) @@ -139,7 +141,7 @@ func (ns *nodeServer) NodeStageVolume(ctx context.Context, req *csi.NodeStageVol } // Check if a device is mounted in target directory - device, _, err := ns.mounter.GetDeviceName(target) + device, _, err := ns.mounter.GetDeviceNameFromMount(target) if err != nil { msg := fmt.Sprintf("failed to check if volume is already mounted: %v", err) @@ -180,22 +182,9 @@ func (ns *nodeServer) NodeStageVolume(ctx context.Context, req *csi.NodeStageVol return &csi.NodeStageVolumeResponse{}, nil } -// hasMountOption returns a boolean indicating whether the given -// slice already contains a mount option. This is used to prevent -// passing duplicate option to the mount command. -func hasMountOption(options []string, opt string) bool { - for _, o := range options { - if o == opt { - return true - } - } - - return false -} - -func (ns *nodeServer) NodeUnstageVolume(ctx context.Context, req *csi.NodeUnstageVolumeRequest) (*csi.NodeUnstageVolumeResponse, error) { +func (ns *NodeService) NodeUnstageVolume(ctx context.Context, req *csi.NodeUnstageVolumeRequest) (*csi.NodeUnstageVolumeResponse, error) { logger := klog.FromContext(ctx) - logger.V(6).Info("NodeUnstageVolume: called", "args", *req) + logger.V(4).Info("NodeUnstageVolume: called", "args", *req) // Check parameters volumeID := req.GetVolumeId() @@ -209,16 +198,17 @@ func (ns *nodeServer) NodeUnstageVolume(ctx context.Context, req *csi.NodeUnstag } if acquired := ns.volumeLocks.TryAcquire(volumeID); !acquired { - logger.Error(errors.New(util.ErrVolumeOperationAlreadyExistsVolumeID), "failed to acquire volume lock", "volumeID", volumeID) - - return nil, status.Errorf(codes.Aborted, util.VolumeOperationAlreadyExistsFmt, volumeID) + return nil, status.Errorf(codes.Aborted, util.VolumeOperationAlreadyExistsVolumeIDFmt, volumeID) } - defer ns.volumeLocks.Release(volumeID) + defer func() { + logger.V(4).Info("NodeUnstageVolume: volume operation finished", "volumeId", volumeID) + ns.volumeLocks.Release(volumeID) + }() // Check if target directory is a mount point. GetDeviceNameFromMount // given a mnt point, finds the device from /proc/mounts // returns the device name, reference count, and error code - dev, refCount, err := ns.mounter.GetDeviceName(target) + dev, refCount, err := ns.mounter.GetDeviceNameFromMount(target) if err != nil { msg := fmt.Sprintf("failed to check if target %q is a mount point: %v", target, err) @@ -253,36 +243,9 @@ func (ns *nodeServer) NodeUnstageVolume(ctx context.Context, req *csi.NodeUnstag return &csi.NodeUnstageVolumeResponse{}, nil } -func (ns *nodeServer) isMounted(ctx context.Context, target string) (bool, error) { +func (ns *NodeService) NodePublishVolume(ctx context.Context, req *csi.NodePublishVolumeRequest) (*csi.NodePublishVolumeResponse, error) { logger := klog.FromContext(ctx) - - notMnt, err := ns.mounter.IsLikelyNotMountPoint(target) - if err != nil { - if os.IsNotExist(err) { - return false, err - } - - // Checking if the path exists and error is related to Corrupted Mount, in that case, the system could unmount and mount. - _, pathErr := ns.mounter.PathExists(target) - if pathErr != nil && ns.mounter.IsCorruptedMnt(pathErr) { - logger.V(4).Info("NodePublishVolume: Target path is a corrupted mount. Trying to unmount.", "target", target) - if mntErr := ns.mounter.Unpublish(target); mntErr != nil { - return false, fmt.Errorf("unable to unmount the target %q : %w", target, mntErr) - } - - // After successful unmount, the device is ready to be mounted. - return false, nil - } - - return false, fmt.Errorf("could not check if %q is a mount point: %w, %w", target, err, pathErr) - } - - return !notMnt, nil -} - -func (ns *nodeServer) NodePublishVolume(ctx context.Context, req *csi.NodePublishVolumeRequest) (*csi.NodePublishVolumeResponse, error) { //nolint:gocyclo,gocognit - logger := klog.FromContext(ctx) - logger.V(6).Info("NodePublishVolume: called", "args", *req) + logger.V(4).Info("NodePublishVolume: called", "args", *req) // Check arguments volumeID := req.GetVolumeId() @@ -309,134 +272,36 @@ func (ns *nodeServer) NodePublishVolume(ctx context.Context, req *csi.NodePublis return nil, status.Error(codes.InvalidArgument, "Volume capability not supported") } + if acquired := ns.volumeLocks.TryAcquire(target); !acquired { + return nil, status.Errorf(codes.Aborted, util.TargetPathOperationAlreadyExistsFmt, target) + } + defer func() { + logger.V(4).Info("NodePublishVolume: volume operation finished", "volumeId", volumeID) + ns.volumeLocks.Release(target) + }() + mountOptions := []string{"bind"} if req.GetReadonly() { mountOptions = append(mountOptions, "ro") } - // Considering kubelet ensures the stage and publish operations - // are serialized, we don't need any extra locking in NodePublishVolume. - - switch req.GetVolumeCapability().GetAccessType().(type) { + switch mode := req.GetVolumeCapability().GetAccessType().(type) { case *csi.VolumeCapability_Mount: - mounted, err := ns.isMounted(ctx, target) - if err != nil { - if os.IsNotExist(err) { - if err := ns.mounter.MakeDir(target); err != nil { - return nil, status.Errorf(codes.Internal, "Could not create dir %q: %v", target, err) - } - } else { - return nil, status.Errorf(codes.Internal, "Could not check if %q is mounted: %v", target, err) - } - } - - if mounted { - logger.Info("NodePublishVolume: volume is already mounted", - "source", source, - "target", target, - ) - - return &csi.NodePublishVolumeResponse{}, nil - } - - mnt := volCap.GetMount() - if mnt == nil { - return nil, status.Error(codes.InvalidArgument, "NodePublishVolume: mount volume capability not found") - } - if mnt := volCap.GetMount(); mnt != nil { - for _, f := range mnt.GetMountFlags() { - if !hasMountOption(mountOptions, f) { - mountOptions = append(mountOptions, f) - } - } - } - - fsType := mnt.GetFsType() - if fsType == "" { - fsType = defaultFsType - } - - _, ok := ValidFSTypes[strings.ToLower(fsType)] - if !ok { - return nil, status.Errorf(codes.InvalidArgument, "NodePublishVolume: invalid fstype %s", fsType) - } - - logger.V(4).Info("NodePublishVolume: mounting source", - "source", source, - "target", target, - "fsType", fsType, - "mountOptions", mountOptions, - "volumeID", volumeID, - ) - - if err := ns.mounter.Mount(source, target, fsType, mountOptions); err != nil { - return nil, status.Errorf(codes.Internal, "failed to mount %q at %q: %v", source, target, err) + if err := ns.nodePublishVolumeForMount(ctx, req, mountOptions, mode); err != nil { + return nil, err } case *csi.VolumeCapability_Block: - source, err := ns.mounter.GetDevicePath(ctx, volumeID) - if err != nil { - return nil, status.Errorf(codes.Internal, "Cannot find device path for volume %s: %v", volumeID, err) - } - - globalMountPath := filepath.Dir(target) - exists, err := ns.mounter.PathExists(globalMountPath) - if err != nil { - return nil, status.Errorf(codes.Internal, "Could not check if path exists %q: %v", globalMountPath, err) - } - if !exists { - if err = ns.mounter.MakeDir(globalMountPath); err != nil { - return nil, status.Errorf(codes.Internal, "Could not create dir %q: %v", globalMountPath, err) - } - } - - mounted, err := ns.isMounted(ctx, target) - if err != nil { //nolint:nestif - if os.IsNotExist(err) { - // Create the mount point as a file since bind mount device node requires it to be a file - logger.V(4).Info("NodePublishVolume: making target file", "target", target) - err = ns.mounter.MakeFile(target) - if err != nil { - if removeErr := os.Remove(target); removeErr != nil { - return nil, status.Errorf(codes.Internal, "Could not remove mount target %q: %v", target, removeErr) - } - - return nil, status.Errorf(codes.Internal, "Could not create file %q: %v", target, err) - } - } else { - return nil, status.Errorf(codes.Internal, "Could not check if %q is mounted: %v", target, err) - } - } - - if mounted { - logger.Info("NodePublishVolume: volume is already mounted", - "source", source, - "target", target, - ) - - return &csi.NodePublishVolumeResponse{}, nil - } - - logger.Info("NodePublishVolume: mounting device", - "source", source, - "target", target, - "volumeID", volumeID, - ) - - if err := ns.mounter.Mount(source, target, "", mountOptions); err != nil { - if removeErr := os.Remove(target); removeErr != nil { - return nil, status.Errorf(codes.Internal, "Could not remove mount target %q: %v", target, removeErr) - } - - return nil, status.Errorf(codes.Internal, "failed to mount %q at %q: %v", source, target, err) + if err := ns.nodePublishVolumeForBlock(ctx, req, mountOptions); err != nil { + return nil, err } } return &csi.NodePublishVolumeResponse{}, nil } -func (ns *nodeServer) NodeUnpublishVolume(ctx context.Context, req *csi.NodeUnpublishVolumeRequest) (*csi.NodeUnpublishVolumeResponse, error) { +func (ns *NodeService) NodeUnpublishVolume(ctx context.Context, req *csi.NodeUnpublishVolumeRequest) (*csi.NodeUnpublishVolumeResponse, error) { logger := klog.FromContext(ctx) - logger.V(6).Info("NodeUnpublishVolume: called", "args", *req) + logger.V(4).Info("NodeUnpublishVolume: called", "args", *req) volumeID := req.GetVolumeId() if volumeID == "" { @@ -447,8 +312,13 @@ func (ns *nodeServer) NodeUnpublishVolume(ctx context.Context, req *csi.NodeUnpu return nil, status.Error(codes.InvalidArgument, "Target path missing in request") } - // Considering that kubelet ensures the stage and publish operations - // are serialized, we don't need any extra locking in NodeUnpublishVolume. + if acquired := ns.volumeLocks.TryAcquire(target); !acquired { + return nil, status.Errorf(codes.Aborted, util.TargetPathOperationAlreadyExistsFmt, target) + } + defer func() { + logger.V(4).Info("NodeUnpublishVolume: volume operation finished", "volumeId", volumeID) + ns.volumeLocks.Release(target) + }() logger.V(4).Info("NodeUnpublishVolume: unmounting volume", "target", target, @@ -463,9 +333,9 @@ func (ns *nodeServer) NodeUnpublishVolume(ctx context.Context, req *csi.NodeUnpu return &csi.NodeUnpublishVolumeResponse{}, nil } -func (ns *nodeServer) NodeGetInfo(ctx context.Context, req *csi.NodeGetInfoRequest) (*csi.NodeGetInfoResponse, error) { +func (ns *NodeService) NodeGetInfo(ctx context.Context, req *csi.NodeGetInfoRequest) (*csi.NodeGetInfoResponse, error) { logger := klog.FromContext(ctx) - logger.V(6).Info("NodeGetInfo: called", "args", *req) + logger.V(4).Info("NodeGetInfo: called", "args", *req) if ns.nodeName == "" { return nil, status.Error(codes.Internal, "Missing node name") @@ -491,9 +361,9 @@ func (ns *nodeServer) NodeGetInfo(ctx context.Context, req *csi.NodeGetInfoReque }, nil } -func (ns *nodeServer) NodeExpandVolume(ctx context.Context, req *csi.NodeExpandVolumeRequest) (*csi.NodeExpandVolumeResponse, error) { +func (ns *NodeService) NodeExpandVolume(ctx context.Context, req *csi.NodeExpandVolumeRequest) (*csi.NodeExpandVolumeResponse, error) { logger := klog.FromContext(ctx) - logger.V(6).Info("NodeExpandVolume: called", "args", *req) + logger.V(4).Info("NodeExpandVolume: called", "args", *req) volumeID := req.GetVolumeId() if len(volumeID) == 0 { @@ -543,11 +413,12 @@ func (ns *nodeServer) NodeExpandVolume(ctx context.Context, req *csi.NodeExpandV } if acquired := ns.volumeLocks.TryAcquire(volumeID); !acquired { - logger.Error(errors.New(util.ErrVolumeOperationAlreadyExistsVolumeID), "failed to acquire volume lock", "volumeID", volumeID) - - return nil, status.Errorf(codes.Aborted, util.VolumeOperationAlreadyExistsFmt, volumeID) + return nil, status.Errorf(codes.Aborted, util.VolumeOperationAlreadyExistsVolumeIDFmt, volumeID) } - defer ns.volumeLocks.Release(volumeID) + defer func() { + logger.V(4).Info("NodeExpandVolume: volume operation finished", "volumeId", volumeID) + ns.volumeLocks.Release(volumeID) + }() _, err := ns.connector.GetVolumeByID(ctx, volumeID) if err != nil { @@ -581,36 +452,29 @@ func (ns *nodeServer) NodeExpandVolume(ctx context.Context, req *csi.NodeExpandV return &csi.NodeExpandVolumeResponse{CapacityBytes: bcap}, nil } -func (ns *nodeServer) NodeGetVolumeStats(ctx context.Context, req *csi.NodeGetVolumeStatsRequest) (*csi.NodeGetVolumeStatsResponse, error) { +func (ns *NodeService) NodeGetVolumeStats(ctx context.Context, req *csi.NodeGetVolumeStatsRequest) (*csi.NodeGetVolumeStatsResponse, error) { logger := klog.FromContext(ctx) - logger.V(6).Info("NodeGetVolumeStats: called", "args", *req) + logger.V(4).Info("NodeGetVolumeStats: called", "args", *req) if req.GetVolumeId() == "" { return nil, status.Error(codes.InvalidArgument, "Volume ID missing in request") } - // Get volume path - // This should work for Kubernetes >= 1.26, see https://github.com/kubernetes/kubernetes/issues/115343 - volumePath := req.GetStagingTargetPath() - if volumePath == "" { - // Except that it doesn't work in the sanity test, so we need a fallback to volumePath. - volumePath = req.GetVolumePath() - } - if len(volumePath) == 0 { - return nil, status.Error(codes.InvalidArgument, "Volume path not provided") + if req.GetVolumePath() == "" { + return nil, status.Error(codes.InvalidArgument, "Volume Path missing in request") } - exists, err := ns.mounter.PathExists(volumePath) + exists, err := ns.mounter.PathExists(req.GetVolumePath()) if err != nil { - return nil, status.Errorf(codes.Internal, "unknown error when stat on %s: %v", volumePath, err) + return nil, status.Errorf(codes.Internal, "unknown error when stat on %s: %v", req.GetVolumePath(), err) } if !exists { - return nil, status.Errorf(codes.NotFound, "path %s does not exist", volumePath) + return nil, status.Errorf(codes.NotFound, "path %s does not exist", req.GetVolumePath()) } - isBlock, err := ns.mounter.IsBlockDevice(volumePath) + isBlock, err := ns.mounter.IsBlockDevice(req.GetVolumePath()) if err != nil { - return nil, status.Errorf(codes.Internal, "failed to determine if %q is block device: %s", volumePath, err) + return nil, status.Errorf(codes.Internal, "failed to determine if %q is block device: %s", req.GetVolumePath(), err) } if isBlock { @@ -629,9 +493,9 @@ func (ns *nodeServer) NodeGetVolumeStats(ctx context.Context, req *csi.NodeGetVo }, nil } - stats, err := ns.mounter.GetStatistics(volumePath) + stats, err := ns.mounter.GetStatistics(req.GetVolumePath()) if err != nil { - return nil, status.Errorf(codes.Internal, "failed to retrieve capacity statistics for volume path %q: %s", volumePath, err) + return nil, status.Errorf(codes.Internal, "failed to retrieve capacity statistics for volume path %q: %s", req.GetVolumePath(), err) } return &csi.NodeGetVolumeStatsResponse{ @@ -652,7 +516,7 @@ func (ns *nodeServer) NodeGetVolumeStats(ctx context.Context, req *csi.NodeGetVo }, nil } -func (ns *nodeServer) NodeGetCapabilities(_ context.Context, _ *csi.NodeGetCapabilitiesRequest) (*csi.NodeGetCapabilitiesResponse, error) { +func (ns *NodeService) NodeGetCapabilities(_ context.Context, _ *csi.NodeGetCapabilitiesRequest) (*csi.NodeGetCapabilitiesResponse, error) { resp := &csi.NodeGetCapabilitiesResponse{ Capabilities: []*csi.NodeServiceCapability{ { @@ -681,3 +545,182 @@ func (ns *nodeServer) NodeGetCapabilities(_ context.Context, _ *csi.NodeGetCapab return resp, nil } + +func (ns *NodeService) nodePublishVolumeForBlock(ctx context.Context, req *csi.NodePublishVolumeRequest, mountOptions []string) error { + logger := klog.FromContext(ctx) + target := req.GetTargetPath() + volumeID := req.GetVolumeId() + source, err := ns.mounter.GetDevicePath(ctx, volumeID) + if err != nil { + return status.Errorf(codes.Internal, "Cannot find device path for volume %s: %v", volumeID, err) + } + + globalMountPath := filepath.Dir(target) + exists, err := ns.mounter.PathExists(globalMountPath) + if err != nil { + return status.Errorf(codes.Internal, "Could not check if path exists %q: %v", globalMountPath, err) + } + if !exists { + if err = ns.mounter.MakeDir(globalMountPath); err != nil { + return status.Errorf(codes.Internal, "Could not create dir %q: %v", globalMountPath, err) + } + } + + // Create the mount point as a file since bind mount device node requires it to be a file + logger.V(4).Info("NodePublishVolume: making target file", "target", target) + err = ns.mounter.MakeFile(target) + if err != nil { + if removeErr := os.Remove(target); removeErr != nil { + return status.Errorf(codes.Internal, "Could not remove mount target %q: %v", target, removeErr) + } + + return status.Errorf(codes.Internal, "Could not create file %q: %v", target, err) + } + + mounted, err := ns.isMounted(ctx, target) + if err != nil { + return status.Errorf(codes.Internal, "Could not check if %q is mounted: %v", target, err) + } + + if !mounted { + logger.V(4).Info("NodePublishVolume: mounting block device", + "source", source, + "target", target, + ) + + if err := ns.mounter.Mount(source, target, "", mountOptions); err != nil { + if removeErr := os.Remove(target); removeErr != nil { + return status.Errorf(codes.Internal, "Could not remove mount target %q: %v", target, removeErr) + } + + return status.Errorf(codes.Internal, "failed to mount %q at %q: %v", source, target, err) + } + } else { + logger.V(4).Info("NodePublishVolume: target path already mounted", "target", target) + } + + return nil +} + +func (ns *NodeService) nodePublishVolumeForMount(ctx context.Context, req *csi.NodePublishVolumeRequest, mountOptions []string, mode *csi.VolumeCapability_Mount) error { + logger := klog.FromContext(ctx) + target := req.GetTargetPath() + source := req.GetStagingTargetPath() + + if m := mode.Mount; m != nil { + for _, f := range m.GetMountFlags() { + if !hasMountOption(mountOptions, f) { + mountOptions = append(mountOptions, f) + } + } + } + + // Prepare the publish target + logger.V(4).Info("NodePublishVolume: creating dir", "target", target) + if err := ns.mounter.MakeDir(target); err != nil { + return status.Errorf(codes.Internal, "Could not create dir %q: %v", target, err) + } + + mounted, err := ns.isMounted(ctx, target) + if err != nil { + return status.Errorf(codes.Internal, "Could not check if %q is mounted: %v", target, err) + } + + if !mounted { + fsType := mode.Mount.GetFsType() + if fsType == "" { + fsType = defaultFsType + } + + _, ok := ValidFSTypes[strings.ToLower(fsType)] + if !ok { + return status.Errorf(codes.InvalidArgument, "NodePublishVolume: invalid fstype %s", fsType) + } + + mountOptions = collectMountOptions(fsType, mountOptions) + + logger.V(4).Info("NodePublishVolume: mounting source", + "source", source, + "target", target, + "fsType", fsType, + "mountOptions", mountOptions, + ) + + if err := ns.mounter.Mount(source, target, fsType, mountOptions); err != nil { + return status.Errorf(codes.Internal, "Failed to mount %q at %q: %v", source, target, err) + } + } + + return nil +} + +func (ns *NodeService) isMounted(ctx context.Context, target string) (bool, error) { + logger := klog.FromContext(ctx) + + notMnt, err := ns.mounter.IsLikelyNotMountPoint(target) + if err != nil && !os.IsNotExist(err) { + // Checking if the path exists and error is related to Corrupted Mount, in that case, the system could unmount and mount. + _, pathErr := ns.mounter.PathExists(target) + if pathErr != nil && ns.mounter.IsCorruptedMnt(pathErr) { + logger.V(4).Info("NodePublishVolume: Target path is a corrupted mount. Trying to unmount.", "target", target) + if mntErr := ns.mounter.Unpublish(target); mntErr != nil { + return false, fmt.Errorf("unable to unmount the target %q : %w", target, mntErr) + } + + // After successful unmount, the device is ready to be mounted. + return false, nil + } + + return false, fmt.Errorf("could not check if %q is a mount point: %w, %w", target, err, pathErr) + } + + // Do not return os.IsNotExist error. Other errors were handled above. The + // existence of the target should be checked by the caller explicitly and + // independently because sometimes prior to mount it is expected not to exist. + if err != nil && os.IsNotExist(err) { + logger.V(5).Info("[Debug] NodePublishVolume: Target path does not exist", "target", target) + + return false, nil + } + + if !notMnt { + logger.V(4).Info("NodePublishVolume: Target path is already mounted", "target", target) + } + + return !notMnt, nil +} + +// hasMountOption returns a boolean indicating whether the given +// slice already contains a mount option. This is used to prevent +// passing duplicate option to the mount command. +func hasMountOption(options []string, opt string) bool { + for _, o := range options { + if o == opt { + return true + } + } + + return false +} + +// collectMountOptions returns array of mount options from +// VolumeCapability_MountVolume and special mount options for +// given filesystem. +func collectMountOptions(fsType string, mntFlags []string) []string { + var options []string + for _, opt := range mntFlags { + if !hasMountOption(options, opt) { + options = append(options, opt) + } + } + + // By default, xfs does not allow mounting of two volumes with the same filesystem uuid. + // Force ignore this uuid to be able to mount volume + its clone / restored snapshot on the same node. + if fsType == FSTypeXfs { + if !hasMountOption(options, "nouuid") { + options = append(options, "nouuid") + } + } + + return options +} diff --git a/pkg/driver/node_test.go b/pkg/driver/node_test.go new file mode 100644 index 0000000..9ec603c --- /dev/null +++ b/pkg/driver/node_test.go @@ -0,0 +1,89 @@ +package driver + +import ( + "context" + "os" + "path/filepath" + "testing" + + "github.com/container-storage-interface/spec/lib/go/csi" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "k8s.io/klog/v2" + "k8s.io/klog/v2/ktesting" + + cloud "github.com/leaseweb/cloudstack-csi-driver/pkg/cloud/fake" + "github.com/leaseweb/cloudstack-csi-driver/pkg/mount" + "github.com/leaseweb/cloudstack-csi-driver/pkg/util" +) + +const ( + //nolint:godox + // TODO: Adjusted this to paths in /tmp until https://github.com/kubernetes/kubernetes/pull/128286 + // is solved. + sourceTest = "/tmp/source_test" + targetTest = "/tmp/target_test" +) + +func TestNodePublishVolumeIdempotentMount(t *testing.T) { + if os.Getuid() != 0 { + t.Skip("Test requires root") + } + logger := ktesting.NewLogger(t, ktesting.NewConfig(ktesting.Verbosity(10), ktesting.BufferLogs(true))) + ctx := klog.NewContext(context.Background(), logger) + + driver := &NodeService{ + connector: cloud.New(), + mounter: mount.New(), + volumeLocks: util.NewVolumeLocks(), + } + + err := driver.mounter.MakeDir(sourceTest) + require.NoError(t, err) + err = driver.mounter.MakeDir(targetTest) + require.NoError(t, err) + + volCapAccessMode := csi.VolumeCapability_AccessMode{Mode: csi.VolumeCapability_AccessMode_SINGLE_NODE_WRITER} + volCapAccessType := csi.VolumeCapability_Mount{Mount: &csi.VolumeCapability_MountVolume{}} + req := csi.NodePublishVolumeRequest{ + VolumeCapability: &csi.VolumeCapability{AccessMode: &volCapAccessMode, AccessType: &volCapAccessType}, + VolumeId: "vol_1", + TargetPath: targetTest, + StagingTargetPath: sourceTest, + Readonly: true, + } + + underlyingLogger, ok := logger.GetSink().(ktesting.Underlier) + if !ok { + t.Fatalf("should have had ktesting LogSink, got %T", logger.GetSink()) + } + + _, err = driver.NodePublishVolume(ctx, &req) + require.NoError(t, err) + _, err = driver.NodePublishVolume(ctx, &req) + require.NoError(t, err) + + logEntries := underlyingLogger.GetBuffer().String() + assert.Contains(t, logEntries, "Target path is already mounted") + + // ensure the target not be mounted twice + targetAbs, err := filepath.Abs(targetTest) + require.NoError(t, err) + + mountList, err := driver.mounter.List() + require.NoError(t, err) + mountPointNum := 0 + for _, mountPoint := range mountList { + if mountPoint.Path == targetAbs { + mountPointNum++ + } + } + assert.Equal(t, 1, mountPointNum) + err = driver.mounter.Unmount(targetTest) + require.NoError(t, err) + _ = driver.mounter.Unmount(targetTest) + err = os.RemoveAll(sourceTest) + require.NoError(t, err) + err = os.RemoveAll(targetTest) + require.NoError(t, err) +} diff --git a/pkg/mount/fake.go b/pkg/mount/fake.go index 8129033..8253756 100644 --- a/pkg/mount/fake.go +++ b/pkg/mount/fake.go @@ -17,8 +17,8 @@ type fakeMounter struct { } // NewFake creates a fake implementation of the -// mount.Interface, to be used in tests. -func NewFake() Interface { +// mount.Mounter, to be used in tests. +func NewFake() Mounter { return &fakeMounter{ mount.SafeFormatAndMount{ Interface: mount.NewFakeMounter([]mount.MountPoint{}), @@ -35,7 +35,7 @@ func (m *fakeMounter) GetDevicePath(_ context.Context, _ string) (string, error) return "/dev/sdb", nil } -func (m *fakeMounter) GetDeviceName(mountPath string) (string, int, error) { +func (m *fakeMounter) GetDeviceNameFromMount(mountPath string) (string, int, error) { return mount.GetDeviceNameFromMount(m, mountPath) } @@ -49,8 +49,8 @@ func (*fakeMounter) PathExists(path string) (bool, error) { return true, nil } -func (*fakeMounter) MakeDir(pathname string) error { - err := os.MkdirAll(pathname, os.FileMode(0o755)) +func (*fakeMounter) MakeDir(path string) error { + err := os.MkdirAll(path, os.FileMode(0o755)) if err != nil { if !os.IsExist(err) { return err @@ -60,8 +60,8 @@ func (*fakeMounter) MakeDir(pathname string) error { return nil } -func (*fakeMounter) MakeFile(pathname string) error { - file, err := os.OpenFile(pathname, os.O_CREATE, os.FileMode(0o644)) +func (*fakeMounter) MakeFile(path string) error { + file, err := os.OpenFile(path, os.O_CREATE, os.FileMode(0o644)) if err != nil { if !os.IsExist(err) { return err @@ -74,8 +74,8 @@ func (*fakeMounter) MakeFile(pathname string) error { return nil } -func (m *fakeMounter) GetStatistics(_ string) (volumeStatistics, error) { - return volumeStatistics{ +func (m *fakeMounter) GetStatistics(_ string) (VolumeStatistics, error) { + return VolumeStatistics{ AvailableBytes: 3 * giB, TotalBytes: 10 * giB, UsedBytes: 7 * giB, diff --git a/pkg/mount/mount.go b/pkg/mount/mount.go index 6e45aaf..e507163 100644 --- a/pkg/mount/mount.go +++ b/pkg/mount/mount.go @@ -23,20 +23,20 @@ const ( diskIDPath = "/dev/disk/by-id" ) -// Interface defines the set of methods to allow for +// Mounter defines the set of methods to allow for // mount operations on a system. -type Interface interface { //nolint:interfacebloat +type Mounter interface { //nolint:interfacebloat mount.Interface FormatAndMount(source string, target string, fstype string, options []string) error GetBlockSizeBytes(devicePath string) (int64, error) GetDevicePath(ctx context.Context, volumeID string) (string, error) - GetDeviceName(mountPath string) (string, int, error) - GetStatistics(volumePath string) (volumeStatistics, error) + GetDeviceNameFromMount(mountPath string) (string, int, error) + GetStatistics(volumePath string) (VolumeStatistics, error) IsBlockDevice(devicePath string) (bool, error) IsCorruptedMnt(err error) bool - MakeDir(pathname string) error - MakeFile(pathname string) error + MakeDir(path string) error + MakeFile(path string) error NeedResize(devicePath string, deviceMountPath string) (bool, error) PathExists(path string) (bool, error) Resize(devicePath, deviceMountPath string) (bool, error) @@ -44,18 +44,20 @@ type Interface interface { //nolint:interfacebloat Unstage(path string) error } -type mounter struct { +// NodeMounter implements Mounter. +// A superstruct of SafeFormatAndMount. +type NodeMounter struct { *mount.SafeFormatAndMount } -type volumeStatistics struct { +type VolumeStatistics struct { AvailableBytes, TotalBytes, UsedBytes int64 AvailableInodes, TotalInodes, UsedInodes int64 } -// New creates an implementation of the mount.Interface. -func New() Interface { - return &mounter{ +// New creates an implementation of the mount.Mounter. +func New() Mounter { + return &NodeMounter{ &mount.SafeFormatAndMount{ Interface: mount.New(""), Exec: kexec.New(), @@ -64,7 +66,7 @@ func New() Interface { } // GetBlockSizeBytes gets the size of the disk in bytes. -func (m *mounter) GetBlockSizeBytes(devicePath string) (int64, error) { +func (m *NodeMounter) GetBlockSizeBytes(devicePath string) (int64, error) { output, err := m.Exec.Command("blockdev", "--getsize64", devicePath).Output() if err != nil { return -1, fmt.Errorf("error when getting size of block volume at path %s: output: %s, err: %w", devicePath, string(output), err) @@ -78,7 +80,7 @@ func (m *mounter) GetBlockSizeBytes(devicePath string) (int64, error) { return gotSizeBytes, nil } -func (m *mounter) GetDevicePath(ctx context.Context, volumeID string) (string, error) { +func (m *NodeMounter) GetDevicePath(ctx context.Context, volumeID string) (string, error) { backoff := wait.Backoff{ Duration: 1 * time.Second, Factor: 1.1, @@ -110,7 +112,7 @@ func (m *mounter) GetDevicePath(ctx context.Context, volumeID string) (string, e return devicePath, nil } -func (m *mounter) getDevicePathBySerialID(volumeID string) (string, error) { +func (m *NodeMounter) getDevicePathBySerialID(volumeID string) (string, error) { sourcePathPrefixes := []string{"virtio-", "scsi-", "scsi-0QEMU_QEMU_HARDDISK_"} serial := diskUUIDToSerial(volumeID) for _, prefix := range sourcePathPrefixes { @@ -127,7 +129,7 @@ func (m *mounter) getDevicePathBySerialID(volumeID string) (string, error) { return "", nil } -func (m *mounter) probeVolume(ctx context.Context) { +func (m *NodeMounter) probeVolume(ctx context.Context) { logger := klog.FromContext(ctx) logger.V(2).Info("Scanning SCSI host") @@ -153,7 +155,7 @@ func (m *mounter) probeVolume(ctx context.Context) { } } -func (m *mounter) GetDeviceName(mountPath string) (string, int, error) { +func (m *NodeMounter) GetDeviceNameFromMount(mountPath string) (string, int, error) { return mount.GetDeviceNameFromMount(m, mountPath) } @@ -171,12 +173,12 @@ func diskUUIDToSerial(uuid string) string { return uuidWithoutHyphen[:20] } -func (*mounter) PathExists(path string) (bool, error) { +func (*NodeMounter) PathExists(path string) (bool, error) { return mount.PathExists(path) } -func (*mounter) MakeDir(pathname string) error { - err := os.MkdirAll(pathname, os.FileMode(0o755)) +func (*NodeMounter) MakeDir(path string) error { + err := os.MkdirAll(path, os.FileMode(0o755)) if err != nil { if !os.IsExist(err) { return err @@ -186,8 +188,8 @@ func (*mounter) MakeDir(pathname string) error { return nil } -func (*mounter) MakeFile(pathname string) error { - f, err := os.OpenFile(pathname, os.O_CREATE, os.FileMode(0o644)) +func (*NodeMounter) MakeFile(path string) error { + f, err := os.OpenFile(path, os.O_CREATE, os.FileMode(0o644)) if err != nil { if !os.IsExist(err) { return err @@ -201,35 +203,35 @@ func (*mounter) MakeFile(pathname string) error { } // Resize resizes the filesystem of the given devicePath. -func (m *mounter) Resize(devicePath, deviceMountPath string) (bool, error) { +func (m *NodeMounter) Resize(devicePath, deviceMountPath string) (bool, error) { return mount.NewResizeFs(m.Exec).Resize(devicePath, deviceMountPath) } // NeedResize checks if the filesystem of the given devicePath needs to be resized. -func (m *mounter) NeedResize(devicePath string, deviceMountPath string) (bool, error) { +func (m *NodeMounter) NeedResize(devicePath string, deviceMountPath string) (bool, error) { return mount.NewResizeFs(m.Exec).NeedResize(devicePath, deviceMountPath) } // GetStatistics gathers statistics on the volume. -func (m *mounter) GetStatistics(volumePath string) (volumeStatistics, error) { +func (m *NodeMounter) GetStatistics(volumePath string) (VolumeStatistics, error) { isBlock, err := m.IsBlockDevice(volumePath) if err != nil { - return volumeStatistics{}, fmt.Errorf("failed to determine if volume %s is block device: %w", volumePath, err) + return VolumeStatistics{}, fmt.Errorf("failed to determine if volume %s is block device: %w", volumePath, err) } if isBlock { // See http://man7.org/linux/man-pages/man8/blockdev.8.html for details output, err := exec.Command("blockdev", "getsize64", volumePath).CombinedOutput() if err != nil { - return volumeStatistics{}, fmt.Errorf("error when getting size of block volume at path %s: output: %s, err: %w", volumePath, string(output), err) + return VolumeStatistics{}, fmt.Errorf("error when getting size of block volume at path %s: output: %s, err: %w", volumePath, string(output), err) } strOut := strings.TrimSpace(string(output)) gotSizeBytes, err := strconv.ParseInt(strOut, 10, 64) if err != nil { - return volumeStatistics{}, fmt.Errorf("failed to parse size %s into int", strOut) + return VolumeStatistics{}, fmt.Errorf("failed to parse size %s into int", strOut) } - return volumeStatistics{ + return VolumeStatistics{ TotalBytes: gotSizeBytes, }, nil } @@ -238,10 +240,10 @@ func (m *mounter) GetStatistics(volumePath string) (volumeStatistics, error) { // See http://man7.org/linux/man-pages/man2/statfs.2.html for details. err = unix.Statfs(volumePath, &statfs) if err != nil { - return volumeStatistics{}, err + return VolumeStatistics{}, err } - volStats := volumeStatistics{ + volStats := VolumeStatistics{ AvailableBytes: int64(statfs.Bavail) * int64(statfs.Bsize), //nolint:unconvert TotalBytes: int64(statfs.Blocks) * int64(statfs.Bsize), //nolint:unconvert UsedBytes: (int64(statfs.Blocks) - int64(statfs.Bfree)) * int64(statfs.Bsize), //nolint:unconvert @@ -255,7 +257,7 @@ func (m *mounter) GetStatistics(volumePath string) (volumeStatistics, error) { } // IsBlockDevice checks if the given path is a block device. -func (m *mounter) IsBlockDevice(devicePath string) (bool, error) { +func (m *NodeMounter) IsBlockDevice(devicePath string) (bool, error) { var stat unix.Stat_t err := unix.Stat(devicePath, &stat) if err != nil { @@ -266,16 +268,16 @@ func (m *mounter) IsBlockDevice(devicePath string) (bool, error) { } // IsCorruptedMnt return true if err is about corrupted mount point. -func (m *mounter) IsCorruptedMnt(err error) bool { +func (m *NodeMounter) IsCorruptedMnt(err error) bool { return mount.IsCorruptedMnt(err) } // Unpublish unmounts the given path. -func (m *mounter) Unpublish(path string) error { +func (m *NodeMounter) Unpublish(path string) error { return m.Unstage(path) } // Unstage unmounts the given path. -func (m *mounter) Unstage(path string) error { +func (m *NodeMounter) Unstage(path string) error { return mount.CleanupMountPoint(path, m, true) } diff --git a/pkg/util/idlocker.go b/pkg/util/idlocker.go index 619a3ae..f5325d5 100644 --- a/pkg/util/idlocker.go +++ b/pkg/util/idlocker.go @@ -22,17 +22,17 @@ import ( ) const ( - // ErrVolumeOperationAlreadyExistsVolumeID is the error msg logged for concurrent operation. - ErrVolumeOperationAlreadyExistsVolumeID = "an operation with the given Volume ID already exists" + // VolumeOperationAlreadyExistsVolumeIDFmt string format to return for concurrent operation. + VolumeOperationAlreadyExistsVolumeIDFmt = "an operation with the given volume ID %s already exists" - // ErrVolumeOperationAlreadyExistsVolumeName is the error msg logged for concurrent operation. - ErrVolumeOperationAlreadyExistsVolumeName = "an operation with the given Volume name already exists" - - // VolumeOperationAlreadyExistsFmt string format to return for concurrent operation. - VolumeOperationAlreadyExistsFmt = "an operation with the given Volume ID %s already exists" + // VolumeOperationAlreadyExistsVolumeNameFmt string format to return for concurrent operation. + VolumeOperationAlreadyExistsVolumeNameFmt = "an operation with the given volume name %s already exists" // SnapshotOperationAlreadyExistsFmt string format to return for concurrent operation. - SnapshotOperationAlreadyExistsFmt = "an operation with the given Snapshot ID %s already exists" + SnapshotOperationAlreadyExistsFmt = "an operation with the given snapshot ID %s already exists" + + // TargetPathOperationAlreadyExistsFmt string format to return for concurrent operation on target path. + TargetPathOperationAlreadyExistsFmt = "an operation with the given target path %s already exists" ) // VolumeLocks implements a map with atomic operations. It stores a set of all volume IDs diff --git a/test/sanity/sanity_test.go b/test/sanity/sanity_test.go index 2a6780b..fbf389c 100644 --- a/test/sanity/sanity_test.go +++ b/test/sanity/sanity_test.go @@ -42,11 +42,12 @@ func TestSanity(t *testing.T) { ctx := klog.NewContext(context.Background(), logger) options := driver.Options{ - Mode: driver.AllMode, - Endpoint: endpoint, - NodeName: "node", + Mode: driver.AllMode, + Endpoint: endpoint, + NodeName: "node", + VolumeAttachLimit: 16, } - csiDriver, err := driver.New(ctx, fake.New(), &options, mount.NewFake()) + csiDriver, err := driver.NewDriver(ctx, fake.New(), &options, mount.NewFake()) if err != nil { t.Fatalf("error creating driver: %v", err) }