Skip to content

Commit

Permalink
Refine documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
kiroah committed Nov 16, 2024
1 parent 71f360f commit 452e44e
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 81 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
},
"source": [
"# Why You Can't Add All Features Into Causal Inference Even With ML\n",
"# Part 3 - Reverse causal\n",
"# Part 2 - Reverse causal\n",
"Hiro Naito"
]
},
Expand All @@ -22,7 +22,7 @@
},
{
"cell_type": "code",
"execution_count": 187,
"execution_count": 4,
"id": "7438faea-b9f1-4088-9903-69b44b03557a",
"metadata": {
"tags": []
Expand Down Expand Up @@ -59,16 +59,16 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 15,
"id": "8df78103-e2f2-4207-b884-e88f686a0274",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"def create_graph(causal_matrix: pd.DataFrame): \n",
"\"\"\" Helper function to create an nx graph given a pandas dataframe storing the cause (rows) and effect (in columns)\n",
"\"\"\"\n",
" \"\"\" Helper function to create an nx graph given a pandas dataframe storing the cause (rows) and effect (in columns)\n",
" \"\"\"\n",
" G = nx.DiGraph()\n",
" G.add_nodes_from(causal_matrix.columns.to_list())\n",
" for i, row in enumerate(causal_matrix.index):\n",
Expand All @@ -80,8 +80,8 @@
" return G\n",
"\n",
"def create_synth_data(G : nx.DiGraph, causal_matrix: pd.DataFrame, n: int, binaries: list[str] = []):\n",
"\"\"\" Helper function to create synthetic data given causal graph + nx graph. \n",
"\"\"\"\n",
" \"\"\" Helper function to create synthetic data given causal graph + nx graph. \n",
" \"\"\"\n",
" #Get list of nodes in topolocial order\n",
" topological_sorted = list(nx.topological_sort(G))\n",
"\n",
Expand All @@ -103,7 +103,7 @@
},
{
"cell_type": "code",
"execution_count": 25,
"execution_count": 17,
"id": "6c881ae2-ee10-4582-a4eb-d23a693ece9b",
"metadata": {
"tags": []
Expand Down Expand Up @@ -399,7 +399,7 @@
},
{
"cell_type": "code",
"execution_count": 203,
"execution_count": 19,
"id": "8b5c69aa-4f07-447e-a8ac-6ffa6ace826a",
"metadata": {
"tags": []
Expand Down Expand Up @@ -430,7 +430,7 @@
},
{
"cell_type": "code",
"execution_count": 220,
"execution_count": 21,
"id": "501c5193-fc4f-4967-ae67-4ab30eae2fde",
"metadata": {
"tags": []
Expand Down Expand Up @@ -477,7 +477,7 @@
},
{
"cell_type": "code",
"execution_count": 226,
"execution_count": 22,
"id": "73c1e247-f0ab-48b6-948e-5b49cdaab0a8",
"metadata": {
"tags": []
Expand Down Expand Up @@ -533,7 +533,7 @@
},
{
"cell_type": "code",
"execution_count": 207,
"execution_count": 24,
"id": "006a0d82-4303-42c5-983f-1c3678eecce9",
"metadata": {
"tags": []
Expand All @@ -548,17 +548,17 @@
"Dep. Variable: Y No. Observations: 100000\n",
"Model: GLM Df Residuals: 99998\n",
"Model Family: Gaussian Df Model: 1\n",
"Link Function: Identity Scale: 1.0008\n",
"Method: IRLS Log-Likelihood: -1.4193e+05\n",
"Date: Wed, 02 Oct 2024 Deviance: 1.0008e+05\n",
"Time: 09:07:23 Pearson chi2: 1.00e+05\n",
"No. Iterations: 3 Pseudo R-squ. (CS): 0.6310\n",
"Link Function: Identity Scale: 0.99526\n",
"Method: IRLS Log-Likelihood: -1.4166e+05\n",
"Date: Thu, 03 Oct 2024 Deviance: 99524.\n",
"Time: 12:11:12 Pearson chi2: 9.95e+04\n",
"No. Iterations: 3 Pseudo R-squ. (CS): 0.6347\n",
"Covariance Type: nonrobust \n",
"==============================================================================\n",
" coef std err z P>|z| [0.025 0.975]\n",
"------------------------------------------------------------------------------\n",
"Intercept 0.0026 0.004 0.579 0.563 -0.006 0.011\n",
"T 1.9978 0.006 315.756 0.000 1.985 2.010\n",
"Intercept 0.0003 0.004 0.075 0.941 -0.008 0.009\n",
"T 2.0023 0.006 317.347 0.000 1.990 2.015\n",
"==============================================================================\n"
]
}
Expand All @@ -574,7 +574,7 @@
},
{
"cell_type": "code",
"execution_count": 228,
"execution_count": 29,
"id": "b387abc0-b07e-40d4-a4ea-55745853488e",
"metadata": {
"tags": []
Expand All @@ -589,19 +589,19 @@
"Dep. Variable: Y No. Observations: 100000\n",
"Model: GLM Df Residuals: 99996\n",
"Model Family: Gaussian Df Model: 3\n",
"Link Function: Identity Scale: 0.047664\n",
"Method: IRLS Log-Likelihood: 10287.\n",
"Date: Wed, 02 Oct 2024 Deviance: 4766.2\n",
"Time: 09:22:51 Pearson chi2: 4.77e+03\n",
"Link Function: Identity Scale: 0.047738\n",
"Method: IRLS Log-Likelihood: 10209.\n",
"Date: Thu, 03 Oct 2024 Deviance: 4773.6\n",
"Time: 12:11:21 Pearson chi2: 4.77e+03\n",
"No. Iterations: 3 Pseudo R-squ. (CS): 1.000\n",
"Covariance Type: nonrobust \n",
"==============================================================================\n",
" coef std err z P>|z| [0.025 0.975]\n",
"------------------------------------------------------------------------------\n",
"Intercept 0.0014 0.001 1.431 0.152 -0.001 0.003\n",
"T 0.0944 0.002 48.947 0.000 0.091 0.098\n",
"R1 -0.0969 0.002 -44.869 0.000 -0.101 -0.093\n",
"R2 0.0957 0.001 154.142 0.000 0.094 0.097\n",
"Intercept -0.0016 0.001 -1.613 0.107 -0.003 0.000\n",
"T 0.0964 0.002 49.823 0.000 0.093 0.100\n",
"R1 -0.0919 0.002 -42.393 0.000 -0.096 -0.088\n",
"R2 0.0943 0.001 151.496 0.000 0.093 0.096\n",
"==============================================================================\n"
]
}
Expand All @@ -617,7 +617,7 @@
},
{
"cell_type": "code",
"execution_count": 68,
"execution_count": 33,
"id": "ead6991b-2231-435a-88ad-00bb8ac1ee5c",
"metadata": {
"tags": []
Expand All @@ -632,18 +632,18 @@
"Dep. Variable: Y No. Observations: 100000\n",
"Model: GLM Df Residuals: 99997\n",
"Model Family: Gaussian Df Model: 2\n",
"Link Function: Identity Scale: 0.10016\n",
"Method: IRLS Log-Likelihood: -26842.\n",
"Date: Tue, 24 Sep 2024 Deviance: 10015.\n",
"Time: 11:03:33 Pearson chi2: 1.00e+04\n",
"Link Function: Identity Scale: 0.048596\n",
"Method: IRLS Log-Likelihood: 9318.8\n",
"Date: Thu, 03 Oct 2024 Deviance: 4859.4\n",
"Time: 12:11:33 Pearson chi2: 4.86e+03\n",
"No. Iterations: 3 Pseudo R-squ. (CS): 1.000\n",
"Covariance Type: nonrobust \n",
"==============================================================================\n",
" coef std err z P>|z| [0.025 0.975]\n",
"------------------------------------------------------------------------------\n",
"Intercept 0.0010 0.001 0.673 0.501 -0.002 0.004\n",
"T 0.2025 0.003 73.164 0.000 0.197 0.208\n",
"R2 0.2998 0.000 948.329 0.000 0.299 0.300\n",
"Intercept -0.0015 0.001 -1.512 0.130 -0.003 0.000\n",
"T 0.0978 0.002 50.124 0.000 0.094 0.102\n",
"R2 0.0680 4.87e-05 1395.713 0.000 0.068 0.068\n",
"==============================================================================\n"
]
}
Expand All @@ -667,16 +667,14 @@
},
{
"cell_type": "code",
"execution_count": 352,
"execution_count": 375,
"id": "b3681867-2fe7-42ba-846a-7f49a29b1787",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"# Import causal matrix and create graph + synthetic data based on the matrix\n",
"\n",
"\n",
"causal_matrix = pd.read_csv(\"Dont add all features - Reverse causal.csv\", index_col=0)\n",
"causal_matrix.fillna(0.0,inplace=True)\n",
"#Define the directional relationship of the nodes as a matrix. 0.0 indicates there's no relationship\n",
Expand Down Expand Up @@ -741,7 +739,7 @@
},
{
"cell_type": "code",
"execution_count": 354,
"execution_count": 376,
"id": "9b039e3a-e7ce-4a46-a4f9-2b7aafdb32e0",
"metadata": {
"tags": []
Expand All @@ -756,17 +754,17 @@
"Dep. Variable: Y No. Observations: 100000\n",
"Model: GLM Df Residuals: 99998\n",
"Model Family: Gaussian Df Model: 1\n",
"Link Function: Identity Scale: 863.18\n",
"Method: IRLS Log-Likelihood: -4.7992e+05\n",
"Date: Wed, 02 Oct 2024 Deviance: 8.6316e+07\n",
"Time: 09:44:40 Pearson chi2: 8.63e+07\n",
"No. Iterations: 3 Pseudo R-squ. (CS): 0.0009399\n",
"Link Function: Identity Scale: 857.12\n",
"Method: IRLS Log-Likelihood: -4.7957e+05\n",
"Date: Wed, 02 Oct 2024 Deviance: 8.5710e+07\n",
"Time: 11:18:25 Pearson chi2: 8.57e+07\n",
"No. Iterations: 3 Pseudo R-squ. (CS): 0.0009932\n",
"Covariance Type: nonrobust \n",
"==============================================================================\n",
" coef std err z P>|z| [0.025 0.975]\n",
"------------------------------------------------------------------------------\n",
"Intercept 0.1218 0.131 0.926 0.354 -0.136 0.379\n",
"T 1.8018 0.186 9.697 0.000 1.438 2.166\n",
"Intercept 0.1818 0.131 1.388 0.165 -0.075 0.438\n",
"T 1.8458 0.185 9.968 0.000 1.483 2.209\n",
"==============================================================================\n"
]
}
Expand All @@ -782,7 +780,7 @@
},
{
"cell_type": "code",
"execution_count": 314,
"execution_count": 377,
"id": "32764fe1-f29c-49ba-baf0-3ac8e2f28b87",
"metadata": {
"tags": []
Expand All @@ -797,31 +795,31 @@
"Dep. Variable: Y No. Observations: 100000\n",
"Model: GLM Df Residuals: 99984\n",
"Model Family: Gaussian Df Model: 15\n",
"Link Function: Identity Scale: 0.047400\n",
"Method: IRLS Log-Likelihood: 10571.\n",
"Date: Wed, 02 Oct 2024 Deviance: 4739.2\n",
"Time: 09:28:17 Pearson chi2: 4.74e+03\n",
"Link Function: Identity Scale: 0.048072\n",
"Method: IRLS Log-Likelihood: 9867.0\n",
"Date: Wed, 02 Oct 2024 Deviance: 4806.4\n",
"Time: 11:18:26 Pearson chi2: 4.81e+03\n",
"No. Iterations: 3 Pseudo R-squ. (CS): 1.000\n",
"Covariance Type: nonrobust \n",
"==============================================================================\n",
" coef std err z P>|z| [0.025 0.975]\n",
"------------------------------------------------------------------------------\n",
"Intercept -0.0009 0.001 -0.733 0.464 -0.003 0.002\n",
"T 0.0970 0.003 38.077 0.000 0.092 0.102\n",
"R1 -0.0977 0.002 -45.351 0.000 -0.102 -0.093\n",
"R2 0.0959 0.001 155.053 0.000 0.095 0.097\n",
"X1 0.2398 0.003 69.819 0.000 0.233 0.247\n",
"X2 0.1199 0.002 66.038 0.000 0.116 0.123\n",
"X3 -0.0013 0.002 -0.726 0.468 -0.005 0.002\n",
"X4 -0.0020 0.004 -0.575 0.565 -0.009 0.005\n",
"X5 -0.5702 0.001 -467.954 0.000 -0.573 -0.568\n",
"X6 -0.0956 0.001 -83.261 0.000 -0.098 -0.093\n",
"X7 -0.1903 0.001 -249.529 0.000 -0.192 -0.189\n",
"X8 -0.0014 0.001 -1.831 0.067 -0.003 0.000\n",
"X9 0.0010 0.002 0.637 0.524 -0.002 0.004\n",
"X10 -0.0954 0.002 -57.433 0.000 -0.099 -0.092\n",
"X11 -0.0955 0.001 -80.588 0.000 -0.098 -0.093\n",
"X12 0.0941 0.002 62.138 0.000 0.091 0.097\n",
"Intercept -0.0007 0.001 -0.511 0.609 -0.003 0.002\n",
"T 0.0945 0.003 36.817 0.000 0.089 0.100\n",
"R1 -0.0946 0.002 -43.354 0.000 -0.099 -0.090\n",
"R2 0.0950 0.001 151.670 0.000 0.094 0.096\n",
"X1 0.2383 0.003 68.706 0.000 0.231 0.245\n",
"X2 0.1193 0.002 65.069 0.000 0.116 0.123\n",
"X3 -0.0009 0.002 -0.473 0.636 -0.005 0.003\n",
"X4 -0.0013 0.004 -0.368 0.713 -0.008 0.006\n",
"X5 -0.5709 0.001 -463.560 0.000 -0.573 -0.569\n",
"X6 -0.0945 0.001 -81.093 0.000 -0.097 -0.092\n",
"X7 -0.1909 0.001 -247.158 0.000 -0.192 -0.189\n",
"X8 -0.0005 0.001 -0.626 0.532 -0.002 0.001\n",
"X9 -0.0009 0.002 -0.594 0.553 -0.004 0.002\n",
"X10 -0.0944 0.002 -56.175 0.000 -0.098 -0.091\n",
"X11 -0.0953 0.001 -79.683 0.000 -0.098 -0.093\n",
"X12 0.0953 0.002 62.512 0.000 0.092 0.098\n",
"==============================================================================\n"
]
}
Expand All @@ -837,7 +835,7 @@
},
{
"cell_type": "code",
"execution_count": 26,
"execution_count": 378,
"id": "b30e4890-fb53-4963-b816-dbe2fe5fc5cf",
"metadata": {},
"outputs": [],
Expand All @@ -848,26 +846,26 @@
},
{
"cell_type": "code",
"execution_count": 27,
"execution_count": 379,
"id": "20dd8f37-0081-4dfb-bff0-6d0299e0b2f8",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.002021 seconds.\n",
"[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.000721 seconds.\n",
"You can set `force_col_wise=true` to remove the overhead.\n",
"[LightGBM] [Info] Total Bins 3570\n",
"[LightGBM] [Info] Number of data points in the train set: 40133, number of used features: 14\n",
"[LightGBM] [Info] Start training from score 0.158939\n",
"[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.001765 seconds.\n",
"[LightGBM] [Info] Number of data points in the train set: 40008, number of used features: 14\n",
"[LightGBM] [Info] Start training from score 0.157211\n",
"[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.000811 seconds.\n",
"You can set `force_col_wise=true` to remove the overhead.\n",
"[LightGBM] [Info] Total Bins 3570\n",
"[LightGBM] [Info] Number of data points in the train set: 39867, number of used features: 14\n",
"[LightGBM] [Info] Start training from score 2.323433\n",
"[LightGBM] [Info] Number of data points in the train set: 39992, number of used features: 14\n",
"[LightGBM] [Info] Start training from score 2.022516\n",
"\n",
"Estimated Conditional Average Treatment Effect (ATE): 0.02095223713821671\n"
"Estimated Conditional Average Treatment Effect (ATE): -0.0192182084241143\n"
]
}
],
Expand All @@ -888,7 +886,7 @@
},
{
"cell_type": "code",
"execution_count": 28,
"execution_count": 386,
"id": "7c62e59d-a8f0-4b76-819a-02cad7a9ae15",
"metadata": {
"tags": []
Expand All @@ -899,7 +897,7 @@
"output_type": "stream",
"text": [
"\n",
"Estimated Conditional Average Treatment Effect (ATE): 0.0037685241626691664\n"
"Estimated Conditional Average Treatment Effect (ATE): 0.003892689833916613\n"
]
},
{
Expand All @@ -911,7 +909,7 @@
}
],
"source": [
"#Estimate using DML and using light GBM as the underlying model\n",
"#Estimate using DML and using Random forest the underlying model\n",
"\n",
"#Define DML (Double Machine Learning) setup\n",
"dml = DML(\n",
Expand Down
2 changes: 1 addition & 1 deletion blog_content/dont_add_all_features_reverse_causal.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
},
"source": [
"# Why You Can't Add All Features Into Causal Inference Even With ML\n",
"# Part 3 - Reverse causal\n",
"# Part 2 - Reverse causal\n",
"Hiro Naito"
]
},
Expand Down

0 comments on commit 452e44e

Please sign in to comment.