From 0e4b4a308cfd19539b43b9d8b8358ed7dd7c1c3e Mon Sep 17 00:00:00 2001 From: Starttoaster Date: Tue, 21 May 2024 13:01:44 -0700 Subject: [PATCH] Add flag to create db --- cmd/root.go | 7 +++++++ internal/db/db.go | 9 ++++++++- 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/cmd/root.go b/cmd/root.go index 455917d..ddfc65c 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -19,6 +19,7 @@ var rootCmd = &cobra.Command{ Run: func(cmd *cobra.Command, args []string) { // Init db package err := db.Init( + viper.GetBool("create-db"), viper.GetString("mysql-host"), viper.GetString("mysql-database"), viper.GetString("mysql-user"), @@ -149,6 +150,7 @@ func init() { rootCmd.PersistentFlags().Uint32("records", 0, "The number of records to send (defaults to 0)") rootCmd.PersistentFlags().Uint32("max-concurrent", 1, "The max number of records to send concurrently (in individual requests.) (defaults to 1)") rootCmd.PersistentFlags().Bool("reset", false, "This resets the mysqlpunch table at the beginning of a run, deleting all records in it and resetting the ID counter. (defaults to false)") + rootCmd.PersistentFlags().Bool("create-db", false, "When set to true, this will handle creating the database in your mysql server. (defaults to false)") err := viper.BindPFlag("log-level", rootCmd.PersistentFlags().Lookup("log-level")) if err != nil { @@ -189,4 +191,9 @@ func init() { if err != nil { log.Fatalln(err.Error()) } + + err = viper.BindPFlag("create-db", rootCmd.PersistentFlags().Lookup("create-db")) + if err != nil { + log.Fatalln(err.Error()) + } } diff --git a/internal/db/db.go b/internal/db/db.go index f937c35..ab270c0 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -15,7 +15,7 @@ var db *sql.DB // Init accepts authentication parameters for a mysql db and creates a client // This function may also be configured to create tables in the db on behalf of the application for setup purposes. -func Init(host, database, user, passwd string) error { +func Init(createDB bool, host, database, user, passwd string) error { // Create db client var err error db, err = sql.Open("mysql", assembleDataSourceName(host, database, user, passwd)) @@ -23,6 +23,13 @@ func Init(host, database, user, passwd string) error { return fmt.Errorf("creating database client: %v", err) } + if createDB { + _, err = db.Exec(`CREATE DATABASE IF NOT EXISTS mysqlpunch;`) + if err != nil { + return fmt.Errorf("creating creating mysqlpunch database (if it didn't exist): %v", err) + } + } + log.Debug("Creating table in mysql db if they don't already exist") err = initTable()