Skip to content

Commit

Permalink
Expose system message and bump version (#32)
Browse files Browse the repository at this point in the history
* Make system message controllable

Signed-off-by: Aivin V. Solatorio <[email protected]>

* Bump version to v0.0.10

Signed-off-by: Aivin V. Solatorio <[email protected]>

---------

Signed-off-by: Aivin V. Solatorio <[email protected]>
  • Loading branch information
avsolatorio authored Dec 2, 2024
1 parent cfbaba9 commit 448e6fe
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 4 deletions.
2 changes: 1 addition & 1 deletion llm4data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
# Do this before importing any other modules.
dotenv.load_dotenv()

__version__ = "0.0.9"
__version__ = "0.0.10"

indicator2name = dict(
wdi=json.load((Path(__file__).parent / "wdi2name.json").open("r"))
Expand Down
8 changes: 6 additions & 2 deletions llm4data/augmentation/microdata/theme_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ class ThemeLLM(object):
- an idno of a study in the form of a string. The data dictionary will be retrieved from a specified NADA catalog.
"""

def __init__(self, idno: str, llm_model_id: str = "gpt-3.5-turbo", data_dictionary: Union[str, Path, dict] = None, catalog_url: str = None, vars_dir: Union[str, Path] = None, desc_dir: Union[str, Path] = None, force: bool = False, persist: bool = True):
def __init__(self, idno: str, llm_model_id: str = "gpt-3.5-turbo", data_dictionary: Union[str, Path, dict] = None, catalog_url: str = None, vars_dir: Union[str, Path] = None, desc_dir: Union[str, Path] = None, force: bool = False, persist: bool = True, system_message: str = SYSTEM_MESSAGE):
"""
Initialize the ThemeLLM object.
Expand Down Expand Up @@ -127,6 +127,7 @@ def __init__(self, idno: str, llm_model_id: str = "gpt-3.5-turbo", data_dictiona

# Set the LLM model id.
self.llm_model_id = llm_model_id
self.system_message = system_message

# State variables
self.clusters = None
Expand Down Expand Up @@ -251,7 +252,7 @@ def clustering(self, embeddings: np.ndarray, n_clusters: int = 20, n_components:

return aggcl.fit_predict(tsvd.fit_transform(embeddings))

def generate_prompts(self, force: bool = False, system_message: str = SYSTEM_MESSAGE, max_input_tokens: int = 2500, system_num_tokens: int = 100, special_sep: str = SPECIAL_SEP):
def generate_prompts(self, force: bool = False, system_message: str = None, max_input_tokens: int = 2500, system_num_tokens: int = 100, special_sep: str = SPECIAL_SEP):
"""
Generate the prompts for the microdata variables.
"""
Expand All @@ -261,6 +262,9 @@ def generate_prompts(self, force: bool = False, system_message: str = SYSTEM_MES

idno_data = []

# Set the system message.
system_message = system_message or self.system_message

for cluster in tqdm(self.clusters["cluster"].keys()):
cluster_labels = self.clusters["cluster"][cluster]
prompt = PromptZeros.build_message(
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "llm4data"
version = "0.0.9"
version = "0.0.10"
description = "LLM4Data is a Python library designed to facilitate the application of large language models (LLMs) and artificial intelligence for development data and knowledge discovery."
authors = ["Aivin V. Solatorio <[email protected]>"]
readme = "README.md"
Expand Down

0 comments on commit 448e6fe

Please sign in to comment.