Skip to content

Commit

Permalink
Merge branch 'main' into vllm-model-provider
Browse files Browse the repository at this point in the history
  • Loading branch information
sanjay920 authored Dec 20, 2024
2 parents 4b90d8e + 4856ebb commit 860267e
Show file tree
Hide file tree
Showing 106 changed files with 2,943 additions and 5,178 deletions.
3 changes: 3 additions & 0 deletions apiclient/types/oauthapp.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ const (
OAuthAppTypeHubSpot OAuthAppType = "hubspot"
OAuthAppTypeGitHub OAuthAppType = "github"
OAuthAppTypeGoogle OAuthAppType = "google"
OAuthAppTypeSalesforce OAuthAppType = "salesforce"
OAuthAppTypeCustom OAuthAppType = "custom"
)

Expand Down Expand Up @@ -36,6 +37,8 @@ type OAuthAppManifest struct {
Integration string `json:"integration,omitempty"`
// Global indicates if the OAuth app is globally applied to all agents.
Global *bool `json:"global,omitempty"`
// This field is only used by Salesforce
InstanceURL string `json:"instanceURL,omitempty"`
}

type OAuthAppList List[OAuthApp]
Expand Down
1 change: 1 addition & 0 deletions apiclient/types/tasks.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ type TaskRun struct {
Task TaskManifest `json:"task,omitempty"`
StartTime *Time `json:"startTime,omitempty"`
EndTime *Time `json:"endTime,omitempty"`
Error string `json:"error,omitempty"`
}

type TaskRunList List[TaskRun]
22 changes: 11 additions & 11 deletions docs/docs/05-configuration/03-auth-providers.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@ The following environment variables are required for all authentication provider
Setting the Client ID and Client Secret will mean that the authentication provider is enabled.
The remaining configuration will be validated on startup.

- `OBOT_AUTH_CLIENT_ID`: The client ID of the authentication provider.
- `OBOT_AUTH_CLIENT_SECRET`: The client secret of the authentication provider.
- `OBOT_AUTH_COOKIE_SECRET`: The secret used to encrypt the authentication cookie. Must be of size 16, 24, or 32 bytes.
- `OBOT_AUTH_ADMIN_EMAILS`: A comma-separated list of the email addresses of the admin users.
- `OBOT_SERVER_AUTH_CLIENT_ID`: The client ID of the authentication provider.
- `OBOT_SERVER_AUTH_CLIENT_SECRET`: The client secret of the authentication provider.
- `OBOT_SERVER_AUTH_COOKIE_SECRET`: The secret used to encrypt the authentication cookie. Must be of size 16, 24, or 32 bytes.
- `OBOT_SERVER_AUTH_ADMIN_EMAILS`: A comma-separated list of the email addresses of the admin users.

The following environment variables are optional for all authentication providers:
- `OBOT_AUTH_EMAIL_DOMAINS`: A comma-separated list of email domains allowed for authentication. Ignored if not set.
- `OBOT_AUTH_CONFIG_TYPE`: The type of the authentication provider. For example, `google` or `github`. Defaults to `google`.
- `OBOT_SERVER_AUTH_EMAIL_DOMAINS`: A comma-separated list of email domains allowed for authentication. Ignored if not set.
- `OBOT_SERVER_AUTH_CONFIG_TYPE`: The type of the authentication provider. For example, `google` or `github`. Defaults to `google`.

## Google

Expand All @@ -25,8 +25,8 @@ Google is the default authentication provider. There are currently no additional

GitHub authentication has the following optional configuration:

- `OBOT_AUTH_GITHUB_ORG`: The name of the organization allowed for authentication. Ignored if not set.
- `OBOT_AUTH_GITHUB_TEAM`: The name of the team allowed for authentication. Ignored if not set.
- `OBOT_AUTH_GITHUB_REPO`: Restrict logins to collaborators of this repository formatted as `orgname/repo`. Ignored if not set.
- `OBOT_AUTH_GITHUB_TOKEN`: The token to use when verifying repository collaborators (must have push access to the repository).
- `OBOT_AUTH_GITHUB_ALLOW_USERS`: A comma-separated list of users allowed to log in even if they don't belong to the organization or team.
- `OBOT_SERVER_AUTH_GITHUB_ORG`: The name of the organization allowed for authentication. Ignored if not set.
- `OBOT_SERVER_AUTH_GITHUB_TEAM`: The name of the team allowed for authentication. Ignored if not set.
- `OBOT_SERVER_AUTH_GITHUB_REPO`: Restrict logins to collaborators of this repository formatted as `orgname/repo`. Ignored if not set.
- `OBOT_SERVER_AUTH_GITHUB_TOKEN`: The token to use when verifying repository collaborators (must have push access to the repository).
- `OBOT_SERVER_AUTH_GITHUB_ALLOW_USERS`: A comma-separated list of users allowed to log in even if they don't belong to the organization or team.
16 changes: 12 additions & 4 deletions pkg/api/authz/authz.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"slices"

"k8s.io/apiserver/pkg/authentication/user"
kclient "sigs.k8s.io/controller-runtime/pkg/client"
)

const (
Expand Down Expand Up @@ -52,16 +53,19 @@ var staticRules = map[string][]string{
"POST /api/llm-proxy/",
"POST /api/prompt",
"GET /api/models",
"GET /api/version",
},
}

type Authorizer struct {
rules []rule
rules []rule
storage kclient.Client
}

func NewAuthorizer() *Authorizer {
func NewAuthorizer(storage kclient.Client) *Authorizer {
return &Authorizer{
rules: defaultRules(),
rules: defaultRules(),
storage: storage,
}
}

Expand All @@ -75,7 +79,11 @@ func (a *Authorizer) Authorize(req *http.Request, user user.Info) bool {
}
}

return authorizeThread(req, user)
if authorizeThread(req, user) {
return true
}

return a.authorizeThreadFileDownload(req, user)
}

type rule struct {
Expand Down
55 changes: 55 additions & 0 deletions pkg/api/authz/thread.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ import (
"strings"

"github.com/gptscript-ai/gptscript/pkg/types"
"github.com/obot-platform/nah/pkg/router"
v1 "github.com/obot-platform/obot/pkg/storage/apis/otto.otto8.ai/v1"
"github.com/obot-platform/obot/pkg/system"
"k8s.io/apiserver/pkg/authentication/user"
)

Expand All @@ -23,3 +26,55 @@ func authorizeThread(req *http.Request, user user.Info) bool {

return false
}

func (a *Authorizer) authorizeThreadFileDownload(req *http.Request, user user.Info) bool {
if req.Method != http.MethodGet {
return false
}

if !strings.HasPrefix(req.URL.Path, "/api/threads/") {
return false
}

parts := strings.Split(req.URL.Path, "/")
if len(parts) < 6 {
return false
}
if parts[0] != "" ||
parts[1] != "api" ||
parts[2] != "threads" ||
parts[4] != "file" {
return false
}

var (
id = parts[3]
thread v1.Thread
)
if err := a.storage.Get(req.Context(), router.Key(system.DefaultNamespace, id), &thread); err != nil {
return false
}

if thread.Spec.UserUID == user.GetUID() {
return true
}

if thread.Spec.WorkflowName == "" {
return false
}

var workflow v1.Workflow
if err := a.storage.Get(req.Context(), router.Key(thread.Namespace, thread.Spec.WorkflowName), &workflow); err != nil {
return false
}

if workflow.Spec.ThreadName == "" {
return false
}

if err := a.storage.Get(req.Context(), router.Key(system.DefaultNamespace, workflow.Spec.ThreadName), &thread); err != nil {
return false
}

return thread.Spec.UserUID == user.GetUID()
}
14 changes: 3 additions & 11 deletions pkg/api/handlers/assistants.go
Original file line number Diff line number Diff line change
Expand Up @@ -281,29 +281,21 @@ func (a *AssistantHandler) GetFile(req api.Context) error {
}

func (a *AssistantHandler) UploadFile(req api.Context) error {
var (
id = req.PathValue("id")
)

thread, err := getUserThread(req, id)
thread, err := getThreadForScope(req)
if err != nil {
return err
}

if thread.Status.WorkspaceID == "" {
return types.NewErrHttp(http.StatusTooEarly, fmt.Sprintf("no workspace found for assistant %s", id))
return types.NewErrNotFound("no workspace found")
}

_, err = uploadFileToWorkspace(req.Context(), req, a.gptScript, thread.Status.WorkspaceID, "files/")
return err
}

func (a *AssistantHandler) DeleteFile(req api.Context) error {
var (
id = req.PathValue("id")
)

thread, err := getUserThread(req, id)
thread, err := getThreadForScope(req)
if err != nil {
return err
}
Expand Down
10 changes: 6 additions & 4 deletions pkg/api/handlers/files.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package handlers
import (
"context"
"fmt"
"io"
"net/http"
"path/filepath"
"strings"
Expand Down Expand Up @@ -90,7 +89,10 @@ func listKnowledgeFiles(req api.Context, agentName, threadName, knowledgeSetName
func uploadKnowledgeToWorkspace(req api.Context, gClient *gptscript.GPTScript, ws *v1.Workspace, agentName, threadName, knowledgeSetName string) error {
filename := req.PathValue("file")

size, err := uploadFileToWorkspace(req.Context(), req, gClient, ws.Status.WorkspaceID, "")
size, err := uploadFileToWorkspace(req.Context(), req, gClient, ws.Status.WorkspaceID, "", api.BodyOptions{
// 100MB
MaxBytes: 100 * 1024 * 1024,
})
if err != nil {
return err
}
Expand Down Expand Up @@ -180,13 +182,13 @@ func getFileInWorkspace(ctx context.Context, req api.Context, gClient *gptscript
return err
}

func uploadFileToWorkspace(ctx context.Context, req api.Context, gClient *gptscript.GPTScript, workspaceID, prefix string) (int, error) {
func uploadFileToWorkspace(ctx context.Context, req api.Context, gClient *gptscript.GPTScript, workspaceID, prefix string, opts ...api.BodyOptions) (int, error) {
file := req.PathValue("file")
if file == "" {
return 0, fmt.Errorf("file path parameter is required")
}

contents, err := io.ReadAll(req.Request.Body)
contents, err := req.Body(opts...)
if err != nil {
return 0, fmt.Errorf("failed to read request body: %w", err)
}
Expand Down
11 changes: 7 additions & 4 deletions pkg/api/handlers/tasks.go
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,7 @@ func (t *TaskHandler) Run(req api.Context) error {
WorkflowExecutionName: editorWFE(req, workflow.Name),
OwningThreadName: userThread.Name,
StepID: stepID,
ThreadCredentialScope: new(bool),
})
if err != nil {
return err
Expand All @@ -288,6 +289,7 @@ func convertTaskRun(workflow *v1.Workflow, wfe *v1.WorkflowExecution) types.Task
Task: convertTaskManifest(wfe.Status.WorkflowManifest),
StartTime: types.NewTime(wfe.CreationTimestamp.Time),
EndTime: endTime,
Error: wfe.Status.Error,
}
}

Expand Down Expand Up @@ -548,10 +550,11 @@ func (t *TaskHandler) Create(req api.Context) error {
Namespace: req.Namespace(),
},
Spec: v1.WorkflowSpec{
ThreadName: thread.Name,
Manifest: workflowManifest,
KnowledgeSetNames: thread.Status.KnowledgeSetNames,
WorkspaceName: workspace.Name,
ThreadName: thread.Name,
Manifest: workflowManifest,
KnowledgeSetNames: thread.Status.KnowledgeSetNames,
WorkspaceName: workspace.Name,
CredentialContextID: thread.Name,
},
}

Expand Down
24 changes: 22 additions & 2 deletions pkg/api/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package api

import (
"encoding/json"
"errors"
"io"
"net/http"
"slices"
Expand Down Expand Up @@ -95,8 +96,27 @@ func (r *Context) Read(obj any) error {
return json.Unmarshal(data, obj)
}

func (r *Context) Body() ([]byte, error) {
return io.ReadAll(io.LimitReader(r.Request.Body, 1<<20))
type BodyOptions struct {
MaxBytes int64
}

func (r *Context) Body(opts ...BodyOptions) (_ []byte, err error) {
defer func() {
if maxErr := (*http.MaxBytesError)(nil); errors.As(err, &maxErr) {
err = types.NewErrHttp(http.StatusRequestEntityTooLarge, "request body too large")
}
_, _ = io.Copy(io.Discard, r.Request.Body)
}()
var opt BodyOptions
for _, o := range opts {
if o.MaxBytes > 0 {
opt.MaxBytes = o.MaxBytes
}
}
if opt.MaxBytes == 0 {
opt.MaxBytes = 8 * 1024 * 1024
}
return io.ReadAll(http.MaxBytesReader(r.ResponseWriter, r.Request.Body, opt.MaxBytes))
}

func (r *Context) WriteCreated(obj any) error {
Expand Down
6 changes: 4 additions & 2 deletions pkg/api/router/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ func Router(services *services.Services) (http.Handler, error) {
// Assistant files
mux.HandleFunc("GET /api/assistants/{assistant_id}/files", assistants.Files)
mux.HandleFunc("GET /api/assistants/{assistant_id}/file/{file...}", assistants.GetFile)
mux.HandleFunc("POST /api/assistants/{id}/files/{file...}", assistants.UploadFile)
mux.HandleFunc("DELETE /api/assistants/{id}/files/{file...}", assistants.DeleteFile)
mux.HandleFunc("POST /api/assistants/{assistant_id}/file/{file...}", assistants.UploadFile)
mux.HandleFunc("DELETE /api/assistants/{assistant_id}/files/{file...}", assistants.DeleteFile)
// Assistant knowledge files
mux.HandleFunc("GET /api/assistants/{id}/knowledge", assistants.Knowledge)
mux.HandleFunc("POST /api/assistants/{id}/knowledge/{file}", assistants.UploadKnowledge)
Expand All @@ -86,6 +86,8 @@ func Router(services *services.Services) (http.Handler, error) {
mux.HandleFunc("DELETE /api/assistants/{assistant_id}/tasks/{id}/runs/{run_id}", tasks.DeleteRun)
mux.HandleFunc("GET /api/assistants/{assistant_id}/tasks/{task_id}/runs/{run_id}/files", assistants.Files)
mux.HandleFunc("GET /api/assistants/{assistant_id}/tasks/{task_id}/runs/{run_id}/file/{file...}", assistants.GetFile)
mux.HandleFunc("POST /api/assistants/{assistant_id}/tasks/{task_id}/runs/{run_id}/file/{file...}", assistants.UploadFile)
mux.HandleFunc("DELETE /api/assistants/{assistant_id}/tasks/{task_id}/runs/{run_id}/files/{file...}", assistants.DeleteFile)
mux.HandleFunc("GET /api/assistants/{assistant_id}/tasks/{id}/events", tasks.Events)
mux.HandleFunc("POST /api/assistants/{assistant_id}/tasks/{id}/events", tasks.Abort)
mux.HandleFunc("GET /api/assistants/{assistant_id}/tasks/{id}/runs/{run_id}/events", tasks.Events)
Expand Down
25 changes: 25 additions & 0 deletions pkg/gateway/server/oauth_apps.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"io"
"net/http"
"net/url"
"strconv"
"strings"
"time"

Expand Down Expand Up @@ -445,6 +446,30 @@ func (s *Server) callbackOAuthApp(apiContext api.Context) error {
CreatedAt: time.Now(),
RefreshToken: googleTokenResp.RefreshToken,
}
case types2.OAuthAppTypeSalesforce:
salesforceTokenResp := new(types.SalesforceOAuthTokenResponse)
if err := json.NewDecoder(resp.Body).Decode(salesforceTokenResp); err != nil {
return fmt.Errorf("failed to parse token response: %w", err)
}
issuedAt, err := strconv.ParseInt(salesforceTokenResp.IssuedAt, 10, 64)
if err != nil {
return fmt.Errorf("failed to parse token response: %w", err)
}
createdAt := time.Unix(issuedAt/1000, (issuedAt%1000)*1000000)

tokenResp = &types.OAuthTokenResponse{
State: state,
TokenType: salesforceTokenResp.TokenType,
Scope: salesforceTokenResp.Scope,
AccessToken: salesforceTokenResp.AccessToken,
ExpiresIn: 7200, // Relies on Salesforce admin not overriding the default 2 hours
Ok: true, // Assuming true if no error is present
CreatedAt: createdAt,
RefreshToken: salesforceTokenResp.RefreshToken,
Extras: map[string]string{
"GPTSCRIPT_SALESFORCE_URL": salesforceTokenResp.InstanceURL,
},
}
default:
if err := json.NewDecoder(resp.Body).Decode(tokenResp); err != nil {
return fmt.Errorf("failed to parse token response: %w", err)
Expand Down
Loading

0 comments on commit 860267e

Please sign in to comment.