Skip to content

Commit

Permalink
Merge pull request #124 from mindflowai/model-org
Browse files Browse the repository at this point in the history
Improve ConfiguredModel and MindFlowModel classes to be more efficient, generic, and extensible.
  • Loading branch information
steegecs authored Jun 20, 2023
2 parents 5a32999 + ff9741d commit 9e8f139
Show file tree
Hide file tree
Showing 9 changed files with 82 additions and 126 deletions.
2 changes: 1 addition & 1 deletion mindflow/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.5.1"
__version__ = "0.5.2"
8 changes: 0 additions & 8 deletions mindflow/core/types/definitions/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,11 @@


class ConversationParameterKey(Enum):
"""
Document argument enum
"""

ID: str = "id"
MESSAGES: str = "messages"
TOTAL_TOKENS: str = "total_tokens"


class ConversationID(Enum):
"""
Conversation ID enum
"""

CHAT_0: str = "chat_0"
CODE_GEN_0: str = "code_gen_0"
4 changes: 0 additions & 4 deletions mindflow/core/types/definitions/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,6 @@


class DocumentType(Enum):
"""
Document type enum
"""

FILE: str = "file"
SLACK: str = "slack"
GITHUB: str = "github"
Expand Down
14 changes: 0 additions & 14 deletions mindflow/core/types/definitions/mind_flow_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,17 +67,3 @@ class MindFlowModelDescription(Enum):
MindFlowModelParameterKey.DESCRIPTION.value: MindFlowModelDescription.EMBEDDING.value,
},
}

# MindFlowModelUnion = Union[
# MindFlowModelID,
# MindFlowModelDefaults,
# MindFlowModelName,
# MindFlowModelType,
# MindFlowModelDescription,
# ]


# def get_mind_flow_model_static(
# static: Type[MindFlowModelUnion], key: MindFlowModelUnion
# ) -> MindFlowModelUnion:
# return static.__members__[key.name]
13 changes: 0 additions & 13 deletions mindflow/core/types/definitions/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,16 +282,3 @@ class ModelAnthropic(Enum):
ModelParameterKey.CONFIG_DESCRIPTION.value: ModelConfigDescription.TEXT_EMBEDDING_ADA_002.value,
},
}


# ModelUnion = Union[
# ModelID,
# ModelParameterKey,
# ModelName,
# ModelHardTokenLimit,
# ModelDescription,
# ]


# def get_model_static(static: Type[ModelUnion], key: ModelUnion) -> ModelUnion:
# return static.__members__[key.name]
37 changes: 0 additions & 37 deletions mindflow/core/types/definitions/object.py

This file was deleted.

16 changes: 0 additions & 16 deletions mindflow/core/types/definitions/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,19 +118,3 @@ class ServiceModel(Enum):
ServiceParameterKey.API_URL.value: ServiceURL.PINECONE.value,
},
}


# ServiceUnion = Union[
# ServiceID,
# ServiceParameterKey,
# ServiceConfigParameterKey,
# ServiceName,
# ServiceURL,
# ServiceModel,
# ServiceModelTypeTextEmbedding,
# ServiceModelTypeTextCompletion,
# ]


# def get_service_static(static: Type[ServiceUnion], key: ServiceUnion) -> ServiceUnion:
# return static.__members__[key.name]
26 changes: 21 additions & 5 deletions mindflow/core/types/mindflow_model.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
import sys
from typing import Dict
from typing import Dict, Generic, TypeVar, cast
from mindflow.core.types.definitions.model import ModelID
from mindflow.core.types.store_traits.static import StaticStore
from mindflow.core.types.store_traits.json import JsonStore

from mindflow.core.types.model import ConfiguredModel
from mindflow.core.types.model import (
ConfiguredModel,
ConfiguredOpenAIChatCompletionModel,
ConfiguredAnthropicChatCompletionModel,
ConfiguredOpenAITextEmbeddingModel,
)
from mindflow.core.types.service import ConfiguredServices
from mindflow.core.types.definitions.mind_flow_model import MindFlowModelID
from mindflow.core.types.definitions.service import (
Expand All @@ -24,11 +30,14 @@ class MindFlowModelConfig(JsonStore):
model: str


class ConfiguredMindFlowModel:
T = TypeVar("T", bound="ConfiguredModel")


class ConfiguredMindFlowModel(Generic[T]):
id: str # index, query, embedding
name: str
defaults: Dict[str, str]
model: ConfiguredModel
model: T

def __init__(self, mindflow_model_id: str, configured_services: ConfiguredServices):
self.id = mindflow_model_id
Expand All @@ -44,7 +53,14 @@ def __init__(self, mindflow_model_id: str, configured_services: ConfiguredServic
) is None:
model_id = self.get_default_model_id(mindflow_model_id, configured_services)

self.model = ConfiguredModel(model_id)
if model_id in [ModelID.GPT_3_5_TURBO.value, ModelID.GPT_4.value]:
self.model = cast(T, ConfiguredOpenAIChatCompletionModel(model_id))
elif model_id in [ModelID.CLAUDE_INSTANT_V1.value, ModelID.CLAUDE_V1.value]:
self.model = cast(T, ConfiguredAnthropicChatCompletionModel(model_id))
elif model_id == ModelID.TEXT_EMBEDDING_ADA_002.value:
self.model = cast(T, ConfiguredOpenAITextEmbeddingModel(model_id))
else:
raise Exception("Unsupported model: " + model_id)

def get_default_model_id(
self, mindflow_model_id: str, configured_services: ConfiguredServices
Expand Down
88 changes: 60 additions & 28 deletions mindflow/core/types/model.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from abc import ABC, abstractmethod
import time
from typing import Optional, Union

Expand All @@ -6,9 +7,6 @@

import numpy as np
from traitlets import Callable

from mindflow.core.types.definitions.model_type import ModelType

import tiktoken

from mindflow.core.types.store_traits.json import JsonStore
Expand Down Expand Up @@ -40,7 +38,7 @@ class ModelConfig(JsonStore):
soft_token_limit: int


class ConfiguredModel(Callable):
class ConfiguredModel(ABC, Callable):
id: str
name: str
service: str
Expand Down Expand Up @@ -80,7 +78,28 @@ def __init__(self, model_id: str):
except NameError:
pass

def openai_chat_completion(
@abstractmethod
def __call__(self, *args, **kwargs):
pass


class ConfiguredOpenAIChatCompletionModel(ConfiguredModel):
id: str
name: str
service: str
model_type: str

tokenizer: tiktoken.Encoding

hard_token_limit: int
token_cost: int
token_cost_unit: str

# Config
soft_token_limit: int
api_key: str

def __call__(
self,
messages: list,
temperature: float = 0.0,
Expand All @@ -106,7 +125,24 @@ def openai_chat_completion(

return ModelError(error_message)

def anthropic_chat_completion(

class ConfiguredAnthropicChatCompletionModel(ConfiguredModel):
id: str
name: str
service: str
model_type: str

tokenizer: tiktoken.Encoding

hard_token_limit: int
token_cost: int
token_cost_unit: str

# Config
soft_token_limit: int
api_key: str

def __call__(
self,
prompt: str,
temperature: float = 0.0,
Expand All @@ -131,7 +167,24 @@ def anthropic_chat_completion(

return ModelError(error_message)

def openai_embedding(self, text: str) -> Union[np.ndarray, ModelError]:

class ConfiguredOpenAITextEmbeddingModel(ConfiguredModel):
id: str
name: str
service: str
model_type: str

tokenizer: tiktoken.Encoding

hard_token_limit: int
token_cost: int
token_cost_unit: str

# Config
soft_token_limit: int
api_key: str

def __call__(self, text: str) -> Union[np.ndarray, ModelError]:
try_count = 0
error_message = ""
while try_count < 5:
Expand All @@ -146,24 +199,3 @@ def openai_embedding(self, text: str) -> Union[np.ndarray, ModelError]:
time.sleep(5)

return ModelError(error_message)

def __call__(self, prompt, *args, **kwargs):
service_model_mapping = {
(
ServiceID.OPENAI.value,
ModelType.TEXT_COMPLETION.value,
): self.openai_chat_completion,
(
ServiceID.OPENAI.value,
ModelType.TEXT_EMBEDDING.value,
): self.openai_embedding,
(
ServiceID.ANTHROPIC.value,
ModelType.TEXT_COMPLETION.value,
): self.anthropic_chat_completion,
}
if (
func := service_model_mapping.get((self.service, self.model_type))
) is not None:
return func(prompt, *args, **kwargs)
raise NotImplementedError(f"Service {self.service} not implemented.")

0 comments on commit 9e8f139

Please sign in to comment.