Skip to content

Commit

Permalink
add Controller and Node instance as an interface field in the Driver …
Browse files Browse the repository at this point in the history
…struct
  • Loading branch information
dhij committed Sep 11, 2022
1 parent 8b884e9 commit f009c37
Show file tree
Hide file tree
Showing 3 changed files with 139 additions and 96 deletions.
56 changes: 39 additions & 17 deletions driver/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"net/http"
"strconv"
"strings"
"sync"
"time"

"github.com/container-storage-interface/spec/lib/go/csi"
Expand Down Expand Up @@ -78,9 +79,30 @@ var (
}
)

type Controller struct {
// publishInfoVolumeName is used to pass the volume name from
// `ControllerPublishVolume` to `NodeStageVolume or `NodePublishVolume`
publishInfoVolumeName string
region string
doTag string
defaultVolumesPageSize uint

storage godo.StorageService
storageActions godo.StorageActionsService
droplets godo.DropletsService
snapshots godo.SnapshotsService
account godo.AccountService
tags godo.TagsService

healthChecker *HealthChecker
log *logrus.Entry

readyMu sync.Mutex
}

// CreateVolume creates a new volume from the given request. The function is
// idempotent.
func (d *Driver) CreateVolume(ctx context.Context, req *csi.CreateVolumeRequest) (*csi.CreateVolumeResponse, error) {
func (d *Controller) CreateVolume(ctx context.Context, req *csi.CreateVolumeRequest) (*csi.CreateVolumeResponse, error) {
if req.Name == "" {
return nil, status.Error(codes.InvalidArgument, "CreateVolume Name must be provided")
}
Expand Down Expand Up @@ -230,7 +252,7 @@ func (d *Driver) CreateVolume(ctx context.Context, req *csi.CreateVolumeRequest)
}

// DeleteVolume deletes the given volume. The function is idempotent.
func (d *Driver) DeleteVolume(ctx context.Context, req *csi.DeleteVolumeRequest) (*csi.DeleteVolumeResponse, error) {
func (d *Controller) DeleteVolume(ctx context.Context, req *csi.DeleteVolumeRequest) (*csi.DeleteVolumeResponse, error) {
if req.VolumeId == "" {
return nil, status.Error(codes.InvalidArgument, "DeleteVolume Volume ID must be provided")
}
Expand Down Expand Up @@ -259,7 +281,7 @@ func (d *Driver) DeleteVolume(ctx context.Context, req *csi.DeleteVolumeRequest)
}

// ControllerPublishVolume attaches the given volume to the node
func (d *Driver) ControllerPublishVolume(ctx context.Context, req *csi.ControllerPublishVolumeRequest) (*csi.ControllerPublishVolumeResponse, error) {
func (d *Controller) ControllerPublishVolume(ctx context.Context, req *csi.ControllerPublishVolumeRequest) (*csi.ControllerPublishVolumeResponse, error) {
if req.VolumeId == "" {
return nil, status.Error(codes.InvalidArgument, "ControllerPublishVolume Volume ID must be provided")
}
Expand Down Expand Up @@ -389,7 +411,7 @@ func (d *Driver) ControllerPublishVolume(ctx context.Context, req *csi.Controlle
}

// ControllerUnpublishVolume deattaches the given volume from the node
func (d *Driver) ControllerUnpublishVolume(ctx context.Context, req *csi.ControllerUnpublishVolumeRequest) (*csi.ControllerUnpublishVolumeResponse, error) {
func (d *Controller) ControllerUnpublishVolume(ctx context.Context, req *csi.ControllerUnpublishVolumeRequest) (*csi.ControllerUnpublishVolumeResponse, error) {
if req.VolumeId == "" {
return nil, status.Error(codes.InvalidArgument, "ControllerUnpublishVolume Volume ID must be provided")
}
Expand Down Expand Up @@ -475,7 +497,7 @@ func (d *Driver) ControllerUnpublishVolume(ctx context.Context, req *csi.Control

// ValidateVolumeCapabilities checks whether the volume capabilities requested
// are supported.
func (d *Driver) ValidateVolumeCapabilities(ctx context.Context, req *csi.ValidateVolumeCapabilitiesRequest) (*csi.ValidateVolumeCapabilitiesResponse, error) {
func (d *Controller) ValidateVolumeCapabilities(ctx context.Context, req *csi.ValidateVolumeCapabilitiesRequest) (*csi.ValidateVolumeCapabilitiesResponse, error) {
if req.VolumeId == "" {
return nil, status.Error(codes.InvalidArgument, "ValidateVolumeCapabilities Volume ID must be provided")
}
Expand Down Expand Up @@ -517,7 +539,7 @@ func (d *Driver) ValidateVolumeCapabilities(ctx context.Context, req *csi.Valida
}

// ListVolumes returns a list of all requested volumes
func (d *Driver) ListVolumes(ctx context.Context, req *csi.ListVolumesRequest) (*csi.ListVolumesResponse, error) {
func (d *Controller) ListVolumes(ctx context.Context, req *csi.ListVolumesRequest) (*csi.ListVolumesResponse, error) {
maxEntries := req.MaxEntries
if maxEntries == 0 && d.defaultVolumesPageSize > 0 {
maxEntries = int32(d.defaultVolumesPageSize)
Expand Down Expand Up @@ -596,7 +618,7 @@ func (d *Driver) ListVolumes(ctx context.Context, req *csi.ListVolumesRequest) (
}

// GetCapacity returns the capacity of the storage pool
func (d *Driver) GetCapacity(ctx context.Context, req *csi.GetCapacityRequest) (*csi.GetCapacityResponse, error) {
func (d *Controller) GetCapacity(ctx context.Context, req *csi.GetCapacityRequest) (*csi.GetCapacityResponse, error) {
// TODO(arslan): check if we can provide this information somehow
d.log.WithFields(logrus.Fields{
"params": req.Parameters,
Expand All @@ -606,7 +628,7 @@ func (d *Driver) GetCapacity(ctx context.Context, req *csi.GetCapacityRequest) (
}

// ControllerGetCapabilities returns the capabilities of the controller service.
func (d *Driver) ControllerGetCapabilities(ctx context.Context, req *csi.ControllerGetCapabilitiesRequest) (*csi.ControllerGetCapabilitiesResponse, error) {
func (d *Controller) ControllerGetCapabilities(ctx context.Context, req *csi.ControllerGetCapabilitiesRequest) (*csi.ControllerGetCapabilitiesResponse, error) {
newCap := func(cap csi.ControllerServiceCapability_RPC_Type) *csi.ControllerServiceCapability {
return &csi.ControllerServiceCapability{
Type: &csi.ControllerServiceCapability_Rpc{
Expand Down Expand Up @@ -643,7 +665,7 @@ func (d *Driver) ControllerGetCapabilities(ctx context.Context, req *csi.Control

// CreateSnapshot will be called by the CO to create a new snapshot from a
// source volume on behalf of a user.
func (d *Driver) CreateSnapshot(ctx context.Context, req *csi.CreateSnapshotRequest) (*csi.CreateSnapshotResponse, error) {
func (d *Controller) CreateSnapshot(ctx context.Context, req *csi.CreateSnapshotRequest) (*csi.CreateSnapshotResponse, error) {
if req.GetName() == "" {
return nil, status.Error(codes.InvalidArgument, "CreateSnapshot Name must be provided")
}
Expand Down Expand Up @@ -739,7 +761,7 @@ func (d *Driver) CreateSnapshot(ctx context.Context, req *csi.CreateSnapshotRequ
}

// DeleteSnapshot will be called by the CO to delete a snapshot.
func (d *Driver) DeleteSnapshot(ctx context.Context, req *csi.DeleteSnapshotRequest) (*csi.DeleteSnapshotResponse, error) {
func (d *Controller) DeleteSnapshot(ctx context.Context, req *csi.DeleteSnapshotRequest) (*csi.DeleteSnapshotResponse, error) {
log := d.log.WithFields(logrus.Fields{
"req_snapshot_id": req.GetSnapshotId(),
"method": "delete_snapshot",
Expand Down Expand Up @@ -772,7 +794,7 @@ func (d *Driver) DeleteSnapshot(ctx context.Context, req *csi.DeleteSnapshotRequ
// system within the given parameters regardless of how they were created.
// ListSnapshots shold not list a snapshot that is being created but has not
// been cut successfully yet.
func (d *Driver) ListSnapshots(ctx context.Context, req *csi.ListSnapshotsRequest) (*csi.ListSnapshotsResponse, error) {
func (d *Controller) ListSnapshots(ctx context.Context, req *csi.ListSnapshotsRequest) (*csi.ListSnapshotsResponse, error) {
listResp := &csi.ListSnapshotsResponse{}
log := d.log.WithFields(logrus.Fields{
"snapshot_id": req.SnapshotId,
Expand Down Expand Up @@ -862,7 +884,7 @@ func (d *Driver) ListSnapshots(ctx context.Context, req *csi.ListSnapshotsReques
}

// ControllerExpandVolume is called from the resizer to increase the volume size.
func (d *Driver) ControllerExpandVolume(ctx context.Context, req *csi.ControllerExpandVolumeRequest) (*csi.ControllerExpandVolumeResponse, error) {
func (d *Controller) ControllerExpandVolume(ctx context.Context, req *csi.ControllerExpandVolumeRequest) (*csi.ControllerExpandVolumeResponse, error) {
volID := req.GetVolumeId()

if len(volID) == 0 {
Expand Down Expand Up @@ -928,15 +950,15 @@ func (d *Driver) ControllerExpandVolume(ctx context.Context, req *csi.Controller
// The call is used for the CSI health check feature
// (https://github.com/kubernetes/enhancements/pull/1077) which we do not
// support yet.
func (d *Driver) ControllerGetVolume(ctx context.Context, req *csi.ControllerGetVolumeRequest) (*csi.ControllerGetVolumeResponse, error) {
func (d *Controller) ControllerGetVolume(ctx context.Context, req *csi.ControllerGetVolumeRequest) (*csi.ControllerGetVolumeResponse, error) {
return nil, status.Error(codes.Unimplemented, "")
}

// extractStorage extracts the storage size in bytes from the given capacity
// range. If the capacity range is not satisfied it returns the default volume
// size. If the capacity range is above supported sizes, it returns an
// error. If the capacity range is below supported size, it returns the minimum supported size
func (d *Driver) extractStorage(capRange *csi.CapacityRange) (int64, error) {
func (d *Controller) extractStorage(capRange *csi.CapacityRange) (int64, error) {
if capRange == nil {
return defaultVolumeSizeInBytes, nil
}
Expand Down Expand Up @@ -1016,7 +1038,7 @@ func formatBytes(inputBytes int64) string {
}

// waitAction waits until the given action for the volume has completed.
func (d *Driver) waitAction(ctx context.Context, log *logrus.Entry, volumeID string, actionID int) error {
func (d *Controller) waitAction(ctx context.Context, log *logrus.Entry, volumeID string, actionID int) error {
err := wait.PollUntil(1*time.Second, func() (done bool, err error) {
action, _, err := d.storageActions.Get(ctx, volumeID, actionID)
if err != nil {
Expand Down Expand Up @@ -1057,7 +1079,7 @@ type limitDetails struct {
}

// checkLimit checks whether the user hit their account volume limit.
func (d *Driver) checkLimit(ctx context.Context) (*limitDetails, error) {
func (d *Controller) checkLimit(ctx context.Context) (*limitDetails, error) {
// only one provisioner runs, we can make sure to prevent burst creation
d.readyMu.Lock()
defer d.readyMu.Unlock()
Expand Down Expand Up @@ -1144,7 +1166,7 @@ func validateCapabilities(caps []*csi.VolumeCapability) []string {
return violations.List()
}

func (d *Driver) tagVolume(parentCtx context.Context, vol *godo.Volume) error {
func (d *Controller) tagVolume(parentCtx context.Context, vol *godo.Volume) error {
for _, tag := range vol.Tags {
if tag == d.doTag {
return nil
Expand Down
148 changes: 79 additions & 69 deletions driver/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,37 +57,22 @@ var (
// csi.NodeServer
//
type Driver struct {
name string
// publishInfoVolumeName is used to pass the volume name from
// `ControllerPublishVolume` to `NodeStageVolume or `NodePublishVolume`
publishInfoVolumeName string

endpoint string
debugAddr string
hostID func() string
region string
doTag string
isController bool
defaultVolumesPageSize uint
name string
endpoint string
debugAddr string
isController bool

srv *grpc.Server
httpSrv *http.Server
log *logrus.Entry
mounter Mounter

storage godo.StorageService
storageActions godo.StorageActionsService
droplets godo.DropletsService
snapshots godo.SnapshotsService
account godo.AccountService
tags godo.TagsService

healthChecker *HealthChecker

// ready defines whether the driver is ready to function. This value will
// be used by the `Identity` service via the `Probe()` method.
readyMu sync.Mutex // protects ready
ready bool

csi.NodeServer
csi.ControllerServer
}

// NewDriverParams defines the parameters that can be passed to NewDriver.
Expand Down Expand Up @@ -131,52 +116,75 @@ func NewDriver(p NewDriverParams) (*Driver, error) {
}
hostID := strconv.Itoa(hostIDInt)

var opts []godo.ClientOpt
opts = append(opts, godo.SetBaseURL(p.URL))

if version == "" {
version = "dev"
}
opts = append(opts, godo.SetUserAgent("csi-digitalocean/"+version))

doClient, err := godo.New(oauthClient, opts...)
if err != nil {
return nil, fmt.Errorf("couldn't initialize DigitalOcean client: %s", err)
}

healthChecker := NewHealthChecker(&doHealthChecker{account: doClient.Account})

log := logrus.New().WithFields(logrus.Fields{
"region": region,
"host_id": hostID,
"version": version,
})

return &Driver{
name: driverName,
publishInfoVolumeName: driverName + "/volume-name",

doTag: p.DOTag,
endpoint: p.Endpoint,
debugAddr: p.DebugAddr,
defaultVolumesPageSize: p.DefaultVolumesPageSize,

hostID: func() string { return hostID },
region: region,
mounter: newMounter(log),
log: log,
// we're assuming only the controller has a non-empty token.
isController: p.Token != "",

storage: doClient.Storage,
storageActions: doClient.StorageActions,
droplets: doClient.Droplets,
snapshots: doClient.Snapshots,
account: doClient.Account,
tags: doClient.Tags,

healthChecker: healthChecker,
}, nil
var driver *Driver
// we're assuming only the controller has a non-empty token.
if p.Token != "" {
var opts []godo.ClientOpt
opts = append(opts, godo.SetBaseURL(p.URL))

if version == "" {
version = "dev"
}
opts = append(opts, godo.SetUserAgent("csi-digitalocean/"+version))

doClient, err := godo.New(oauthClient, opts...)
if err != nil {
return nil, fmt.Errorf("couldn't initialize DigitalOcean client: %s", err)
}

healthChecker := NewHealthChecker(&doHealthChecker{account: doClient.Account})

controller := &Controller{
publishInfoVolumeName: driverName + "/volume-name",
region: region,
doTag: p.DOTag,
defaultVolumesPageSize: p.DefaultVolumesPageSize,

storage: doClient.Storage,
storageActions: doClient.StorageActions,
droplets: doClient.Droplets,
snapshots: doClient.Snapshots,
account: doClient.Account,
tags: doClient.Tags,

healthChecker: healthChecker,
log: log,
}

driver = &Driver{
name: driverName,
endpoint: p.Endpoint,
debugAddr: p.DebugAddr,
isController: p.Token != "",

ControllerServer: controller,
}
} else {
node := &Node{
publishInfoVolumeName: driverName + "/volume-name",
region: region,
hostID: func() string { return hostID },
log: log,
mounter: newMounter(log),
}

driver = &Driver{
name: driverName,
endpoint: p.Endpoint,
debugAddr: p.DebugAddr,
isController: p.Token != "",

NodeServer: node,
}
}

return driver, nil
}

// Run starts the CSI plugin by communication over the given endpoint
Expand Down Expand Up @@ -217,11 +225,15 @@ func (d *Driver) Run(ctx context.Context) error {
return resp, err
}

d.srv = grpc.NewServer(grpc.UnaryInterceptor(errHandler))
csi.RegisterIdentityServer(d.srv, d)

// warn the user, it'll not propagate to the user but at least we see if
// something is wrong in the logs. Only check if the driver is running with
// a token (i.e: controller)
if d.isController {
details, err := d.checkLimit(context.Background())
controller := d.ControllerServer.(*Controller)
details, err := controller.checkLimit(context.Background())
if err != nil {
return fmt.Errorf("failed to check volumes limits on startup: %s", err)
}
Expand All @@ -235,7 +247,7 @@ func (d *Driver) Run(ctx context.Context) error {
if d.debugAddr != "" {
mux := http.NewServeMux()
mux.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) {
err := d.healthChecker.Check(r.Context())
err := controller.healthChecker.Check(r.Context())
if err != nil {
d.log.WithError(err).Error("executing health check")
http.Error(w, err.Error(), http.StatusInternalServerError)
Expand All @@ -248,13 +260,11 @@ func (d *Driver) Run(ctx context.Context) error {
Handler: mux,
}
}
csi.RegisterControllerServer(d.srv, controller)
} else {
csi.RegisterNodeServer(d.srv, d.NodeServer.(*Node))
}

d.srv = grpc.NewServer(grpc.UnaryInterceptor(errHandler))
csi.RegisterIdentityServer(d.srv, d)
csi.RegisterControllerServer(d.srv, d)
csi.RegisterNodeServer(d.srv, d)

d.ready = true // we're now ready to go!
d.log.WithFields(logrus.Fields{
"grpc_addr": grpcAddr,
Expand Down
Loading

0 comments on commit f009c37

Please sign in to comment.