-
Notifications
You must be signed in to change notification settings - Fork 1
/
train.py
28 lines (25 loc) · 900 Bytes
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
from src.database_trainer import DatabaseTrainer
import config
############################################################################################
# NOTE: To execute this code, you will need to create a file called config.py that defines
# the following constants (refer to the `sample_config.py` file for an example):
# - VANNA_API_KEY
# - MODEL
# - POSTGRES_HOST
# - POSTGRES_PORT
# - POSTGRES_USER
# - POSTGRES_PWD
# - POSTGRES_DB
# - DDL
############################################################################################
db_creds = {
"host": config.POSTGRES_HOST,
"port": config.POSTGRES_PORT,
"user": config.POSTGRES_USER,
"password": config.POSTGRES_PWD,
"dbname": config.POSTGRES_DB,
}
database_trainer = DatabaseTrainer(
api_key=config.VANNA_API_KEY, model=config.MODEL, db_creds=db_creds
)
database_trainer.train(ddl=config.DDL)