Skip to content

Commit

Permalink
chore: update the user API to be usable
Browse files Browse the repository at this point in the history
The user API was migrated from the Gateway and wasn't used outside
getting and listing. This change ensures that update and delete work,
especially in regard to changing user roles.

One note: this change does make it possible to change the "nobody" user
when authentication is turned off. However, the role will be switched
back to admin automatically.

Signed-off-by: Donnie Adams <[email protected]>
  • Loading branch information
thedadams committed Jan 8, 2025
1 parent 1a2e386 commit c1f35a6
Show file tree
Hide file tree
Showing 7 changed files with 134 additions and 97 deletions.
2 changes: 1 addition & 1 deletion apiclient/types/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ const (
type Role int

func (u Role) HasRole(role Role) bool {
return role >= u
return u != RoleUnknown && role >= u
}

type User struct {
Expand Down
1 change: 1 addition & 0 deletions pkg/api/authz/authz.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ var staticRules = map[string][]string{
"/api/oauth/redirect/{service}",
"/api/assistants/{path...}",
"GET /api/me",
"PATCH /api/users/{id}",
"POST /api/llm-proxy/",
"POST /api/prompt",
"GET /api/models",
Expand Down
14 changes: 9 additions & 5 deletions pkg/gateway/client/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,15 @@ func (u UserDecorator) AuthenticateRequest(req *http.Request) (*authenticator.Re
return nil, false, nil
}

gatewayUser, err := u.client.EnsureIdentity(req.Context(), &types.Identity{
Email: firstValue(resp.User.GetExtra(), "email"),
AuthProviderID: uint(firstValueAsInt(resp.User.GetExtra(), "auth_provider_id")),
ProviderUsername: resp.User.GetName(),
}, req.Header.Get("X-Obot-User-Timezone"))
gatewayUser, err := u.client.EnsureIdentity(
req.Context(),
&types.Identity{
Email: firstValue(resp.User.GetExtra(), "email"),
AuthProviderID: uint(firstValueAsInt(resp.User.GetExtra(), "auth_provider_id")),
ProviderUsername: resp.User.GetName(),
},
req.Header.Get("X-Obot-User-Timezone"),
)
if err != nil {
return nil, false, err
}
Expand Down
15 changes: 15 additions & 0 deletions pkg/gateway/client/error.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package client

type LastAdminError struct{}

func (e *LastAdminError) Error() string {
return "last admin"
}

type AlreadyExistsError struct {
name string
}

func (e *AlreadyExistsError) Error() string {
return e.name + " already exists"
}
10 changes: 5 additions & 5 deletions pkg/gateway/client/identity.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import (

// EnsureIdentity ensures that the given identity exists in the database, and returns the user associated with it.
func (c *Client) EnsureIdentity(ctx context.Context, id *types.Identity, timezone string) (*types.User, error) {
role := types2.RoleBasic
var role types2.Role
if _, ok := c.adminEmails[id.Email]; ok {
role = types2.RoleAdmin
}
Expand All @@ -24,7 +24,7 @@ func (c *Client) EnsureIdentityWithRole(ctx context.Context, id *types.Identity,
var user *types.User
if err := c.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
var err error
user, err = EnsureIdentity(tx, id, timezone, role)
user, err = ensureIdentity(tx, id, timezone, role)
return err
}); err != nil {
return nil, err
Expand All @@ -33,8 +33,8 @@ func (c *Client) EnsureIdentityWithRole(ctx context.Context, id *types.Identity,
return user, nil
}

// EnsureIdentity ensures that the given identity exists in the database, and returns the user associated with it.
func EnsureIdentity(tx *gorm.DB, id *types.Identity, timezone string, role types2.Role) (*types.User, error) {
// ensureIdentity ensures that the given identity exists in the database, and returns the user associated with it.
func ensureIdentity(tx *gorm.DB, id *types.Identity, timezone string, role types2.Role) (*types.User, error) {
email := id.Email
if err := tx.First(id).Error; errors.Is(err, gorm.ErrRecordNotFound) {
if err = tx.Create(id).Error; err != nil {
Expand Down Expand Up @@ -75,7 +75,7 @@ func EnsureIdentity(tx *gorm.DB, id *types.Identity, timezone string, role types
}

var userChanged bool
if user.Role != role {
if role != types2.RoleUnknown && user.Role != role {
user.Role = role
userChanged = true
}
Expand Down
77 changes: 77 additions & 0 deletions pkg/gateway/client/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,22 @@ package client
import (
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"time"

types2 "github.com/obot-platform/obot/apiclient/types"
"github.com/obot-platform/obot/pkg/gateway/types"
"github.com/obot-platform/obot/pkg/proxy"
"gorm.io/gorm"
)

func (c *Client) Users(ctx context.Context, query types.UserQuery) ([]types.User, error) {
var users []types.User
return users, c.db.WithContext(ctx).Scopes(query.Scope).Find(&users).Error
}

func (c *Client) User(ctx context.Context, username string) (*types.User, error) {
u := new(types.User)
return u, c.db.WithContext(ctx).Where("username = ?", username).First(u).Error
Expand All @@ -22,6 +29,76 @@ func (c *Client) UserByID(ctx context.Context, id string) (*types.User, error) {
return u, c.db.WithContext(ctx).Where("id = ?", id).First(u).Error
}

func (c *Client) DeleteUser(ctx context.Context, username string) (*types.User, error) {
existingUser := new(types.User)
return existingUser, c.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
if err := tx.Where("username = ?", username).First(existingUser).Error; err != nil {
return err
}

if existingUser.Role.HasRole(types2.RoleAdmin) {
var adminCount int64
if err := tx.Model(new(types.User)).Where("role = ?", types2.RoleAdmin).Count(&adminCount).Error; err != nil {
return err
}

if adminCount <= 1 {
return new(LastAdminError)
}
}

if err := tx.Where("user_id = ?", existingUser.ID).Delete(new(types.Identity)).Error; err != nil {
return err
}

return tx.Delete(existingUser).Error
})
}

func (c *Client) UpdateUser(ctx context.Context, actingUserIsAdmin bool, updatedUser *types.User, username string) (*types.User, error) {
existingUser := new(types.User)
return existingUser, c.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
if err := tx.Where("username = ?", username).First(existingUser).Error; err != nil {
return err
}

// If the username is being changed, then ensure that a user with that name doesn't already exist.
if updatedUser.Username != "" && updatedUser.Username != username {
if err := tx.Model(updatedUser).Where("username = ?", updatedUser.Username).First(new(types.User)).Error; err == nil {
return &AlreadyExistsError{name: fmt.Sprintf("user with username %q", updatedUser.Username)}
} else if !errors.Is(err, gorm.ErrRecordNotFound) {
return err
}

existingUser.Username = updatedUser.Username
}

// Anyone can update their timezone
if updatedUser.Timezone != "" {
existingUser.Timezone = updatedUser.Timezone
}

// Only admins can change user roles.
if actingUserIsAdmin {
// If the role is being changed from admin to non-admin, then ensure that this isn't the last admin.
if updatedUser.Role > 0 && existingUser.Role.HasRole(types2.RoleAdmin) && !updatedUser.Role.HasRole(types2.RoleAdmin) {
var adminCount int64
if err := tx.Model(new(types.User)).Where("role = ?", types2.RoleAdmin).Count(&adminCount).Error; err != nil {
return err
}

if adminCount <= 1 {
return new(LastAdminError)
}
}

existingUser.Role = updatedUser.Role
}

return tx.Updates(existingUser).Error
})
}

func (c *Client) UpdateProfileIconIfNeeded(ctx context.Context, user *types.User, authProviderID uint) error {
if authProviderID == 0 {
return nil
Expand Down
112 changes: 26 additions & 86 deletions pkg/gateway/server/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"github.com/gptscript-ai/gptscript/pkg/mvl"
types2 "github.com/obot-platform/obot/apiclient/types"
"github.com/obot-platform/obot/pkg/api"
"github.com/obot-platform/obot/pkg/gateway/client"
"github.com/obot-platform/obot/pkg/gateway/types"
"gorm.io/gorm"
)
Expand All @@ -18,16 +19,8 @@ var pkgLog = mvl.Package()
func (s *Server) getCurrentUser(apiContext api.Context) error {
user, err := s.client.User(apiContext.Context(), apiContext.User.GetName())
if errors.Is(err, gorm.ErrRecordNotFound) {
// The only reason this would happen is if auth is turned off.
role := types2.RoleBasic
if apiContext.UserIsAdmin() {
role = types2.RoleAdmin
}
return apiContext.Write(types2.User{
Username: apiContext.User.GetName(),
Role: role,
Timezone: apiContext.UserTimezone(),
})
// This shouldn't happen, but, if it does, then the user would be unauthorized because we can't identify them.
return types2.NewErrHttp(http.StatusUnauthorized, "unauthorized")
} else if err != nil {
return err
}
Expand All @@ -40,10 +33,8 @@ func (s *Server) getCurrentUser(apiContext api.Context) error {
}

func (s *Server) getUsers(apiContext api.Context) error {
userQuery := types.NewUserQuery(apiContext.URL.Query())

var users []types.User
if err := s.db.WithContext(apiContext.Context()).Scopes(userQuery.Scope).Find(&users).Error; err != nil {
users, err := s.client.Users(apiContext.Context(), types.NewUserQuery(apiContext.URL.Query()))
if err != nil {
return fmt.Errorf("failed to get users: %v", err)
}

Expand All @@ -61,8 +52,8 @@ func (s *Server) getUser(apiContext api.Context) error {
return types2.NewErrHttp(http.StatusBadRequest, "username path parameter is required")
}

user := new(types.User)
if err := s.db.WithContext(apiContext.Context()).Where("username = ?", username).First(user).Error; err != nil {
user, err := s.client.User(apiContext.Context(), username)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return types2.NewErrNotFound("user %s not found", username)
}
Expand All @@ -74,14 +65,14 @@ func (s *Server) getUser(apiContext api.Context) error {

func (s *Server) updateUser(apiContext api.Context) error {
requestingUsername := apiContext.User.GetName()
userIsAdmin := apiContext.UserIsAdmin()
actingUserIsAdmin := apiContext.UserIsAdmin()

username := apiContext.PathValue("username")
if username == "" {
return types2.NewErrHttp(http.StatusBadRequest, "username path parameter is required")
}

if !userIsAdmin && requestingUsername != username {
if !actingUserIsAdmin && requestingUsername != username {
return types2.NewErrHttp(http.StatusForbidden, "only admins can update other users")
}

Expand All @@ -90,53 +81,22 @@ func (s *Server) updateUser(apiContext api.Context) error {
return types2.NewErrHttp(http.StatusBadRequest, "invalid user request body")
}

existingUser := new(types.User)
status := http.StatusInternalServerError
if err := s.db.WithContext(apiContext.Context()).Transaction(func(tx *gorm.DB) error {
if err := tx.Where("username = ?", username).First(existingUser).Error; err != nil {
return err
}

// If the username is being changed, then ensure that a user with that name doesn't already exist.
if user.Username != "" && user.Username != username {
if err := tx.Model(user).Where("username = ?", user.Username).First(new(types.User)).Error; err == nil {
status = http.StatusConflict
return fmt.Errorf("user with username %q already exists", user.Username)
} else if !errors.Is(err, gorm.ErrRecordNotFound) {
return err
}

existingUser.Username = user.Username
}

// Anyone can update their timezone
if user.Timezone != "" && user.Timezone != existingUser.Timezone {
if _, err := time.LoadLocation(user.Timezone); err != nil {
return types2.NewErrHttp(http.StatusBadRequest, "invalid timezone")
}
existingUser.Timezone = user.Timezone
if user.Timezone != "" {
if _, err := time.LoadLocation(user.Timezone); err != nil {
return types2.NewErrHttp(http.StatusBadRequest, "invalid timezone")
}
}

// Only admins can change user roles.
if userIsAdmin {
// If the role is being changed from admin to non-admin, then ensure that this isn't the last admin.
if user.Role > 0 && existingUser.Role.HasRole(types2.RoleAdmin) && !user.Role.HasRole(types2.RoleAdmin) {
var adminCount int64
if err := tx.Model(new(types.User)).Count(&adminCount).Error; err != nil {
return err
}

if adminCount <= 1 {
status = http.StatusBadRequest
return fmt.Errorf("cannot remove last admin")
}
}

existingUser.Role = user.Role
status := http.StatusInternalServerError
existingUser, err := s.client.UpdateUser(apiContext.Context(), actingUserIsAdmin, user, username)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
status = http.StatusNotFound
} else if lae := (*client.LastAdminError)(nil); errors.As(err, &lae) {
status = http.StatusBadRequest
} else if ae := (*client.AlreadyExistsError)(nil); errors.As(err, &ae) {
status = http.StatusConflict
}

return tx.Updates(existingUser).Error
}); err != nil {
return types2.NewErrHttp(status, fmt.Sprintf("failed to update user: %v", err))
}

Expand All @@ -149,33 +109,13 @@ func (s *Server) deleteUser(apiContext api.Context) error {
return types2.NewErrHttp(http.StatusBadRequest, "username path parameter is required")
}

existingUser := new(types.User)
status := http.StatusInternalServerError
if err := s.db.WithContext(apiContext.Context()).Transaction(func(tx *gorm.DB) error {
if err := tx.Where("username = ?", username).First(existingUser).Error; err != nil {
return err
}

if existingUser.Role.HasRole(types2.RoleAdmin) {
var adminCount int64
if err := tx.Model(new(types.User)).Count(&adminCount).Error; err != nil {
return err
}

if adminCount <= 1 {
status = http.StatusBadRequest
return fmt.Errorf("cannot remove last admin")
}
}

if err := tx.Where("user_id = ?", existingUser.ID).Delete(new(types.Identity)).Error; err != nil {
return err
}

return tx.Delete(existingUser).Error
}); err != nil {
existingUser, err := s.client.DeleteUser(apiContext.Context(), username)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
status = http.StatusNotFound
} else if lae := (*client.LastAdminError)(nil); errors.As(err, &lae) {
status = http.StatusBadRequest
}
return types2.NewErrHttp(status, fmt.Sprintf("failed to delete user: %v", err))
}
Expand Down

0 comments on commit c1f35a6

Please sign in to comment.