Skip to content

Commit

Permalink
Increasing UT coverage (#346)
Browse files Browse the repository at this point in the history
Co-authored-by: karthikk92 <[email protected]>
  • Loading branch information
KshitijaKakde and karthikk92 authored Feb 25, 2025
1 parent 2ffa8a3 commit 1bc83a3
Show file tree
Hide file tree
Showing 36 changed files with 7,627 additions and 244 deletions.
26 changes: 18 additions & 8 deletions common/k8sutils/k8sutils.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
package k8sutils

/*
Copyright (c) 2020-2022 Dell Inc, or its subsidiaries.
Copyright (c) 2020-2025 Dell Inc, or its subsidiaries.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand All @@ -16,6 +14,8 @@ package k8sutils
limitations under the License.
*/

package k8sutils

import (
"context"
"fmt"
Expand All @@ -31,6 +31,16 @@ import (
"k8s.io/client-go/tools/clientcmd"
)

var (
buildConfigFromFlags = clientcmd.BuildConfigFromFlags
newForConfig = kubernetes.NewForConfig
inClusterConfig = rest.InClusterConfig
)

var fsInfo = func(ctx context.Context, path string) (int64, int64, int64, int64, int64, int64, error) {
return gofsutil.FsInfo(ctx, path)
}

type leaderElection interface {
Run() error
WithNamespace(namespace string)
Expand All @@ -41,22 +51,22 @@ func CreateKubeClientSet(kubeconfig string) (*kubernetes.Clientset, error) {
var clientset *kubernetes.Clientset
if kubeconfig != "" {
// use the current context in kubeconfig
config, err := clientcmd.BuildConfigFromFlags("", kubeconfig)
config, err := buildConfigFromFlags("", kubeconfig)
if err != nil {
return nil, err
}
// create the clientset
clientset, err = kubernetes.NewForConfig(config)
clientset, err = newForConfig(config)
if err != nil {
return nil, err
}
} else {
config, err := rest.InClusterConfig()
config, err := inClusterConfig()
if err != nil {
return nil, err
}
// creates the clientset
clientset, err = kubernetes.NewForConfig(config)
clientset, err = newForConfig(config)
if err != nil {
return nil, err
}
Expand All @@ -81,7 +91,7 @@ func LeaderElection(clientset *kubernetes.Clientset, lockName string, namespace

// GetStats - Returns the stats for the volume mounted on given volume path
func GetStats(ctx context.Context, volumePath string) (int64, int64, int64, int64, int64, int64, error) {
availableBytes, totalBytes, usedBytes, totalInodes, freeInodes, usedInodes, err := gofsutil.FsInfo(ctx, volumePath)
availableBytes, totalBytes, usedBytes, totalInodes, freeInodes, usedInodes, err := fsInfo(ctx, volumePath)
if err != nil {
return 0, 0, 0, 0, 0, 0, status.Error(codes.Internal, fmt.Sprintf(
"failed to get volume stats: %s", err))
Expand Down
236 changes: 236 additions & 0 deletions common/k8sutils/k8sutils_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
/*
Copyright (c) 2025 Dell Inc, or its subsidiaries.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package k8sutils

import (
"context"
"errors"
"fmt"
"os"
"testing"
"time"

"github.com/stretchr/testify/mock"
"k8s.io/client-go/kubernetes"
"k8s.io/client-go/rest"
)

var exitFunc = os.Exit

type MockLeaderElection struct {
mock.Mock
}

func (m *MockLeaderElection) Run() error {
args := m.Called()
return args.Error(0)
}

func TestCreateKubeClientSet(t *testing.T) {
// Test cases
tests := []struct {
name string
kubeconfig string
configErr error
clientErr error
wantErr bool
}{
{
name: "Valid kubeconfig",
kubeconfig: "valid_kubeconfig",
configErr: nil,
clientErr: nil,
wantErr: false,
},
{
name: "Invalid kubeconfig",
kubeconfig: "invalid_kubeconfig",
configErr: errors.New("config error"),
clientErr: nil,
wantErr: true,
},
{
name: "In-cluster config",
kubeconfig: "",
configErr: nil,
clientErr: nil,
wantErr: false,
},
{
name: "In-cluster config error",
kubeconfig: "",
configErr: errors.New("config error"),
clientErr: nil,
wantErr: true,
},
{
name: "New for config error",
kubeconfig: "",
configErr: nil,
clientErr: errors.New("client error"),
wantErr: true,
},
{
name: "New for config error",
kubeconfig: "invalid_kubeconfig",
configErr: nil,
clientErr: errors.New("client error"),
wantErr: true,
},
}

// Save original functions
origBuildConfigFromFlags := buildConfigFromFlags
origInClusterConfig := inClusterConfig
origNewForConfig := newForConfig

// Restore original functions after tests
defer func() {
buildConfigFromFlags = origBuildConfigFromFlags
inClusterConfig = origInClusterConfig
newForConfig = origNewForConfig
}()

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Mock functions
buildConfigFromFlags = func(_, _ string) (*rest.Config, error) {
return &rest.Config{}, tt.configErr
}
inClusterConfig = func() (*rest.Config, error) {
return &rest.Config{}, tt.configErr
}
newForConfig = func(_ *rest.Config) (*kubernetes.Clientset, error) {
if tt.clientErr != nil {
return nil, tt.clientErr
}
return &kubernetes.Clientset{}, nil
}

clientset, err := CreateKubeClientSet(tt.kubeconfig)
if (err != nil) != tt.wantErr {
t.Errorf("CreateKubeClientSet() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !tt.wantErr && clientset == nil {
t.Errorf("CreateKubeClientSet() = nil, want non-nil")
}
})
}
}

func TestGetStats(t *testing.T) {
// Set up the necessary dependencies
ctx := context.Background()
volumePath := "/path/to/volume"
availableBytes, totalBytes, usedBytes, totalInodes, freeInodes, usedInodes, _ := GetStats(ctx, volumePath)

expectedAvailableBytes := int64(0)
expectedTotalBytes := int64(0)
expectedUsedBytes := int64(0)
expectedTotalInodes := int64(0)
expectedFreeInodes := int64(0)
expectedUsedInodes := int64(0)
if availableBytes != expectedAvailableBytes {
t.Errorf("Expected availableBytes to be %d, but got %d", expectedAvailableBytes, availableBytes)
}
if totalBytes != expectedTotalBytes {
t.Errorf("Expected totalBytes to be %d, but got %d", expectedTotalBytes, totalBytes)
}
if usedBytes != expectedUsedBytes {
t.Errorf("Expected usedBytes to be %d, but got %d", expectedUsedBytes, usedBytes)
}
if totalInodes != expectedTotalInodes {
t.Errorf("Expected totalInodes to be %d, but got %d", expectedTotalInodes, totalInodes)
}
if freeInodes != expectedFreeInodes {
t.Errorf("Expected freeInodes to be %d, but got %d", expectedFreeInodes, freeInodes)
}
if usedInodes != expectedUsedInodes {
t.Errorf("Expected usedInodes to be %d, but got %d", expectedUsedInodes, usedInodes)
}
}

func TestGetStatsNoError(t *testing.T) {
// Set up the necessary dependencies
defaultFsInfo := fsInfo
fsInfo = func(_ context.Context, _ string) (int64, int64, int64, int64, int64, int64, error) {
return 1, 1, 1, 1, 1, 1, nil
}
defer func() {
fsInfo = defaultFsInfo
}()

ctx := context.Background()
volumePath := "/path/to/volume"
availableBytes, totalBytes, usedBytes, totalInodes, freeInodes, usedInodes, _ := GetStats(ctx, volumePath)

expectedAvailableBytes := int64(1)
expectedTotalBytes := int64(1)
expectedUsedBytes := int64(1)
expectedTotalInodes := int64(1)
expectedFreeInodes := int64(1)
expectedUsedInodes := int64(1)
if availableBytes != expectedAvailableBytes {
t.Errorf("Expected availableBytes to be %d, but got %d", expectedAvailableBytes, availableBytes)
}
if totalBytes != expectedTotalBytes {
t.Errorf("Expected totalBytes to be %d, but got %d", expectedTotalBytes, totalBytes)
}
if usedBytes != expectedUsedBytes {
t.Errorf("Expected usedBytes to be %d, but got %d", expectedUsedBytes, usedBytes)
}
if totalInodes != expectedTotalInodes {
t.Errorf("Expected totalInodes to be %d, but got %d", expectedTotalInodes, totalInodes)
}
if freeInodes != expectedFreeInodes {
t.Errorf("Expected freeInodes to be %d, but got %d", expectedFreeInodes, freeInodes)
}
if usedInodes != expectedUsedInodes {
t.Errorf("Expected usedInodes to be %d, but got %d", expectedUsedInodes, usedInodes)
}
}

func TestLeaderElection(t *testing.T) {
clientset := &kubernetes.Clientset{} // Mock or use a fake clientset if needed
runFunc := func(_ context.Context) {
fmt.Println("Running leader function")
}

mockLE := new(MockLeaderElection)
mockLE.On("Run").Return(nil) // Mocking a successful run

// Override exitFunc to prevent test from exiting
exitCalled := false
oldExit := exitFunc
defer func() { recover(); exitFunc = oldExit }()
exitFunc = func(_ int) { exitCalled = true; panic("exitFunc called") }

// Simulate LeaderElection function
func() {
defer func() {
if r := recover(); r != nil {
exitCalled = true
}
}()
LeaderElection(clientset, "test-lock", "test-namespace", time.Second, time.Second*2, time.Second*3, runFunc)
}()

if !exitCalled {
t.Errorf("exitFunc was called unexpectedly")
}
}
10 changes: 6 additions & 4 deletions common/utils/logging.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
package utils

/*
Copyright (c) 2021-2022 Dell Inc, or its subsidiaries.
Copyright (c) 2021-2025 Dell Inc, or its subsidiaries.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand All @@ -16,6 +14,8 @@ package utils
limitations under the License.
*/

package utils

import (
"context"
"fmt"
Expand Down Expand Up @@ -170,7 +170,9 @@ func ParseLogLevel(lvl string) (logrus.Level, error) {
}

// UpdateLogLevel updates the log level
func UpdateLogLevel(lvl logrus.Level) {
func UpdateLogLevel(lvl logrus.Level, mu *sync.Mutex) {
mu.Lock()
defer mu.Unlock()
singletonLog.Level = lvl
}

Expand Down
8 changes: 4 additions & 4 deletions common/utils/utils.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
package utils

/*
Copyright (c) 2019-2022 Dell Inc, or its subsidiaries.
Copyright (c) 2019-2025 Dell Inc, or its subsidiaries.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand All @@ -16,6 +14,8 @@ package utils
limitations under the License.
*/

package utils

import (
"context"
"errors"
Expand Down Expand Up @@ -94,7 +94,7 @@ func ParseInt64FromContext(ctx context.Context, key string) (int64, error) {
}

// RemoveExistingCSISockFile When the sock file that the gRPC server is going to be listening on already exists, error will be thrown saying the address is already in use, thus remove it first
func RemoveExistingCSISockFile() error {
var RemoveExistingCSISockFile = func() error {
log := GetLogger()
protoAddr := os.Getenv(constants.EnvCSIEndpoint)

Expand Down
Loading

0 comments on commit 1bc83a3

Please sign in to comment.