From 824cac2723532dab0025fd19ae9f05d66a4783ec Mon Sep 17 00:00:00 2001
From: Federico Di Pierro <nierro92@gmail.com>
Date: Mon, 6 Nov 2023 11:00:38 +0100
Subject: [PATCH] new(pkg/driver): implemented kmod Prepare method.

Signed-off-by: Federico Di Pierro <nierro92@gmail.com>
---
 cmd/driver/prepare/prepare.go |  4 +-
 pkg/driver/distro/distro.go   |  4 ++
 pkg/driver/type/bpf.go        |  5 +-
 pkg/driver/type/kmod.go       | 97 ++++++++++++++++++++++++++++++++++-
 pkg/driver/type/modernbpf.go  |  4 +-
 pkg/driver/type/type.go       |  8 ++-
 6 files changed, 114 insertions(+), 8 deletions(-)

diff --git a/cmd/driver/prepare/prepare.go b/cmd/driver/prepare/prepare.go
index faf9af19..1297a808 100644
--- a/cmd/driver/prepare/prepare.go
+++ b/cmd/driver/prepare/prepare.go
@@ -63,7 +63,7 @@ func NewDriverPrepareCmd(ctx context.Context, opt *options.Common) *cobra.Comman
 	cmd.Flags().BoolVar(&o.Download, "download", true, "Whether to enable download of drivers")
 	cmd.Flags().BoolVar(&o.Build, "build", true, "Whether to enable build of drivers")
 	cmd.Flags().StringVar(&o.DriverVersion, "driver-version", "", "Driver version to be built")
-	cmd.Flags().StringSliceVar(&o.DriverRepos, "driver-repo", []string{"https://download.falco.org/driver"}, "Specify different URL(s) where to look for prebuilt drivers")
+	cmd.Flags().StringSliceVar(&o.DriverRepos, "driver-repo", []string{driverdistro.DefaultFalcoRepo}, "Specify different URL(s) where to look for prebuilt drivers")
 
 	if err := cmd.MarkFlagRequired("driver-version"); err != nil {
 		output.ExitOnErr(o.Printer, fmt.Errorf("unable to mark flag \"driver-version\" as required"))
@@ -106,7 +106,7 @@ func (o *driverPrepareOptions) RunDriverPrepare(_ context.Context, _ []string) e
 	}
 	o.Printer.Logger.Info("found distro", o.Printer.Logger.Args("target", d.GetTargetID(info)))
 
-	err = driver.Type.Prepare()
+	err = driver.Type.Prepare(o.Printer, o.Name)
 	if err != nil {
 		return err
 	}
diff --git a/pkg/driver/distro/distro.go b/pkg/driver/distro/distro.go
index 22879824..358f534f 100644
--- a/pkg/driver/distro/distro.go
+++ b/pkg/driver/distro/distro.go
@@ -27,6 +27,10 @@ import (
 	"github.com/falcosecurity/falcoctl/pkg/output"
 )
 
+const (
+	DefaultFalcoRepo = "https://download.falco.org/driver"
+)
+
 var distros = map[string]Distro{}
 
 // ErrUnsupported is the error returned when the target distro is not supported.
diff --git a/pkg/driver/type/bpf.go b/pkg/driver/type/bpf.go
index 8f0a8348..c1056054 100644
--- a/pkg/driver/type/bpf.go
+++ b/pkg/driver/type/bpf.go
@@ -17,6 +17,8 @@ package drivertype
 
 import (
 	"k8s.io/utils/mount"
+
+	"github.com/falcosecurity/falcoctl/pkg/output"
 )
 
 func init() {
@@ -29,11 +31,12 @@ func (b *bpf) String() string {
 	return "bpf"
 }
 
-func (b *bpf) Prepare() error {
+func (b *bpf) Prepare(printer *output.Printer, _ string) error {
 	// Mount /sys/kernel/debug that is needed on old (pre 4.17) kernel releases,
 	// since these releases still did not support raw tracepoints.
 	// BPF_PROG_TYPE_RAW_TRACEPOINT was introduced in 4.17 indeed:
 	// https://github.com/torvalds/linux/commit/c4f6699dfcb8558d138fe838f741b2c10f416cf9
+	printer.Logger.Info("Mounting debugfs for bpf driver.")
 	mounter := mount.New("/bin/mount")
 	return mounter.Mount("debugfs", "/sys/kernel/debug", "debugfs", []string{"nodev"})
 }
diff --git a/pkg/driver/type/kmod.go b/pkg/driver/type/kmod.go
index a44b93f7..a3edcf97 100644
--- a/pkg/driver/type/kmod.go
+++ b/pkg/driver/type/kmod.go
@@ -15,6 +15,29 @@
 
 package drivertype
 
+import (
+	"fmt"
+	"os/exec"
+	"strings"
+	"time"
+
+	"github.com/falcosecurity/falcoctl/pkg/output"
+)
+
+const (
+	maxRmmodWait  = 10
+	rmmodWaitTime = 5 * time.Second
+)
+
+type ErrorMissingDep struct {
+	program string
+	reason  error
+}
+
+func (e *ErrorMissingDep) Error() string {
+	return fmt.Sprintf("This program requires %s (%s)", e.program, e.reason.Error())
+}
+
 func init() {
 	driverTypes["kmod"] = &kmod{}
 }
@@ -25,8 +48,78 @@ func (k *kmod) String() string {
 	return "kmod"
 }
 
-func (k *kmod) Prepare() error {
-	// todo remove previously loaded modules
+// Prepare for kmod does a cleanup of existing kernel modules.
+// First thing, it tries to rmmod the loaded kmod, if present.
+// Then, using dkms, it tries to fetch all
+// dkms-installed versions of the module to clean them up.
+func (k *kmod) Prepare(printer *output.Printer, driverName string) error {
+	_, err := exec.Command("bash", "-c", "hash lsmod").Output()
+	if err != nil {
+		return &ErrorMissingDep{program: "lsmod", reason: err}
+	}
+	_, err = exec.Command("bash", "-c", "hash rmmod").Output()
+	if err != nil {
+		return &ErrorMissingDep{program: "rmmod", reason: err}
+	}
+
+	kmodName := strings.ReplaceAll(driverName, "-", "_")
+	args := printer.Logger.Args("kmod", kmodName)
+
+	printer.Logger.Info("Check if kernel module is still loaded.", args)
+	lsmodCmdArgs := fmt.Sprintf(`lsmod | cut -d' ' -f1 | grep -qx "%s"`, kmodName)
+	_, err = exec.Command("bash", "-c", lsmodCmdArgs).Output()
+	if err == nil {
+		unloaded := false
+		// Module is still loaded, try to remove it
+		for i := 0; i < maxRmmodWait; i++ {
+			printer.Logger.Info("Kernel module is still loaded.", args)
+			printer.Logger.Info("Trying to unload it with 'rmmod'.", args)
+			if _, err = exec.Command("rmmod", kmodName).Output(); err == nil {
+				printer.Logger.Info("OK! Unloading module succeeded.", args)
+				unloaded = true
+				break
+			}
+			printer.Logger.Info("Nothing to do...'falcoctl' will wait until you remove the kernel module to have a clean termination.", args)
+			printer.Logger.Info("Check that no process is using the kernel module with 'lsmod'.", args)
+			printer.Logger.Info("Sleep 5 seconds...", args)
+			time.Sleep(rmmodWaitTime)
+		}
+		if !unloaded {
+			printer.Logger.Warn("Kernel module is still loaded, you could have incompatibility issues.", args)
+		}
+	} else {
+		printer.Logger.Info("OK! There is no module loaded.", args)
+	}
+
+	_, err = exec.Command("bash", "-c", "hash dkms").Output()
+	if err != nil {
+		printer.Logger.Info("Skipping dkms remove (dkms not found).", args)
+		return nil
+	}
+
+	printer.Logger.Info("Check all versions of kernel module in dkms.", args)
+	dkmsLsCmdArgs := fmt.Sprintf(`dkms status -m "%s" | tr -d "," | tr -d ":" | tr "/" " " | cut -d' ' -f2`, kmodName)
+	out, err := exec.Command("bash", "-c", dkmsLsCmdArgs).Output()
+	if err != nil {
+		printer.Logger.Warn("Listing kernel module versions failed.", args, printer.Logger.Args("reason", err))
+		return nil
+	}
+	if len(out) == 0 {
+		printer.Logger.Info("OK! There are no module versions in dkms.", args)
+	} else {
+		driverVersions := strings.Split(string(out), "\n")
+		printer.Logger.Info("There are some module versions in dkms.", args)
+		printer.Logger.Info("Removing all the following versions from dkms.", args, printer.Logger.Args("versions", driverVersions))
+		for _, dVer := range driverVersions {
+			dkmsRmCmdArgs := fmt.Sprintf(`dkms remove -m %s -v "%s" --all`, kmodName, dVer)
+			_, err = exec.Command("bash", "-c", dkmsRmCmdArgs).Output()
+			if err == nil {
+				printer.Logger.Info("OK! Removing succeeded.", args, printer.Logger.Args("version", dVer))
+			} else {
+				printer.Logger.Warn("Removing failed.", args, printer.Logger.Args("version", dVer))
+			}
+		}
+	}
 	return nil
 }
 
diff --git a/pkg/driver/type/modernbpf.go b/pkg/driver/type/modernbpf.go
index 0e3b49ed..3b760c63 100644
--- a/pkg/driver/type/modernbpf.go
+++ b/pkg/driver/type/modernbpf.go
@@ -15,6 +15,8 @@
 
 package drivertype
 
+import "github.com/falcosecurity/falcoctl/pkg/output"
+
 func init() {
 	driverTypes["modern-bpf"] = &modernBpf{}
 }
@@ -25,7 +27,7 @@ func (m *modernBpf) String() string {
 	return "modern-bpf"
 }
 
-func (m *modernBpf) Prepare() error {
+func (m *modernBpf) Prepare(printer *output.Printer, _ string) error {
 	return nil
 }
 
diff --git a/pkg/driver/type/type.go b/pkg/driver/type/type.go
index 4bf2e698..4343cf68 100644
--- a/pkg/driver/type/type.go
+++ b/pkg/driver/type/type.go
@@ -15,14 +15,18 @@
 
 package drivertype
 
-import "fmt"
+import (
+	"fmt"
+
+	"github.com/falcosecurity/falcoctl/pkg/output"
+)
 
 var driverTypes = map[string]DriverType{}
 
 // DriverType is the interface that wraps driver types.
 type DriverType interface {
 	String() string
-	Prepare() error
+	Prepare(printer *output.Printer, driverName string) error
 	Extension() string
 	HasArtifacts() bool
 }