From 5952920380debc0493f6890175ba9657ce0e9a17 Mon Sep 17 00:00:00 2001 From: Albin Kerouanton Date: Wed, 6 Mar 2024 18:23:16 +0100 Subject: [PATCH] libnet/d/remote: replace errorWithRollback Use defer funcs instead. For no apparant reasons, a few error cases in the Join method were not triggering a rollback. This is now fixed. Signed-off-by: Albin Kerouanton --- libnetwork/drivers/remote/driver.go | 46 ++++++++++++++++++----------- 1 file changed, 29 insertions(+), 17 deletions(-) diff --git a/libnetwork/drivers/remote/driver.go b/libnetwork/drivers/remote/driver.go index f7e83a709ca1c..3e29412413a01 100644 --- a/libnetwork/drivers/remote/driver.go +++ b/libnetwork/drivers/remote/driver.go @@ -172,7 +172,7 @@ func (d *driver) DeleteNetwork(nid string) error { return d.call("DeleteNetwork", &api.DeleteNetworkRequest{NetworkID: nid}, &api.DeleteNetworkResponse{}) } -func (d *driver) CreateEndpoint(nid, eid string, ifInfo driverapi.InterfaceInfo, epOptions map[string]interface{}) error { +func (d *driver) CreateEndpoint(nid, eid string, ifInfo driverapi.InterfaceInfo, epOptions map[string]interface{}) (retErr error) { if ifInfo == nil { return errors.New("must not be called with nil InterfaceInfo") } @@ -199,6 +199,16 @@ func (d *driver) CreateEndpoint(nid, eid string, ifInfo driverapi.InterfaceInfo, return err } + defer func() { + if retErr != nil { + if err := d.DeleteEndpoint(nid, eid); err != nil { + retErr = fmt.Errorf("%w; failed to roll back: %w", err, retErr) + } else { + retErr = fmt.Errorf("%w; rolled back", retErr) + } + } + }() + inIface, err := parseInterface(res) if err != nil { return err @@ -210,31 +220,23 @@ func (d *driver) CreateEndpoint(nid, eid string, ifInfo driverapi.InterfaceInfo, if inIface.MacAddress != nil { if err := ifInfo.SetMacAddress(inIface.MacAddress); err != nil { - return errorWithRollback(fmt.Sprintf("driver modified interface MAC address: %v", err), d.DeleteEndpoint(nid, eid)) + return fmt.Errorf("driver modified interface MAC address: %v", err) } } if inIface.Address != nil { if err := ifInfo.SetIPAddress(inIface.Address); err != nil { - return errorWithRollback(fmt.Sprintf("driver modified interface address: %v", err), d.DeleteEndpoint(nid, eid)) + return fmt.Errorf("driver modified interface address: %v", err) } } if inIface.AddressIPv6 != nil { if err := ifInfo.SetIPAddress(inIface.AddressIPv6); err != nil { - return errorWithRollback(fmt.Sprintf("driver modified interface address: %v", err), d.DeleteEndpoint(nid, eid)) + return fmt.Errorf("driver modified interface address: %v", err) } } return nil } -func errorWithRollback(msg string, err error) error { - rollback := "rolled back" - if err != nil { - rollback = "failed to roll back: " + err.Error() - } - return fmt.Errorf("%s; %s", msg, rollback) -} - func (d *driver) DeleteEndpoint(nid, eid string) error { deleteRequest := &api.DeleteEndpointRequest{ NetworkID: nid, @@ -256,7 +258,7 @@ func (d *driver) EndpointOperInfo(nid, eid string) (map[string]interface{}, erro } // Join method is invoked when a Sandbox is attached to an endpoint. -func (d *driver) Join(nid, eid string, sboxKey string, jinfo driverapi.JoinInfo, options map[string]interface{}) error { +func (d *driver) Join(nid, eid string, sboxKey string, jinfo driverapi.JoinInfo, options map[string]interface{}) (retErr error) { join := &api.JoinRequest{ NetworkID: nid, EndpointID: eid, @@ -271,10 +273,20 @@ func (d *driver) Join(nid, eid string, sboxKey string, jinfo driverapi.JoinInfo, return err } + defer func() { + if retErr != nil { + if err := d.Leave(nid, eid); err != nil { + retErr = fmt.Errorf("%w; failed to roll back: %w", err, retErr) + } else { + retErr = fmt.Errorf("%w; rolled back", retErr) + } + } + }() + ifaceName := res.InterfaceName if iface := jinfo.InterfaceName(); iface != nil && ifaceName != nil { if err := iface.SetNames(ifaceName.SrcName, ifaceName.DstPrefix); err != nil { - return errorWithRollback(fmt.Sprintf("failed to set interface name: %s", err), d.Leave(nid, eid)) + return fmt.Errorf("failed to set interface name: %s", err) } } @@ -284,7 +296,7 @@ func (d *driver) Join(nid, eid string, sboxKey string, jinfo driverapi.JoinInfo, return fmt.Errorf(`unable to parse Gateway "%s"`, res.Gateway) } if jinfo.SetGateway(addr) != nil { - return errorWithRollback(fmt.Sprintf("failed to set gateway: %v", addr), d.Leave(nid, eid)) + return fmt.Errorf("failed to set gateway: %v", addr) } } if res.GatewayIPv6 != "" { @@ -292,7 +304,7 @@ func (d *driver) Join(nid, eid string, sboxKey string, jinfo driverapi.JoinInfo, return fmt.Errorf(`unable to parse GatewayIPv6 "%s"`, res.GatewayIPv6) } if jinfo.SetGatewayIPv6(addr) != nil { - return errorWithRollback(fmt.Sprintf("failed to set gateway IPv6: %v", addr), d.Leave(nid, eid)) + return fmt.Errorf("failed to set gateway IPv6: %v", addr) } } if len(res.StaticRoutes) > 0 { @@ -302,7 +314,7 @@ func (d *driver) Join(nid, eid string, sboxKey string, jinfo driverapi.JoinInfo, } for _, route := range routes { if jinfo.AddStaticRoute(route.Destination, route.RouteType, route.NextHop) != nil { - return errorWithRollback(fmt.Sprintf("failed to set static route: %v", route), d.Leave(nid, eid)) + return fmt.Errorf("failed to set static route: %v", route) } } }