From 403e35d2ffca3f78ef13b1e45a69570d9ba5b46b Mon Sep 17 00:00:00 2001 From: Devashish Date: Wed, 17 Apr 2024 22:28:02 -0400 Subject: [PATCH] Add thread-safe add and get methods for plugin storage * Add general code improvements --- packer/build.go | 6 ++-- packer/plugin.go | 72 ++++++++++++++++++++++++++++++++++++++++++++---- 2 files changed, 69 insertions(+), 9 deletions(-) diff --git a/packer/build.go b/packer/build.go index 9dca9610ac8..8b62ec53799 100644 --- a/packer/build.go +++ b/packer/build.go @@ -60,14 +60,14 @@ type BuildMetadata struct { func (b *CoreBuild) getPluginsMetadata() map[string]PluginDetails { resp := map[string]PluginDetails{} - builderPlugin, builderPluginOk := PluginsDetailsStorage[fmt.Sprintf("%q-%q", PluginComponentBuilder, b.BuilderType)] + builderPlugin, builderPluginOk := GlobalPluginsDetailsStore.GetBuilder(b.BuilderType) if builderPluginOk { resp[builderPlugin.Name] = builderPlugin } for _, pp := range b.PostProcessors { for _, p := range pp { - postprocessorsPlugin, postprocessorsPluginOk := PluginsDetailsStorage[fmt.Sprintf("%q-%q", PluginComponentPostProcessor, p.PType)] + postprocessorsPlugin, postprocessorsPluginOk := GlobalPluginsDetailsStore.GetPostProcessor(p.PType) if postprocessorsPluginOk { resp[postprocessorsPlugin.Name] = postprocessorsPlugin } @@ -75,7 +75,7 @@ func (b *CoreBuild) getPluginsMetadata() map[string]PluginDetails { } for _, pv := range b.Provisioners { - provisionerPlugin, provisionerPluginOk := PluginsDetailsStorage[fmt.Sprintf("%q-%q", PluginComponentProvisioner, pv.PType)] + provisionerPlugin, provisionerPluginOk := GlobalPluginsDetailsStore.GetProvisioner(pv.PType) if provisionerPluginOk { resp[provisionerPlugin.Name] = provisionerPlugin } diff --git a/packer/plugin.go b/packer/plugin.go index 513694f639f..de85071d7ad 100644 --- a/packer/plugin.go +++ b/packer/plugin.go @@ -15,6 +15,7 @@ import ( "regexp" "runtime" "strings" + "sync" packersdk "github.com/hashicorp/packer-plugin-sdk/packer" pluginsdk "github.com/hashicorp/packer-plugin-sdk/plugin" @@ -151,8 +152,7 @@ func (c *PluginConfig) DiscoverMultiPlugin(pluginName, pluginPath string) error c.Builders.Set(key, func() (packersdk.Builder, error) { return c.Client(pluginPath, "start", "builder", builderName).Builder() }) - PluginsDetailsStorage[fmt.Sprintf("%q-%q", PluginComponentBuilder, key)] = pluginDetails - + GlobalPluginsDetailsStore.SetBuilder(key, pluginDetails) } if len(desc.Builders) > 0 { @@ -168,7 +168,7 @@ func (c *PluginConfig) DiscoverMultiPlugin(pluginName, pluginPath string) error c.PostProcessors.Set(key, func() (packersdk.PostProcessor, error) { return c.Client(pluginPath, "start", "post-processor", postProcessorName).PostProcessor() }) - PluginsDetailsStorage[fmt.Sprintf("%q-%q", PluginComponentPostProcessor, key)] = pluginDetails + GlobalPluginsDetailsStore.SetPostProcessor(key, pluginDetails) } if len(desc.PostProcessors) > 0 { @@ -184,7 +184,7 @@ func (c *PluginConfig) DiscoverMultiPlugin(pluginName, pluginPath string) error c.Provisioners.Set(key, func() (packersdk.Provisioner, error) { return c.Client(pluginPath, "start", "provisioner", provisionerName).Provisioner() }) - PluginsDetailsStorage[fmt.Sprintf("%q-%q", PluginComponentProvisioner, key)] = pluginDetails + GlobalPluginsDetailsStore.SetProvisioner(key, pluginDetails) } if len(desc.Provisioners) > 0 { @@ -200,7 +200,7 @@ func (c *PluginConfig) DiscoverMultiPlugin(pluginName, pluginPath string) error c.DataSources.Set(key, func() (packersdk.Datasource, error) { return c.Client(pluginPath, "start", "datasource", datasourceName).Datasource() }) - PluginsDetailsStorage[fmt.Sprintf("%q-%q", PluginComponentDataSource, key)] = pluginDetails + GlobalPluginsDetailsStore.SetDataSource(key, pluginDetails) } if len(desc.Datasources) > 0 { log.Printf("found external %v datasource from %s plugin", desc.Datasources, pluginName) @@ -268,4 +268,64 @@ type PluginDetails struct { PluginPath string } -var PluginsDetailsStorage = map[string]PluginDetails{} +type pluginsDetailsStorage struct { + rwMutex sync.RWMutex + data map[string]PluginDetails +} + +var GlobalPluginsDetailsStore = &pluginsDetailsStorage{ + data: make(map[string]PluginDetails), +} + +func (pds *pluginsDetailsStorage) set(key string, plugin PluginDetails) { + pds.rwMutex.Lock() + defer pds.rwMutex.Unlock() + pds.data[key] = plugin +} + +func (pds *pluginsDetailsStorage) get(key string) (PluginDetails, bool) { + pds.rwMutex.RLock() + defer pds.rwMutex.RUnlock() + plugin, exists := pds.data[key] + return plugin, exists +} + +func (pds *pluginsDetailsStorage) SetBuilder(name string, plugin PluginDetails) { + key := fmt.Sprintf("%q-%q", PluginComponentBuilder, name) + pds.set(key, plugin) +} + +func (pds *pluginsDetailsStorage) GetBuilder(name string) (PluginDetails, bool) { + key := fmt.Sprintf("%q-%q", PluginComponentBuilder, name) + return pds.get(key) +} + +func (pds *pluginsDetailsStorage) SetPostProcessor(name string, plugin PluginDetails) { + key := fmt.Sprintf("%q-%q", PluginComponentPostProcessor, name) + pds.set(key, plugin) +} + +func (pds *pluginsDetailsStorage) GetPostProcessor(name string) (PluginDetails, bool) { + key := fmt.Sprintf("%q-%q", PluginComponentPostProcessor, name) + return pds.get(key) +} + +func (pds *pluginsDetailsStorage) SetProvisioner(name string, plugin PluginDetails) { + key := fmt.Sprintf("%q-%q", PluginComponentProvisioner, name) + pds.set(key, plugin) +} + +func (pds *pluginsDetailsStorage) GetProvisioner(name string) (PluginDetails, bool) { + key := fmt.Sprintf("%q-%q", PluginComponentProvisioner, name) + return pds.get(key) +} + +func (pds *pluginsDetailsStorage) SetDataSource(name string, plugin PluginDetails) { + key := fmt.Sprintf("%q-%q", PluginComponentDataSource, name) + pds.set(key, plugin) +} + +func (pds *pluginsDetailsStorage) GetDataSource(name string) (PluginDetails, bool) { + key := fmt.Sprintf("%q-%q", PluginComponentDataSource, name) + return pds.get(key) +}