-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpreprocess_data.py
122 lines (103 loc) · 3.73 KB
/
preprocess_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
"""
Airflow DAG to load raw data, process it, split it, and store in database.
Author
------
Nicolas Rojas
"""
# imports
from datetime import datetime
import pandas as pd
from sklearn.model_selection import train_test_split
from airflow import DAG
from airflow.operators.python import PythonOperator
from airflow.providers.mysql.hooks.mysql import MySqlHook
def check_table_exists(table_name: str):
"""Check whether table exists in clean_data database. If not, create it.
Parameters
----------
table_name : str
Name of table to check.
"""
# count number of rows in data table
query = f'SELECT COUNT(*) FROM information_schema.tables WHERE table_name="{table_name}"'
mysql_hook = MySqlHook(mysql_conn_id="clean_data", schema="clean_data")
connection = mysql_hook.get_conn()
cursor = connection.cursor()
cursor.execute(query)
results = cursor.fetchall()
# check whether table exists
if results[0][0] == 0:
# create table
print("----- table does not exists, creating it")
create_sql = f"CREATE TABLE `{table_name}`\
`age` SMALLINT,\
`anual_income` BIGINT,\
`credit_score` SMALLINT,\
`loan_amount` BIGINT,\
`loan_duration_years` TINYINT,\
`number_of_open_accounts` SMALLINT,\
`had_past_default` TINYINT,\
`loan_approval` TINYINT\
)"
mysql_hook.run(create_sql)
else:
# no need to create table
print("----- table already exists")
return "Table checked"
def store_data(dataframe: pd.DataFrame, table_name: str):
"""Store dataframe data in given table, in clean data database.
Parameters
----------
dataframe : pd.DataFrame
Dataframe to store in database.
table_name : str
Name of the table to store the data.
"""
check_table_exists(table_name)
# insert every dataframe row into sql table
mysql_hook = MySqlHook(mysql_conn_id="clean_data", schema="clean_data")
sql_column_names = ", ".join(
["`" + name + "`" for name in dataframe.columns]
)
conn = mysql_hook.get_conn()
cur = conn.cursor()
# VALUES in query are %s repeated as many columns are in dataframe
query = f"INSERT INTO `{table_name}` ({sql_column_names}) \
VALUES ({', '.join(['%s' for _ in range(dataframe.shape[1])])})"
dataframe = list(dataframe.itertuples(index=False, name=None))
cur.executemany(query, dataframe)
conn.commit()
return "Data stored"
def preprocess_data():
"""Preprocess raw data and store it in clean_data database."""
# retrieve raw data
mysql_hook = MySqlHook(mysql_conn_id="raw_data", schema="raw_data")
conn = mysql_hook.get_conn()
query = "SELECT * FROM `raw_clients`"
dataframe = pd.read_sql(query, con=conn)
# drop useless column
dataframe.drop(columns=["id"], inplace=True)
# fill empty fields
dataframe.fillna(0, inplace=True)
# split data: 70% train, 10% val, 20% test
df_train, df_test = train_test_split(
dataframe, test_size=0.2, shuffle=True, random_state=1337
)
df_train, df_val = train_test_split(
df_train, test_size=0.125, shuffle=True, random_state=1337
)
# store data partitions in database
store_data(df_train, "clean_clients_train")
store_data(df_val, "clean_clients_val")
store_data(df_test, "clean_clients_test")
return "Data preprocessed"
with DAG(
"preprocess_data",
description="Fetch raw data, preprocess it and save it in mysql database",
start_date=datetime(2024, 9, 18, 0, 2),
schedule_interval="@once",
) as dag:
preprocess_task = PythonOperator(
task_id="preprocess_data", python_callable=preprocess_data
)
preprocess_task