-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathetl_twitter_dag.py
219 lines (180 loc) · 9.74 KB
/
etl_twitter_dag.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
#!/usr/bin/env python3
"""
@author: Naman Gala
Script to define the data pipeline as Airflow DAG that performs ETL (Extract Load Transform) tasks such as
scraping tweets from twitter, labelling, cleaning, normalizing and preprocessing the raw data to be used
for analysis and model training on scheduled interval.
"""
import os
import json
import sys
from datetime import datetime
from airflow.decorators import task, dag
from airflow.utils.task_group import TaskGroup
from airflow.operators.python import PythonOperator
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
from airflow.providers.snowflake.hooks.snowflake import SnowflakeHook
from snowflake.connector.pandas_tools import write_pandas
from airflow.models.connection import Connection
from task_definitions.etl_task_definitions import scrap_raw_tweets_from_web, preprocess_tweets
from task_definitions.etl_task_definitions import add_sentiment_labels_to_tweets
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
from utils.helper import Config, Connections
from utils.helper import load_dataframe
# Load all configurations from config.toml
config = Config()
@dag(dag_id = "etl", start_date = datetime(2023,1,1), schedule_interval = "@monthly", catchup = False)
def twitter_data_pipeline_dag_etl() -> None:
"""
Data pipeline for performing ETL task that has to be used for training.
Returns
-------
None
"""
@task(task_id = "configure_connections")
def set_connections() -> None:
"""
Task 1 => Configure and establish respective connections for external services like
AWS S3 buckets and Snowflake data warehouse. The credentials are stored as docker secrets
in respective containers and accessed as environment variables for secure usage which
restricts them from getting leaked in the docker image or repository.
Note:
AWS credentials are generated using specific IAM users and roles.
Returns
-------
None
"""
# AWS S3 connection
aws_access_key_id = os.environ["AWS_ACCESS_KEY_ID"]
aws_secret_access_key = os.environ["AWS_SECRET_ACCESS_KEY"]
aws_region_name = os.environ["REGION"]
s3_credentials = json.dumps(
dict(
aws_access_key_id = aws_access_key_id,
aws_secret_access_key = aws_secret_access_key,
aws_region_name = aws_region_name,
)
)
s3_connection = Connection(conn_id = "s3_connection",
conn_type = "S3",
extra = s3_credentials
)
s3_conn_response = Connections(s3_connection).create_connections()
# Snowflake connection
login = os.environ["LOGIN"]
password = os.environ["PASSWORD"]
host_name = os.environ["HOST"]
snowflake_connection = Connection(conn_id = "snowflake_conn",
conn_type = "Snowflake",
host = host_name,
login = login,
password = password
)
snowflake_conn_response = Connections(snowflake_connection).create_connections()
if not s3_conn_response and snowflake_conn_response:
print("Connection not established!!")
#Instantiating S3 hook for respective tasks
s3_hook = S3Hook(aws_conn_id = config["aws"]["connection_id"])
# Task 2 => Refer respective task definition for documentation
scrap_raw_tweets_from_web_ = PythonOperator(
task_id = "scrap_raw_tweets_from_web",
python_callable = scrap_raw_tweets_from_web,
op_kwargs = {
's3_hook': s3_hook,
'bucket_name': config["aws"]["s3_bucket_name"],
'search_query': config["tweets-scraping"]["search_query"],
'tweet_limit': config["tweets-scraping"]["tweet_limit"],
'raw_file_name': config["files"]["raw_file_name"]
}
)
@task(task_id = "download_from_s3")
def download_data_from_s3_bucket(temp_data_path: str, file_name: str) -> None:
"""
Task 3 => Download data stored in S3 buckets for usage.
Parameters
----------
temp_data_path: str
Path to save downloaded file.
file_name: str
Name of the downloaded file.
Returns
-------
None
"""
# Creating a S3 hook using the connection created via task 1.
downloaded_file = s3_hook.download_file(
key = file_name,
bucket_name = config["aws"]["s3_bucket_name"],
local_path = temp_data_path
)
os.rename(src = downloaded_file, destination = f"{temp_data_path}/{file_name}")
with TaskGroup(group_id = "sentiment_labelling") as group1:
#Task 4 => Refer respective task definition for documentation
add_sentiment_labels_to_scrapped_tweets_ = PythonOperator(
task_id = "add_sentiment_labels_to_scrapped_tweets",
python_callable = add_sentiment_labels_to_tweets,
op_kwargs = {
's3_hook': s3_hook,
'bucket_name': config["aws"]["s3_bucket_name"],
'temp_data_path': config["aws"]["temp_data_path"],
'raw_file_name': config["files"]["raw_file_name"],
'labelled_file_name': config["files"]["labelled_file_name"],
}
)
# Prioritizing every downstream tasks pertaining to task group 1
download_data_from_s3_bucket(config["aws"]["temp_data_path"], config["files"]["raw_file_name"]) >> add_sentiment_labels_to_scrapped_tweets_
with TaskGroup(group_id = "preprocess_tweets_using_NLP") as group2:
#Task 5 => Refer respective task definition for documentation
preprocess_tweets_ = PythonOperator(
task_id = "preprocess_labelled_tweets_using_nlp_techniques",
python_callable = preprocess_tweets,
op_kwargs = {
's3_hook': s3_hook,
'bucket_name': config["aws"]["s3_bucket_name"],
'temp_data_path': config["aws"]["temp_data_path"],
'labelled_file_name': config["files"]["labelled_file_name"],
'preprocessed_file_name': config["files"]["preprocessed_file_name"]
}
)
# Prioritizing every downstream tasks pertaining to task group 2
download_data_from_s3_bucket(config["aws"]["temp_data_path"], config["files"]["labelled_file_name"]) >> preprocess_tweets_
@task(task_id = "load_processed_data_to_datawarehouse")
def load_processed_data_to_snowflake(processed_file: str, table_name: str) -> None:
"""
Task 6 => Load and write final processed data into snowflake data warehouse. It loads the processed parquet
file as dataframe and loads it as a database table into the data warehouse.
Parameters
----------
processed_file: str
Name of preprocessed parquet file.
table_name: str
Name of the database table in snowflake data warehouse.
Returns
-------
None
"""
try:
# Similar to S3 hook, snowflake hook is used accordingly
snowflake_conn = SnowflakeHook(
snowflake_conn_id = "snowflake_conn",
account = os.environ["ACCOUNT"],
warehouse = os.environ["WAREHOUSE"],
database = os.environ["DATABASE"],
schema = os.environ["SCHEMA"],
role = os.environ["ROLE"]
)
dataframe = load_dataframe(processed_file)
# Functionality to write any pandas dataframe into snowflake
write_pandas(
conn = snowflake_conn,
df = dataframe,
table_name = table_name,
quote_identifiers = False
)
except Exception as exc:
raise ConnectionError("Something went wrong with the snowflake connection. Please check them!!") from exc
finally:
snowflake_conn.close()
# Prioritizing every downstream tasks pertaining to the entire DAG
set_connections() >> scrap_raw_tweets_from_web_>> group1 >> group2 >> load_processed_data_to_snowflake(config["files"]["preprocessed_file_name"], config["misc"]["table_name"])
etl_dag = twitter_data_pipeline_dag_etl()