Skip to content

Commit

Permalink
libnet/d/remote: replace errorWithRollback
Browse files Browse the repository at this point in the history
Use defer funcs instead.

For no apparant reasons, a few error cases in the Join method were not
triggering a rollback. This is now fixed.

Signed-off-by: Albin Kerouanton <[email protected]>
  • Loading branch information
akerouanton committed May 10, 2024
1 parent 4d525c9 commit 5952920
Showing 1 changed file with 29 additions and 17 deletions.
46 changes: 29 additions & 17 deletions libnetwork/drivers/remote/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ func (d *driver) DeleteNetwork(nid string) error {
return d.call("DeleteNetwork", &api.DeleteNetworkRequest{NetworkID: nid}, &api.DeleteNetworkResponse{})
}

func (d *driver) CreateEndpoint(nid, eid string, ifInfo driverapi.InterfaceInfo, epOptions map[string]interface{}) error {
func (d *driver) CreateEndpoint(nid, eid string, ifInfo driverapi.InterfaceInfo, epOptions map[string]interface{}) (retErr error) {
if ifInfo == nil {
return errors.New("must not be called with nil InterfaceInfo")
}
Expand All @@ -199,6 +199,16 @@ func (d *driver) CreateEndpoint(nid, eid string, ifInfo driverapi.InterfaceInfo,
return err
}

defer func() {
if retErr != nil {
if err := d.DeleteEndpoint(nid, eid); err != nil {
retErr = fmt.Errorf("%w; failed to roll back: %w", err, retErr)
} else {
retErr = fmt.Errorf("%w; rolled back", retErr)
}
}
}()

inIface, err := parseInterface(res)
if err != nil {
return err
Expand All @@ -210,31 +220,23 @@ func (d *driver) CreateEndpoint(nid, eid string, ifInfo driverapi.InterfaceInfo,

if inIface.MacAddress != nil {
if err := ifInfo.SetMacAddress(inIface.MacAddress); err != nil {
return errorWithRollback(fmt.Sprintf("driver modified interface MAC address: %v", err), d.DeleteEndpoint(nid, eid))
return fmt.Errorf("driver modified interface MAC address: %v", err)
}
}
if inIface.Address != nil {
if err := ifInfo.SetIPAddress(inIface.Address); err != nil {
return errorWithRollback(fmt.Sprintf("driver modified interface address: %v", err), d.DeleteEndpoint(nid, eid))
return fmt.Errorf("driver modified interface address: %v", err)
}
}
if inIface.AddressIPv6 != nil {
if err := ifInfo.SetIPAddress(inIface.AddressIPv6); err != nil {
return errorWithRollback(fmt.Sprintf("driver modified interface address: %v", err), d.DeleteEndpoint(nid, eid))
return fmt.Errorf("driver modified interface address: %v", err)
}
}

return nil
}

func errorWithRollback(msg string, err error) error {
rollback := "rolled back"
if err != nil {
rollback = "failed to roll back: " + err.Error()
}
return fmt.Errorf("%s; %s", msg, rollback)
}

func (d *driver) DeleteEndpoint(nid, eid string) error {
deleteRequest := &api.DeleteEndpointRequest{
NetworkID: nid,
Expand All @@ -256,7 +258,7 @@ func (d *driver) EndpointOperInfo(nid, eid string) (map[string]interface{}, erro
}

// Join method is invoked when a Sandbox is attached to an endpoint.
func (d *driver) Join(nid, eid string, sboxKey string, jinfo driverapi.JoinInfo, options map[string]interface{}) error {
func (d *driver) Join(nid, eid string, sboxKey string, jinfo driverapi.JoinInfo, options map[string]interface{}) (retErr error) {
join := &api.JoinRequest{
NetworkID: nid,
EndpointID: eid,
Expand All @@ -271,10 +273,20 @@ func (d *driver) Join(nid, eid string, sboxKey string, jinfo driverapi.JoinInfo,
return err
}

defer func() {
if retErr != nil {
if err := d.Leave(nid, eid); err != nil {
retErr = fmt.Errorf("%w; failed to roll back: %w", err, retErr)
} else {
retErr = fmt.Errorf("%w; rolled back", retErr)
}
}
}()

ifaceName := res.InterfaceName
if iface := jinfo.InterfaceName(); iface != nil && ifaceName != nil {
if err := iface.SetNames(ifaceName.SrcName, ifaceName.DstPrefix); err != nil {
return errorWithRollback(fmt.Sprintf("failed to set interface name: %s", err), d.Leave(nid, eid))
return fmt.Errorf("failed to set interface name: %s", err)
}
}

Expand All @@ -284,15 +296,15 @@ func (d *driver) Join(nid, eid string, sboxKey string, jinfo driverapi.JoinInfo,
return fmt.Errorf(`unable to parse Gateway "%s"`, res.Gateway)
}
if jinfo.SetGateway(addr) != nil {
return errorWithRollback(fmt.Sprintf("failed to set gateway: %v", addr), d.Leave(nid, eid))
return fmt.Errorf("failed to set gateway: %v", addr)
}
}
if res.GatewayIPv6 != "" {
if addr = net.ParseIP(res.GatewayIPv6); addr == nil {
return fmt.Errorf(`unable to parse GatewayIPv6 "%s"`, res.GatewayIPv6)
}
if jinfo.SetGatewayIPv6(addr) != nil {
return errorWithRollback(fmt.Sprintf("failed to set gateway IPv6: %v", addr), d.Leave(nid, eid))
return fmt.Errorf("failed to set gateway IPv6: %v", addr)
}
}
if len(res.StaticRoutes) > 0 {
Expand All @@ -302,7 +314,7 @@ func (d *driver) Join(nid, eid string, sboxKey string, jinfo driverapi.JoinInfo,
}
for _, route := range routes {
if jinfo.AddStaticRoute(route.Destination, route.RouteType, route.NextHop) != nil {
return errorWithRollback(fmt.Sprintf("failed to set static route: %v", route), d.Leave(nid, eid))
return fmt.Errorf("failed to set static route: %v", route)
}
}
}
Expand Down

0 comments on commit 5952920

Please sign in to comment.