diff --git a/acls.go b/acls.go index d971d0dce7..0eb350c9db 100644 --- a/acls.go +++ b/acls.go @@ -312,15 +312,31 @@ func (h *Headscale) generateSSHRules(policy *ACLPolicy) ([]*tailcfg.SSHRule, err return rules, nil } -// CreateUserACLPolicy creates an acl policy for the given user. -func (h *Headscale) CreateUserACLPolicy( +// CreateOrUpdateUserACLPolicy creates an acl policy for the given user. +func (h *Headscale) CreateOrUpdateUserACLPolicy( userID uint, policy ACLPolicy, ) (*UserACLPolicy, error) { + existingUserPolicy := UserACLPolicy{UserID: userID} + if err := h.db.Where("user_id = ?", userID).First(&existingUserPolicy).Error; err == nil { + // already exists, just update + existingUserPolicy.ACLPolicy = policy + if err := h.db.Save(&existingUserPolicy).Error; err != nil { + log.Error(). + Str("func", "CreateOrUpdateUserACLPolicy"). + Err(err). + Msg("Could not update user acl policy") + + return nil, err + } + + return &existingUserPolicy, nil + } + userACLPolicy := UserACLPolicy{ACLPolicy: policy, UserID: userID} if err := h.db.Create(&userACLPolicy).Error; err != nil { log.Error(). - Str("func", "CreateUserACLPolicy"). + Str("func", "CreateOrUpdateUserACLPolicy"). Err(err). Msg("Could not create user acl policy") diff --git a/acls_test.go b/acls_test.go index 8618757b7c..0466ccc54d 100644 --- a/acls_test.go +++ b/acls_test.go @@ -128,7 +128,7 @@ func (s *Suite) TestValidExpandTagOwnersInSources(c *check.C) { } app.db.Save(&machine) - _, err = app.CreateUserACLPolicy(user.ID, ACLPolicy{ + _, err = app.CreateOrUpdateUserACLPolicy(user.ID, ACLPolicy{ Groups: Groups{"group:test": []string{"user1", "user2"}}, TagOwners: TagOwners{"tag:test": []string{"user3", "group:test"}}, ACLs: []ACL{ @@ -148,6 +148,37 @@ func (s *Suite) TestValidExpandTagOwnersInSources(c *check.C) { c.Assert(rules[0].SrcIPs[0], check.Equals, "100.64.0.1") } +func (s *Suite) TestCreateUserAclPolicy(c *check.C) { + user, err := app.CreateUser("user1") + c.Assert(err, check.IsNil) + + _, err = app.CreateOrUpdateUserACLPolicy(user.ID, ACLPolicy{ + Groups: Groups{"group:test": []string{"user1", "user2"}}, + TagOwners: TagOwners{"tag:test": []string{"user3", "group:test"}}, + ACLs: []ACL{ + { + Action: "accept", + Sources: []string{"tag:test"}, + Destinations: []string{"*:*"}, + }, + }, + }) + c.Assert(err, check.IsNil) + + _, err = app.CreateOrUpdateUserACLPolicy(user.ID, ACLPolicy{ + Groups: Groups{"group:test": []string{"user1", "user2"}}, + TagOwners: TagOwners{"tag:test": []string{"user3", "group:test"}}, + ACLs: []ACL{ + { + Action: "accept", + Sources: []string{"tag:test"}, + Destinations: []string{"*:80"}, + }, + }, + }) + c.Assert(err, check.IsNil) +} + // this test should validate that we can expand a group in a TagOWner section and // match properly the IP's of the related hosts. The owner is valid and the tag is also valid. // the tag is matched in the Destinations section. @@ -180,7 +211,7 @@ func (s *Suite) TestValidExpandTagOwnersInDestinations(c *check.C) { } app.db.Save(&machine) - _, err = app.CreateUserACLPolicy(user.ID, ACLPolicy{ + _, err = app.CreateOrUpdateUserACLPolicy(user.ID, ACLPolicy{ Groups: Groups{"group:test": []string{"user1", "user2"}}, TagOwners: TagOwners{"tag:test": []string{"user3", "group:test"}}, ACLs: []ACL{ @@ -229,7 +260,7 @@ func (s *Suite) TestInvalidTagValidUser(c *check.C) { } app.db.Save(&machine) - _, err = app.CreateUserACLPolicy(user.ID, ACLPolicy{ + _, err = app.CreateOrUpdateUserACLPolicy(user.ID, ACLPolicy{ TagOwners: TagOwners{"tag:test": []string{"user1"}}, ACLs: []ACL{ { @@ -299,7 +330,7 @@ func (s *Suite) TestValidTagInvalidUser(c *check.C) { } app.db.Save(&machine) - _, err = app.CreateUserACLPolicy(user.ID, ACLPolicy{ + _, err = app.CreateOrUpdateUserACLPolicy(user.ID, ACLPolicy{ TagOwners: TagOwners{"tag:webapp": []string{"user1"}}, ACLs: []ACL{ { diff --git a/grpcv1.go b/grpcv1.go index 2c40434c75..900360f660 100644 --- a/grpcv1.go +++ b/grpcv1.go @@ -3,6 +3,7 @@ package headscale import ( "context" + "errors" "fmt" "strings" "time" @@ -44,6 +45,9 @@ func (api headscaleV1APIServer) CreateUser( ) (*v1.CreateUserResponse, error) { user, err := api.h.CreateUser(request.GetName()) if err != nil { + if errors.Is(err, ErrUserExists) { + return nil, status.Error(codes.AlreadyExists, err.Error()) + } return nil, err } @@ -573,7 +577,7 @@ func (api headscaleV1APIServer) CreateACLPolicy( return nil, err } - _, err = api.h.CreateUserACLPolicy(user.ID, aclPolicy) + _, err = api.h.CreateOrUpdateUserACLPolicy(user.ID, aclPolicy) if err != nil { return nil, err } diff --git a/machine_test.go b/machine_test.go index 41ddb21d8a..701aeb8036 100644 --- a/machine_test.go +++ b/machine_test.go @@ -178,7 +178,7 @@ func (s *Suite) TestListPeers(c *check.C) { user, err := app.CreateUser("test") c.Assert(err, check.IsNil) - _, err = app.CreateUserACLPolicy(user.ID, ACLPolicy{ + _, err = app.CreateOrUpdateUserACLPolicy(user.ID, ACLPolicy{ ACLs: []ACL{ { Action: "accept", @@ -251,7 +251,7 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) { } // - _, err = app.CreateUserACLPolicy(user.ID, ACLPolicy{ + _, err = app.CreateOrUpdateUserACLPolicy(user.ID, ACLPolicy{ ACLs: []ACL{ { Action: "accept", @@ -1193,7 +1193,7 @@ func (s *Suite) TestAutoApproveRoutes(c *check.C) { user, err := app.CreateUser("test") c.Assert(err, check.IsNil) - _, err = app.CreateUserACLPolicy(user.ID, ACLPolicy{ + _, err = app.CreateOrUpdateUserACLPolicy(user.ID, ACLPolicy{ Groups: Groups{"group:test": []string{"test"}}, TagOwners: TagOwners{"tag:exit": []string{"test"}}, ACLs: []ACL{