Skip to content

Commit

Permalink
enhance: remove models for model providers that have been deconfigured
Browse files Browse the repository at this point in the history
Signed-off-by: Donnie Adams <[email protected]>
  • Loading branch information
thedadams committed Dec 26, 2024
1 parent db96799 commit 9002e67
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 20 deletions.
11 changes: 10 additions & 1 deletion pkg/api/handlers/modelprovider.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,16 @@ func (mp *ModelProviderHandler) Deconfigure(req api.Context) error {
// Stop the model provider so that the credential is completely removed from the system.
mp.dispatcher.StopModelProvider(ref.Namespace, ref.Name)

return nil
if ref.Annotations[v1.ModelProviderSyncAnnotation] == "" {
if ref.Annotations == nil {
ref.Annotations = make(map[string]string, 1)
}
ref.Annotations[v1.ModelProviderSyncAnnotation] = "true"
} else {
delete(ref.Annotations, v1.ModelProviderSyncAnnotation)
}

return req.Update(&ref)
}

func (mp *ModelProviderHandler) Reveal(req api.Context) error {
Expand Down
44 changes: 25 additions & 19 deletions pkg/controller/handlers/toolreference/toolreference.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package toolreference
import (
"context"
"crypto/sha256"
"errors"
"fmt"
"net/url"
"os"
Expand Down Expand Up @@ -370,8 +371,8 @@ func (h *Handler) BackPopulateModels(req router.Request, _ router.Response) erro
cred, err := h.gptClient.RevealCredential(req.Ctx, []string{string(toolRef.UID)}, toolRef.Name)
if err != nil {
if strings.Contains(err.Error(), "credential not found") {
// Model provider is not configured, don't error
return nil
// Unable to find credential, ensure all models remove for this model provider
return removeModelsForProvider(req.Ctx, req.Client, req.Namespace, req.Name)
}
return err
}
Expand Down Expand Up @@ -420,6 +421,27 @@ func (h *Handler) BackPopulateModels(req router.Request, _ router.Response) erro
return nil
}

func removeModelsForProvider(ctx context.Context, c client.Client, namespace, name string) error {
var models v1.ModelList
if err := c.List(ctx, &models, &client.ListOptions{
Namespace: namespace,
FieldSelector: fields.SelectorFromSet(fields.Set{
"spec.manifest.modelProvider": name,
}),
}); err != nil {
return fmt.Errorf("failed to list models for model provider %q for cleanup: %w", name, err)
}

var errs []error
for _, model := range models.Items {
if err := client.IgnoreNotFound(c.Delete(ctx, &model)); err != nil {
errs = append(errs, fmt.Errorf("failed to delete model %q for cleanup: %w", model.Name, err))
}
}

return errors.Join(errs...)
}

func (h *Handler) CleanupModelProvider(req router.Request, _ router.Response) error {
toolRef := req.Object.(*v1.ToolReference)
if toolRef.Spec.Type != types.ToolReferenceTypeModelProvider || toolRef.Status.Tool == nil {
Expand All @@ -432,23 +454,7 @@ func (h *Handler) CleanupModelProvider(req router.Request, _ router.Response) er
}
}

var models v1.ModelList
if err := req.List(&models, &client.ListOptions{
Namespace: req.Namespace,
FieldSelector: fields.SelectorFromSet(fields.Set{
"spec.manifest.modelProvider": toolRef.Name,
}),
}); err != nil {
return fmt.Errorf("failed to list models for model provider %q for cleanup: %w", toolRef.Name, err)
}

for _, model := range models.Items {
if err := client.IgnoreNotFound(req.Delete(&model)); err != nil {
return fmt.Errorf("failed to delete model %q for cleanup: %w", model.Name, err)
}
}

return nil
return removeModelsForProvider(req.Ctx, req.Client, req.Namespace, req.Name)
}

func modelName(modelProviderName, modelName string) string {
Expand Down

0 comments on commit 9002e67

Please sign in to comment.