From e1c012606fbff21c256ad6dc84e82e3a473ba4cb Mon Sep 17 00:00:00 2001 From: Xuxiang Sun <47307188+ryansun117@users.noreply.github.com> Date: Wed, 14 Sep 2022 17:12:09 -0500 Subject: [PATCH] Add SQL database to graph conversion tool (Db2Graph) (#99) * Migrate from internal * Change to delta and limit size: - delta change to avoid div by zero error - limit size max 1 billion as too laarge values lead to postgres error * Adding the workflow and pytest for db2graph * For verification testing * Workflow verified branch db2graph removed * Changing the position because affecting build test * Updated file name and setup.cfg; cleanup dup code, typos * Small changes after previous commit Renamed the file to marius_db2graph to align with commands like marius_preprocess; Created two new functions for get_fetch_size to avoid duplicate code; Added the marius_db2graph command to setup.cfg (but haven't tested it because I'm not sure if pip installing it right now would work); Added 'my-sql' as an option to use mysql-connector because this wasn't added previously; * Changing installation via marius path * Updating testing setup * hydra-core only 1.1.2 version working for tests * Updated the documentation with marius_db2graph * Updated the code to you only omegaconf * Resolving review comments * pushed repetative code to a function * hydra-core added right now till other PR merges * Removed edges_entity_feature_values * Removed generate_uuid and related parts * Resolving other review comments * Workflow naming restrictions * wrong file name * Updated the workflow using matrix * Correcting a typo * Matched naming in code and documentation Renamed edge_entity_entity_queries to edge_queries, edge_entity_entity_queries_list to edge_queries_list, and edge_entity_entity_rel_list to edge_rel_list * Sample dockerfile for the Sakila dataset * Added end-to-end example; Modified basicConfig usage Added dockerfile, run.sh, sakila.yaml Modified basicConfig in marius_db2graph.py to avoid Python version issue * Updated script to correctly set password; Fixed typo * Updated validation for better info & moved files * Moved file to proper folder * Moved by mistake * Updated doc to reflect parsing and logging changes; Changed dockerfile to install from marius main * Fix doc typos * apply autoformatter * add optional db2graph dependency * fix linter issues * include test dependencies in github actions * update github actions Co-authored-by: mohilp1998 Co-authored-by: Roger Waleffe Co-authored-by: Jason Mohoney --- .github/workflows/db2graph_test_postgres.yml | 48 ++ docs/db2graph/db2graph.rst | 512 +++++++++++++++++++ examples/configuration/sakila.yaml | 48 ++ examples/db2graph/dockerfile | 48 ++ examples/db2graph/run.sh | 12 + setup.cfg | 4 + src/python/tools/db2graph/marius_db2graph.py | 419 +++++++++++++++ test/db2graph/test_postgres.py | 296 +++++++++++ 8 files changed, 1387 insertions(+) create mode 100644 .github/workflows/db2graph_test_postgres.yml create mode 100644 docs/db2graph/db2graph.rst create mode 100644 examples/configuration/sakila.yaml create mode 100644 examples/db2graph/dockerfile create mode 100644 examples/db2graph/run.sh create mode 100644 src/python/tools/db2graph/marius_db2graph.py create mode 100644 test/db2graph/test_postgres.py diff --git a/.github/workflows/db2graph_test_postgres.yml b/.github/workflows/db2graph_test_postgres.yml new file mode 100644 index 00000000..eb402bb7 --- /dev/null +++ b/.github/workflows/db2graph_test_postgres.yml @@ -0,0 +1,48 @@ +name: Testing DB2GRAPH using postgres +on: + push: + branches: + - main + pull_request: + branches: + - main + +jobs: + + db2graph: + runs-on: ubuntu-latest + container: ${{ matrix.python_container }} + strategy: + matrix: + python_container: ["python:3.6", "python:3.7", "python:3.8", "python:3.9", "python:3.10"] + + services: + postgres: + # Docker Hub image + image: postgres + # Provide the password for postgres + env: + POSTGRES_PASSWORD: postgres + # Set health checks to wait until postgres has started + options: >- + --health-cmd pg_isready + --health-interval 10s + --health-timeout 5s + --health-retries 5 + + steps: + # Downloads a copy of the code in your repository before running CI tests + - name: Check out repository code + uses: actions/checkout@v3 + + - name: Installing dependencies + run: MARIUS_NO_BINDINGS=1 python3 -m pip install .[db2graph,tests] + + - name: Running pytest + run: MARIUS_NO_BINDINGS=1 pytest -s test/db2graph/test_postgres.py + # Environment variables used in the test + env: + # The hostname used to communicate with the PostgreSQL service container + POSTGRES_HOST: postgres + # The default PostgreSQL port - using default port + POSTGRES_PORT: 5432 \ No newline at end of file diff --git a/docs/db2graph/db2graph.rst b/docs/db2graph/db2graph.rst new file mode 100644 index 00000000..ca8923cb --- /dev/null +++ b/docs/db2graph/db2graph.rst @@ -0,0 +1,512 @@ +Db2Graph: Database to Graph conversion tool +============================================ + +Introduction +"""""""""""""""""""" + +**Db2Graph** converts **relational databases** into **graphs as sets of triples** which can be used as **input datasets for Marius**, allowing streamlined preprocessing from database to Marius. Db2Graph comes with Marius but can be used as a standalone tool. Conversion with Db2Graph is achieved in the following steps: + +#. Users import/create the database locally + +#. Users define the configuration file and edge SQL SELECT queries + +#. Db2Graph executes the SQL SELECT queries + +#. Db2Graph transforms the result set of queries into sets of triples + +Below we lay out the requirements, definitions, and steps for using Db2Graph, and a real example use case: + +Requirements +"""""""""""""""""""" + +Db2Graph currently supports graph conversion from three relational database management systems: **MySQL**, **MariaDB**, and **PostgreSQL**. Db2Graph requires no additional installation as all the required python packages are part of Marius installation. Please refer to `mairus installation `_ for installing the required packages. + +System Design +"""""""""""""""""""" + +Db2Graph classifies a graph into the following two types: + +* Entity Nodes: Nodes that are globally unique. Global uniqueness is ensured by appending ``table-name_col-name_val`` to the literal. In a graph, entity nodes either point to other entity nodes or are pointed to by other entity nodes. +* Edges of Entity Node to Entity Node: Directed edges where both source and destination are entity nodes. + +During the conversion, we assume that all nodes are **case insensitive**. We ignore the following set of **invalid nodes names**: ``"0", None, "", 0, "not reported", "None", "none"``. + +Db2Graph outputs a set of triplets in the format of ``[source node] [edge] [destination node]`` where each element in the triplets is delimited by a single tab. This output format aligns with the input format of Marius, allowing streamlined preprocessing from database to using Marius. + +How to Use +"""""""""""""""""""" + +First, make sure marius is installed with the optional db2graph dependencies: `python3 -m pip install .[db2graph]`. + +Assuming that a database has already been created, graph conversion with Db2Graph can be achieved in the following steps: + +#. | First, create a YAML configuration file ``config.yaml`` and a query definition files to contain SQL SELECT queries of type ``edges_queries``. Assume that the config file and query file are placed in a ``./conf/`` directory. + + .. code-block:: bash + + $ ls -l . + conf/ + config.yaml # config file + edges_queries.txt # defines edges_queries + + | Define the configuration file in ``config.yaml``. Below is a sample configuration file. Note that all fields are required. An error would be thrown if the query files do not exist. + + .. code-block:: yaml + + db_server: postgre-sql + db_name: sample_db + db_user: sample_user + db_password: sample_password + db_host: localhost + edges_queries: conf/edges_queries.txt + + .. list-table:: + :widths: 15 10 50 15 + :header-rows: 1 + + * - Key + - Type + - Description + - Required + * - db_server + - String + - Denotes the RDBMS to use. Options: [“maria-db”, “postgre-sql”, "my-sql"]. + - Yes + * - db_name + - String + - Denotes the name of the database. + - Yes + * - db_user + - String + - Denotes the user name to access the database. + - Yes + * - db_password + - String + - Password to access the database. + - Yes + * - db_host + - String + - Denotes the hostname of the database. + - Yes + * - edges_queries + - String + - Path to the text file that contains the SQL SELECT queries fetching edges from entity nodes to entity nodes. + - Yes + +#. | Next, define SQL SELECT queries. Assume the file ``conf/edges_queries.txt`` has been created. In it, define queries with the following format with no empty lines in-between lines. Each edge consists of two rows: A single ``relation_name`` followed by another row of SQL SELECT query. Note that you can include any SQL keyword after WHERE clause. + + .. code-block:: sql + + relation_name_A_to_B -- this is the name of the edge from A to B + SELECT table1_name.column_name_A, table2_name.column_name_B FROM table1_name, table1_name WHERE ...; -- this row represents an edge from source entity node A to destination entity node B + relation_name_B_to_C -- this is the name of the edge from B to C + SELECT table1_name.column_name_B, table2_name.column_name_C FROM table1_name, table2_name WHERE ...; -- this row represents an edge from source entity node B to destination entity node C + + | The user can expand or shorten the list of queries in the above query definition file to query a certain subset of data from the database. + + .. note:: + Db2Graph validates the correctness of format of each query. However, it does not validate the correctness of the queries. That is, it assumes that all column names and table names exist in the given database schema provided by the user. An error will be thrown in the event that the validation check fails. + + .. note:: + There cannot be ``AS`` alias within the queries. Any alias violates the correctness of the queries in Db2Graph. + +#. | Lastly, execute Db2Graph with the following commands. Two flags are required. Note that prints will include both errors and general information, and those are also logged to ``./output_dir/output.log``: + + .. code-block:: bash + + $ marius_db2graph --config_path conf/config.yaml --output_directory output_dir/ + Starting marius_db2graph conversion tool for config: conf/config.yaml + ... + Edge file written to output_dir/edges.txt + + | The ``--config_path`` flag specifies where the configuration file created by the user is. + + | The ``--output_directory`` flag specifies where the data will be output and is set by the user. In this example, assume we have not created the output_dir directory. ``db2graph`` will create it for us. + + | The conversion result will be written to ``edges.txt`` in a newly created directory named ``./output_dir``: + + .. code-block:: bash + + $ ls -l . + output_dir/ + edges.txt # generated file with sets of triples + output.log # output log file + conf/ + config.yaml # config file + edges_queries.txt # defines edges_queries + $ cat output_dir/edges.txt + column_name_A relation_name_A_to_B column_name_B + column_name_B relation_name_B_to_C column_name_C + +End-to-end Example Use Case +"""""""""""""""""""" + +We use `the Sakila DVD store database `_ from MySQL to demonstrate an end-to-end example from converting a database into a graph using Db2Graph to preprocessing and training the dataset using Marius. For simplicity, we have provided a dockerfile and a bash script which install Marius along with Db2Graph and initialize the Sakila database for you. + +#. | First, download an place the provided ``dockerfile`` and ``run.sh`` in the current working directory. Create and run a docker container using the dockerfile. This dockerfile pre-installs Marius and all dependencies needed for using Marius in this end-to-end example. It also copies ``run.sh`` into the container. + + .. code-block:: bash + + $ docker build -t db2graph_image . # Builds a docker image named db2graph_image + $ docker run --name db2graph_container -itd db2graph_image # Create the container named db2graph_container + $ docker exec -it db2graph_container bash # Run the container in interactive mode in bash + + | In the root directory of the container, execute ``run.sh``. This script downloads and initializes the Sakila database. Note that the username is set to ``root``, the database name is set to ``sakila_user``, and the password is set to ``sakila_password``. + + .. code-block:: bash + + $ run.sh + $ cd marius/ + + | To verify that the database has been install correctly: + + .. code-block:: bash + + $ mysql + mysql> USE sakila; + mysql> SHOW FULL tables; + +----------------------------+------------+ + | Tables_in_sakila | Table_type | + +----------------------------+------------+ + | actor | BASE TABLE | + | actor_info | VIEW | + ... + 23 rows in set (0.01 sec) + + .. note:: + + If you see any error of type ``ERROR 2002 (HY000): Can't' connect to local MySQL server through socket '/var/run/mysqld/mysqld.sock' (111)``, run the command ``systemctl start mysql`` and retry. + +#. | Next, create the configuration file for using Db2Graph. Assuming we are in the ``marius/`` root directory, create & navigate to the ``datasets/sakila`` directory. Create the ``conf/config.yaml`` and ``conf/edges_queries.txt`` files if they have not been created. + + .. code-block:: bash + + $ mkdir -p datasets/sakila/conf/ + $ vi datasets/sakila/conf/config.yaml + $ vi datasets/sakila/conf/edges_queries.txt + + | In ``datasets/sakila/conf/config.yaml``, define the following fields: + + .. code-block:: yaml + + db_server: my-sql + db_name: sakila + db_user: sakila_user + db_password: sakila_password + db_host: 127.0.0.1 + edges_queries: datasets/sakila/conf/edges_queries.txt + + | In ``datasets/sakila/conf/edges_queries.txt``, define the following queries. Note that we create three edges/relationships: An actor acted in a film; A film sold by a store; A film categorized as a category. + + .. code-block:: sql + + acted_in + SELECT actor.first_name, film.title FROM actor, film_actor, film WHERE actor.actor_id = film_actor.actor_id AND film_actor.film_id = film.film_id ORDER BY film.title ASC; + sold_by + SELECT film.title, address.address FROM film, inventory, store, address WHERE film.film_id = inventory.film_id AND inventory.store_id = store.store_id AND store.address_id = address.address_id ORDER BY film.title ASC; + categorized_as + SELECT film.title, category.name FROM film, film_category, category WHERE film.film_id = film_category.film_id AND film_category.category_id = category.category_id ORDER BY film.title ASC; + + | For simplicity, we limit the queries to focus on the film table. The user can expand or shorten the list of queries in each of the above query definition files to query a certain subset of data from the database. For the Sakila database structure, please refer to `this MySQL documentation `_. + + .. note:: + + The queries above have ``ORDER BY`` clause at the end, which is not compulsory (and can have performance impact). We have kept it for the example because it will ensure same output across multiple runs. For optimal performance remove the ``ORDER BY`` clause. + +#. | Lastly, execute Db2Graph with the following script: + + .. code-block:: bash + + $ marius_db2graph --config_path datasets/sakila/conf/config.yaml --output_directory datasets/sakila/ + Starting marius_db2graph conversion tool for config: datasets/sakila/conf/config.yaml + ... + Total execution time: 0.382 seconds + Edge file written to datasets/sakila/edges.txt + + | The conversion result was written to ``edges.txt`` in the specified directory ``datasets/sakila/``. In ``edges.txt``, there should be 7915 edges representing the three relationships we defined earlier: + + .. code-block:: bash + + $ ls -1 datasets/sakila/ + edges.txt # generated file with sets of triples + marius_db2graph.log # output log file + conf/ + ... + $ cat datasets/sakila/edges.txt + actor_first_name_rock acted_in film_title_academy dinosaur + actor_first_name_mary acted_in film_title_academy dinosaur + actor_first_name_oprah acted_in film_title_academy dinosaur + ... + + .. note:: + + This concludes the example for using Db2Graph. For an end-to-end example of using Db2Graph with Marius, continue through the sections below. For example, for a custom link prediction example, follow `Custom Link Prediction example `_ from the docs. Please refer to docs/examples to see all the examples. + +#. | Preprocessing and training a custom dataset like the Sakila database is straightforward with the ``marius_preprocess`` and ``marius_train`` commands. These commands come with ``marius`` when ``marius`` is installed. + + .. code-block:: bash + + $ marius_preprocess --output_dir datasets/sakila/ --edges datasets/sakila/edges.txt --dataset_split 0.8 0.1 0.1 --delim="\t" + Preprocess custom dataset + Reading edges + /usr/local/lib/python3.8/dist-packages/marius/tools/preprocess/converters/readers/pandas_readers.py:55: ParserWarning: Falling back to the 'python' engine because the 'c' engine does not support regex separators (separators > 1 char and different from '\s+' are interpreted as regex); you can avoid this warning by specifying engine='python'. + train_edges_df = pd.read_csv(self.train_edges, delimiter=self.delim, skiprows=self.header_length, header=None) + Remapping Edges + Node mapping written to: datasets/sakila/nodes/node_mapping.txt + Relation mapping written to: datasets/sakila/edges/relation_mapping.txt + Splitting into: 0.8/0.1/0.1 fractions + Dataset statistics written to: datasets/sakila/dataset.yaml + + | In the above command, we set ``dataset_split`` to a list of ``0.8 0.1 0.1``. Under the hood, this splits ``edge.txt`` into ``edges/train_edges.bin``, ``edges/validation_edges.bin`` and ``edges/test_edges.bin`` based on the given list of fractions. + + | Note that ``edge.txt`` contains three columns delimited by tabs, so we set ``--delim="\t"``. + + | The ``--edges`` flag specifies the raw edge list file that ``marius_preprocess`` will preprocess (and train later). + + | The ``--output_directory`` flag specifies where the preprocessed graph will be output and is set by the user. In this example, assume we have not created the datasets/fb15k_237_example repository. ``marius_preprocess`` will create it for us. + + | For detailed usages of ``marius_preprocess``, please execute the following command: + + .. code-block:: bash + + $ marius_preprocess -h + + | Let's check again what was created inside the ``datasets/sakila/`` directory: + + .. code-block:: bash + + $ ls -1 datasets/sakila/ + dataset.yaml # input dataset statistics + nodes/ + node_mapping.txt # mapping of raw node ids to integer uuids + edges/ + relation_mapping.txt # mapping of relations to integer uuids + test_edges.bin # preprocessed testing edge list + train_edges.bin # preprocessed training edge list + validation_edges.bin # preprocessed validation edge list + conf/ # directory containing config files + ... + + | Let's check what is inside the generated ``dataset.yaml`` file: + + .. code-block:: bash + + $ cat datasets/sakila/dataset.yaml + dataset_dir: /marius/datasets/sakila/ + num_edges: 6332 + num_nodes: 1146 + num_relations: 3 + num_train: 6332 + num_valid: 791 + num_test: 792 + node_feature_dim: -1 + rel_feature_dim: -1 + num_classes: -1 + initialized: false + + .. note:: + If the above ``marius_preprocess`` command fails due to any missing directory errors, please create the ``/edges`` and ``/nodes`` directories as a workaround. + + | To train a model, we need to define a YAML configuration file based on information created from ``marius_preprocess``. An example YAML configuration file for the Sakila database (link prediction model with DistMult) is given in ``examples/configuration/sakila.yaml``. Note that the ``dataset_dir`` is set to the preprocessing output directory, in our example, ``datasets/sakila/``. + + | Let's create the same YAML configuration file for the Sakila database from scratch. We follow the structure of the configuration file and create each of the four sections one by one. In a YAML file, indentation is used to denote nesting and all parameters are in the format of key-value pairs. + + .. code-block:: bash + + $ vi datasets/sakila/sakila.yaml + + .. note:: + String values in the configuration file are case insensitive but we use capital letters for convention. + + | First, we define the **model**. We begin by setting all required parameters. This includes ``learning_task``, ``encoder``, ``decoder``, and ``loss``. The rest of the configurations can be fine-tuned by the user. + + .. code-block:: yaml + + model: + learning_task: LINK_PREDICTION # set the learning task to link prediction + encoder: + layers: + - - type: EMBEDDING # set the encoder to be an embedding table with 50-dimensional embeddings + output_dim: 50 + decoder: + type: DISTMULT # set the decoder to DistMult + options: + input_dim: 50 + loss: + type: SOFTMAX_CE + options: + reduction: SUM + dense_optimizer: # optimizer to use for dense model parameters. In this case these are the DistMult relation (edge-type) embeddings + type: ADAM + options: + learning_rate: 0.1 + sparse_optimizer: # optimizer to use for node embedding table + type: ADAGRAD + options: + learning_rate: 0.1 + storage: + # omit + training: + # omit + evaluation: + # omit + + | Next, we set the **storage** and **dataset**. We begin by setting all required parameters. This includes ``dataset``. Here, the ``dataset_dir`` is set to ``datasets/sakila/``, which is the preprocessing output directory. + + .. code-block:: yaml + + model: + # omit + storage: + device_type: cuda + dataset: + dataset_dir: /marius/datasets/sakila/ + edges: + type: DEVICE_MEMORY + embeddings: + type: DEVICE_MEMORY + save_model: true + training: + # omit + evaluation: + # omit + + | Lastly, we configure **training** and **evaluation**. We begin by setting all required parameters. We begin by setting all required parameters. This includes ``num_epochs`` and ``negative_sampling``. We set ``num_epochs=10`` (10 epochs to train) to demonstrate this example. Note that ``negative_sampling`` is required for link prediction. + + .. code-block:: yaml + + model: + # omit + storage: + # omit + training: + batch_size: 1000 + negative_sampling: + num_chunks: 10 + negatives_per_positive: 500 + degree_fraction: 0.0 + filtered: false + num_epochs: 10 + pipeline: + sync: true + epochs_per_shuffle: 1 + evaluation: + batch_size: 1000 + negative_sampling: + filtered: true + pipeline: + sync: true + + | After defining our configuration file, training is run with ``marius_train ``. + + | We can now train our example using the configuration file we just created by running the following command (assuming we are in the ``marius`` root directory): + + .. code-block:: bash + + $ marius_train datasets/sakila/sakila.yaml + [2022-06-19 07:01:39.828] [info] [marius.cpp:44] Start initialization + [06/19/22 07:01:44.287] Initialization Complete: 4.458s + [06/19/22 07:01:44.292] ################ Starting training epoch 1 ################ + [06/19/22 07:01:44.308] Edges processed: [1000/6332], 15.79% + [06/19/22 07:01:44.311] Edges processed: [2000/6332], 31.59% + [06/19/22 07:01:44.313] Edges processed: [3000/6332], 47.38% + [06/19/22 07:01:44.315] Edges processed: [4000/6332], 63.17% + [06/19/22 07:01:44.317] Edges processed: [5000/6332], 78.96% + [06/19/22 07:01:44.320] Edges processed: [6000/6332], 94.76% + [06/19/22 07:01:44.322] Edges processed: [6332/6332], 100.00% + [06/19/22 07:01:44.322] ################ Finished training epoch 1 ################ + [06/19/22 07:01:44.322] Epoch Runtime: 29ms + [06/19/22 07:01:44.322] Edges per Second: 218344.83 + [06/19/22 07:01:44.322] Evaluating validation set + [06/19/22 07:01:44.329] + ================================= + Link Prediction: 1582 edges evaluated + Mean Rank: 548.639697 + MRR: 0.005009 + Hits@1: 0.000632 + Hits@3: 0.001264 + Hits@5: 0.001264 + Hits@10: 0.001896 + Hits@50: 0.034766 + Hits@100: 0.075221 + ================================= + [06/19/22 07:01:44.330] Evaluating test set + [06/19/22 07:01:44.333] + ================================= + Link Prediction: 1584 edges evaluated + Mean Rank: 525.809343 + MRR: 0.006225 + Hits@1: 0.000000 + Hits@3: 0.001263 + Hits@5: 0.004419 + Hits@10: 0.005682 + Hits@50: 0.046086 + Hits@100: 0.107323 + ================================= + + | After running this configuration for 10 epochs, we should see a result similar to below: + + .. code-block:: bash + + [06/19/22 07:01:44.524] ################ Starting training epoch 10 ################ + [06/19/22 07:01:44.527] Edges processed: [1000/6332], 15.79% + [06/19/22 07:01:44.529] Edges processed: [2000/6332], 31.59% + [06/19/22 07:01:44.531] Edges processed: [3000/6332], 47.38% + [06/19/22 07:01:44.533] Edges processed: [4000/6332], 63.17% + [06/19/22 07:01:44.536] Edges processed: [5000/6332], 78.96% + [06/19/22 07:01:44.538] Edges processed: [6000/6332], 94.76% + [06/19/22 07:01:44.540] Edges processed: [6332/6332], 100.00% + [06/19/22 07:01:44.540] ################ Finished training epoch 10 ################ + [06/19/22 07:01:44.540] Epoch Runtime: 16ms + [06/19/22 07:01:44.540] Edges per Second: 395749.97 + [06/19/22 07:01:44.540] Evaluating validation set + [06/19/22 07:01:44.544] + ================================= + Link Prediction: 1582 edges evaluated + Mean Rank: 469.225664 + MRR: 0.047117 + Hits@1: 0.030973 + Hits@3: 0.044880 + Hits@5: 0.051833 + Hits@10: 0.071429 + Hits@50: 0.136536 + Hits@100: 0.197219 + ================================= + [06/19/22 07:01:44.544] Evaluating test set + [06/19/22 07:01:44.547] + ================================= + Link Prediction: 1584 edges evaluated + Mean Rank: 456.828283 + MRR: 0.041465 + Hits@1: 0.023990 + Hits@3: 0.040404 + Hits@5: 0.051768 + Hits@10: 0.068813 + Hits@50: 0.147096 + Hits@100: 0.210227 + ================================= + + | Let's check again what was added in the ``datasets/sakila/`` directory. For clarity, we only list the files that were created in training. Notice that several files have been created, including the trained model, the embedding table, a full configuration file, and output logs: + + .. code-block:: bash + + $ ls datasets/sakila/ + model_0/ + embeddings.bin # trained node embeddings of the graph + embeddings_state.bin # node embedding optimizer state + model.pt # contains the dense model parameters, embeddings of the edge-types + model_stlsate.pt # optimizer state of the trained model parameters + node_mapping.txt # mapping of raw node ids to integer uuids + relation_mapping.txt # mapping of relations to integer uuids + full_config.yaml # detailed config generated based on user-defined config + metadata.csv # information about metadata + logs/ # logs containing output, error, debug information, and etc. + nodes/ + ... + edges/ + ... + ... + + .. note:: + ``model.pt`` contains the dense model parameters. For DistMult, this is the embeddings of the edge-types. For GNN encoders, this file will include the GNN parameters. + \ No newline at end of file diff --git a/examples/configuration/sakila.yaml b/examples/configuration/sakila.yaml new file mode 100644 index 00000000..99f37652 --- /dev/null +++ b/examples/configuration/sakila.yaml @@ -0,0 +1,48 @@ +model: + learning_task: LINK_PREDICTION # set the learning task to link prediction + encoder: + layers: + - - type: EMBEDDING # set the encoder to be an embedding table with 50-dimensional embeddings + output_dim: 50 + decoder: + type: DISTMULT # set the decoder to DistMult + options: + input_dim: 50 + loss: + type: SOFTMAX_CE + options: + reduction: SUM + dense_optimizer: # optimizer to use for dense model parameters. In this case these are the DistMult relation (edge-type) embeddings + type: ADAM + options: + learning_rate: 0.1 + sparse_optimizer: # optimizer to use for node embedding table + type: ADAGRAD + options: + learning_rate: 0.1 +storage: + device_type: cuda + dataset: + dataset_dir: /marius/datasets/sakila/ + edges: + type: DEVICE_MEMORY + embeddings: + type: DEVICE_MEMORY + save_model: true +training: + batch_size: 1000 + negative_sampling: + num_chunks: 10 + negatives_per_positive: 500 + degree_fraction: 0.0 + filtered: false + num_epochs: 10 + pipeline: + sync: true + epochs_per_shuffle: 1 +evaluation: + batch_size: 1000 + negative_sampling: + filtered: true + pipeline: + sync: true diff --git a/examples/db2graph/dockerfile b/examples/db2graph/dockerfile new file mode 100644 index 00000000..0384aa47 --- /dev/null +++ b/examples/db2graph/dockerfile @@ -0,0 +1,48 @@ +# setup for Marius +FROM nvidia/cuda:11.4.2-cudnn8-devel-ubuntu20.04 + +ENV TZ=US + +RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone + +RUN apt update + +RUN apt install -y g++ \ + make \ + wget \ + unzip \ + vim \ + git \ + python3-pip \ + build-essential \ + python-dev \ + libpq-dev + +# install cmake 3.20 +RUN wget https://github.com/Kitware/CMake/releases/download/v3.20.0/cmake-3.20.0-linux-x86_64.sh \ + && mkdir /opt/cmake \ + && sh cmake-3.20.0-linux-x86_64.sh --skip-license --prefix=/opt/cmake/ \ + && ln -s /opt/cmake/bin/cmake /usr/local/bin/cmake + +# install pytorch +RUN pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113 && pip3 install docutils==0.17 + +# install Marius +RUN git clone https://github.com/marius-team/marius.git && cd marius && pip3 install . + +# install debconf-set-selections & systemctl +RUN apt-get install debconf + +RUN apt-get install systemctl + +# install mysql-8 +RUN echo "mysql-community-server mysql-community-server/root-pass password password" | debconf-set-selections + +RUN echo "mysql-community-server mysql-community-server/re-root-pass password password" | debconf-set-selections + +RUN DEBIAN_FRONTEND=noninteractive apt-get -y install mysql-server + +# Adding a run.sh script to initialize things +COPY run.sh /usr/local/bin/run.sh + +RUN chmod +x usr/local/bin/run.sh diff --git a/examples/db2graph/run.sh b/examples/db2graph/run.sh new file mode 100644 index 00000000..4a3a39ea --- /dev/null +++ b/examples/db2graph/run.sh @@ -0,0 +1,12 @@ +#!/bin/sh +systemctl start mysql +mkdir /db2graph_eg +wget -O /db2graph_eg/sakila-db.tar.gz https://downloads.mysql.com/docs/sakila-db.tar.gz +tar -xf /db2graph_eg/sakila-db.tar.gz -C /db2graph_eg/ +mysql -u root -p=password < /db2graph_eg/sakila-db/sakila-schema.sql +mysql -u root -p=password < /db2graph_eg/sakila-db/sakila-data.sql +## For creating a new user for accessing the data +mysql -u root -p=password mysql -e "CREATE USER 'sakila_user'@'localhost' IDENTIFIED BY 'sakila_password';" +mysql -u root -p=password mysql -e "GRANT ALL PRIVILEGES ON *.* TO 'sakila_user'@'localhost';" +mysql -u root -p=password mysql -e "FLUSH PRIVILEGES;" +service mysql restart \ No newline at end of file diff --git a/setup.cfg b/setup.cfg index e9cc1079..ea9ecb58 100644 --- a/setup.cfg +++ b/setup.cfg @@ -28,6 +28,9 @@ docs = sphinx-rtd-theme==1.0.0 sphinx-autodoc-typehints==1.17.0 breathe==4.30.0 +db2graph = + psycopg2-binary + mysql-connector-python [options] install_requires = @@ -63,3 +66,4 @@ console_scripts = marius_config_generator = marius.tools.marius_config_generator:main marius_predict = marius.tools.marius_predict:main marius_env_info = marius.distribution.marius_env_info:main + marius_db2graph = marius.tools.db2graph.marius_db2graph:main \ No newline at end of file diff --git a/src/python/tools/db2graph/marius_db2graph.py b/src/python/tools/db2graph/marius_db2graph.py new file mode 100644 index 00000000..c6016ecf --- /dev/null +++ b/src/python/tools/db2graph/marius_db2graph.py @@ -0,0 +1,419 @@ +import argparse +import logging +import re +import sys +import time +from pathlib import Path + +import mysql.connector +import pandas as pd +import psutil +import psycopg2 +from mysql.connector import errorcode +from omegaconf import OmegaConf + +INVALID_ENTRY_LIST = ["0", None, "", 0, "not reported", "None", "none"] +FETCH_SIZE = int(1e4) +MAX_FETCH_SIZE = int(1e9) +OUTPUT_FILE_NAME = "edges.txt" + + +def set_args(): + parser = argparse.ArgumentParser( + description=( + "Db2Graph is tool to generate graphs from relational database using SQL queries. See" + " documentation docs/db2graph for more details." + ), + prog="db2graph", + ) + + parser.add_argument( + "--config_path", + metavar="config_path", + type=str, + default="", + help="Path to the config file. See documentation docs/db2graph for more details.", + ) + + parser.add_argument( + "--output_directory", + metavar="output_directory", + type=str, + default="./", + help="Directory to put output data and log file. See documentation docs/db2graph for more details.", + ) + return parser + + +def config_parser_fn(config_name): + """ + Takes the input yaml config file's name (& relative path). Returns all the extracted data + + :param config_name: file name (& relative path) for the YAML config file + :returns: + - db_server: string denoting database server (initial support only for mariadb) + - db_name: name of the database you need to pull from + - db_user: user name used to access the database + - db_password: password used to access the database + - db_host: hostname of the database + - edges_queries_list: list of sql queries to define edges of type entity nodes to entity nodes + & the names of edges + """ + input_cfg = None + input_config_path = Path(config_name).absolute() + + input_cfg = OmegaConf.load(input_config_path) + + # db_server used to distinguish between different databases + db_server = None + if "db_server" in input_cfg.keys(): + db_server = input_cfg["db_server"] + else: + logging.error("ERROR: db_server is not defined") + exit(1) + + # db_name is the name of the database to pull the data from + db_name = None + if "db_name" in input_cfg.keys(): + db_name = input_cfg["db_name"] + else: + logging.error("ERROR: db_name is not defined") + exit(1) + + # db_user is the user name used to access the database + db_user = None + if "db_user" in input_cfg.keys(): + db_user = input_cfg["db_user"] + else: + logging.error("ERROR: db_user is not defined") + + # db_password is the password used to access the database + db_password = None + if "db_password" in input_cfg.keys(): + db_password = input_cfg["db_password"] + else: + logging.error("ERROR: db_password is not defined") + + # db_host is the hostname of the database + db_host = None + if "db_host" in input_cfg.keys(): + db_host = input_cfg["db_host"] + else: + logging.error("ERROR: db_host is not defined") + + # Getting all edge queries for edge type entity node to entity node + edges_queries_list = list() + edge_rel_list = list() + if "edges_queries" in input_cfg.keys(): + query_filepath = input_cfg["edges_queries"] + + if not Path(query_filepath).exists(): + raise ValueError("{} does not exist".format(str(query_filepath))) + + edge_queries_file = open(query_filepath, "r") + read_lines = edge_queries_file.readlines() + for i in range(len(read_lines)): + read_lines[i] = read_lines[i].strip() + if read_lines[i] == "": + logging.error("Error: Empty lines are not allowed in edges_query file. " + "Please remove them") + exit(1) + + # Removing the last '\n' character + if read_lines[i][-1] == "\n": + read_lines[i] = read_lines[i][:-1] + + # Adding the line to rel_list if even else its a query + if i % 2 == 0: + edge_rel_list.append(read_lines[i]) + else: + edges_queries_list.append(read_lines[i]) + else: + logging.error("ERROR: edges_queries is not defined") + exit(1) + + return db_server, db_name, db_user, db_password, db_host, edges_queries_list, edge_rel_list + + +def connect_to_db(db_server, db_name, db_user, db_password, db_host): + """ + Function takes db_server and db_name as the input. Tries to connect to the database and returns an object + which can be used to execute queries. + Assumption: default user: root, host: 127.0.0.1 and password:"". You will need to change code if otherwise + + :param db_server: The name of the backend database application used for accessing data + :param db_name: The name of the database where the data resides + :param db_user: The user name used to access the database + :param db_password: The password used to access the database + :param db_host: The hostname of the database + :return cnx: cursor object that can be used to execute the database queries + """ + if db_server == "maria-db" or db_server == "my-sql": + try: + cnx = mysql.connector.connect(user=db_user, password=db_password, host=db_host, database=db_name) + except mysql.connector.Error as err: + if err.errno == errorcode.ER_ACCESS_DENIED_ERROR: + logging.error(f"Incorrect user name or password\n{err}") + elif err.errno == errorcode.ER_BAD_DB_ERROR: + logging.error(f"Non-existing database\n{err}") + else: + logging.error(err) + + elif db_server == "postgre-sql": + try: + cnx = psycopg2.connect(user=db_user, password=db_password, host=db_host, database=db_name) + except psycopg2.Error as err: + logging.error(f"Error\n{err}") + + else: + logging.error("Other databases are currently not supported.") + + return cnx + + +def validation_check_edge_entity_entity_queries(edges_queries_list): + """ + Ensures that the edge_entity_entity_queries are correctly formatted. + + :param edges_queries_list: List of all the queries defining edges from entity nodes to entity nodes + :return new_query_list: These are updated queries with necessary changes if any + """ + # Format: SELECT table1_name.col1_name, table2_name.col2_name FROM ____ WHERE ____ (and so on); + logging.info("\nValidating queries for proper formatting") + new_query_list = list() + for q in range(len(edges_queries_list)): + logging.info(f"Checking query[{q}]") + qry_split = edges_queries_list[q].strip().split() + + if "AS" in qry_split or "as" in qry_split: + logging.error("Error: Cannot use AS keyword in query. Please update" + " the query") + exit(1) + + check_var = qry_split[0].lower() + if check_var != "select": + logging.error("Error: Incorrect edge entity node - entity node formatting, " + "not starting with SELECT") + exit(1) + + check_split = qry_split[1].split(".") + if len(check_split) != 2: + logging.error( + "Error: Incorrect edge entity node - entity node formatting, " + + "table1_name.col1_name not correctly formatted" + ) + exit(1) + if check_split[1][-1] != ",": + logging.error( + "Error: Incorrect edge entity node - entity node formatting, " + + "missing ',' at the end of table1_name.col1_name" + ) + exit(1) + + check_split = qry_split[2].split(".") + if len(check_split) != 2: + logging.error( + "Error: Incorrect edge entity node - entity node formatting, " + + "table2_name.col2_name not correctly formatted" + ) + exit(1) + + check_var = qry_split[3].lower() + if check_var != "from": + logging.error( + "Error: Incorrect edge entity node - entity node formatting, " + + "extra elements after table2_name.col2_name" + ) + exit(1) + + new_query_list.append(edges_queries_list[q]) + + return new_query_list + + +def clean_token(token): + """ + Helper to clean a dataframe, can be used by applying this function to a dataframe + + :param token: elements to clean + :return token: cleaned token + """ + token = str(token) + token = token.strip().strip("\t.'\" ") + return token.lower() + + +def get_init_fetch_size(): + """ + In an initial pass, estimates the optimal maximum possible fetch_size + for given query based on memory usage report of virtual_memory() + + :return limit_fetch_size: the optimal maximum possible fetch_size for database engine + """ + mem_copy = psutil.virtual_memory() + mem_copy_used = mem_copy.used + limit_fetch_size = min(mem_copy.available / 2, MAX_FETCH_SIZE) # max fetch_size limited to MAX_FETCH_SIZE + return limit_fetch_size, mem_copy_used + + +def get_fetch_size(fetch_size, limit_fetch_size, mem_copy_used): + """ + Calculates the optimal maximum fetch_size based on the current snapshot of virtual_memory() + Increase fetch_size if the amount of memory used is less than half of machine's total available memory + The size of fetch_size is limited between 10000 and limit_fetch_size bytes + + :param limit_fetch_size: the optimal maximum possible fetch_size + :return fetch_size: updated fetch_size passed into database engine + """ + delta = ( + psutil.virtual_memory().used - mem_copy_used + ) # delta between two virtual_memory(), i.e. mem used for curr fetch_size + est_fetch_size = limit_fetch_size / (delta + 1) * fetch_size # estimated optimal fetch_size + if est_fetch_size > limit_fetch_size: + fetch_size = int(limit_fetch_size) + elif FETCH_SIZE < est_fetch_size and est_fetch_size <= limit_fetch_size: + fetch_size = int(est_fetch_size) + else: + fetch_size = FETCH_SIZE + return fetch_size + + +def get_cursor(cnx, db_server, cursor_name): + """ + Gets the cursor for the database connection + + :param cnx: database connection + :param db_server: database server + :param cursor_name: name of the cursor (needed for postgre-sql) + :return cursor: cursor for database connection + """ + cursor = [] + if db_server == "maria-db" or db_server == "my-sql": + cursor = cnx.cursor() + elif db_server == "postgre-sql": + cursor = cnx.cursor(name=cursor_name) + return cursor + + +def post_processing(output_dir, cnx, edges_queries_list, edge_rel_list, db_server): + """ + Executes the given queries_list one by one, cleanses the data by removing duplicates, + then append the entity nodes with tableName_colName which works as Unique Identifier, + and store the final result in a dataframe/.txt file + + :param output_dir: Directory to put output file + :param cnx: Cursor object + :param edges_queries_list: List of all the queries defining edges from entity nodes to entity nodes + :param edge_rel_list: List of all the relationships defining edges from entity nodes to entity nodes + :param db_server: database server name + :return 0: 0 for success, exit code 1 for failure + """ + if len(edges_queries_list) != len(edge_rel_list): + logging.error("Number of queries in edges_queries_list must match number of edges in edge_rel_list") + exit(1) + + open(output_dir / Path(OUTPUT_FILE_NAME), "w").close() # Clearing the output file + logging.info("\nProcessing queries to generate edges") + + fetch_size = FETCH_SIZE + # generating edges entity node to entity nodes + for i in range(len(edges_queries_list)): + start_time2 = time.time() + first_pass = True + + # Executing the query and timing it + query = edges_queries_list[i] + cursor_name = "edge_entity_entity_cursor" + str( + i + ) # Name imp because: https://www.psycopg.org/docs/usage.html#server-side-cursors + cursor = get_cursor(cnx, db_server, cursor_name) + cursor.execute(query) + + # Getting Basic Details + table_name_list = re.split(" ", query) # table name of the query to execute + table_name1 = table_name_list[1].split(".")[0] # src table + col_name1 = table_name_list[1].split(".")[1][:-1] # src column, (note last character ',' is removed) + table_name2 = table_name_list[2].split(".")[0] # dst/target table + col_name2 = table_name_list[2].split(".")[1] # dst/target column + + # Processing each batch of cursor on client + rows_completed = 0 + + # In an initial sample pass, estimates the optimal maximum possible fetch_size for + # given query based on memory usage report of virtual_memory() + # process data with fetch_size=10000, record the amount of memory used, + # increase fetch_size if the amount of memory used is less than half of machine's total available memory, + # Note: all unit size are in bytes, fetch_size limited between 10000 and 100000000 bytes + if first_pass: + limit_fetch_size, mem_copy_used = get_init_fetch_size() + + # Potential issue: There might be duplicates now possible as drop_duplicates over smaller range + # expected that user db does not have dupliacted + while True: # Looping till all rows are completed and processed + result = cursor.fetchmany(fetch_size) + result = pd.DataFrame(result) + if result.shape[0] == 0: + break + + # Cleaning Part + result = result.applymap(clean_token) # strip tokens and lower case strings + result = result[~result.iloc[:, 1].isin(INVALID_ENTRY_LIST)] # clean invalid data + result = result[~result.iloc[:, 0].isin(INVALID_ENTRY_LIST)] + result = result.drop_duplicates() # remove invalid row + + result.iloc[:, 0] = table_name1 + "_" + col_name1 + "_" + result.iloc[:, 0] # src + result.iloc[:, 1] = table_name2 + "_" + col_name2 + "_" + result.iloc[:, 1] # dst/target + result.insert(1, "rel", edge_rel_list[i]) # rel + result.columns = ["src", "rel", "dst"] + + # storing the output + result.to_csv( + output_dir / Path(OUTPUT_FILE_NAME), sep="\t", header=False, index=False, mode="a" + ) # Appending the output to disk + del result + rows_completed += fetch_size + + # update fetch_size based on current snapshot of the machine's memory usage + if first_pass: + fetch_size = get_fetch_size(fetch_size, limit_fetch_size, mem_copy_used) + first_pass = False + logging.info(f"Finished processing query[{i}] in {time.time() - start_time2:.3f} seconds") + + +def main(): + total_time = time.time() + parser = set_args() + args = parser.parse_args() + + ret_data = config_parser_fn(args.config_path) + db_server = ret_data[0] + db_name = ret_data[1] + db_user = ret_data[2] + db_password = ret_data[3] + db_host = ret_data[4] + edges_queries_list = ret_data[5] + edge_rel_list = ret_data[6] + + output_dir = Path(args.output_directory) + output_dir.mkdir(parents=True, exist_ok=True) + logging.basicConfig( + filename=output_dir / Path("marius_db2graph.log"), level=logging.INFO, filemode="w" + ) # set filemode='w' if want to start a fresh log file + logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) # add handler to print to console + + try: + logging.info(f"\nStarting marius_db2graph conversion tool for config: {args.config_path}") + + cnx = connect_to_db(db_server, db_name, db_user, db_password, db_host) + + # Generating edges + edges_queries_list = validation_check_edge_entity_entity_queries(edges_queries_list) + post_processing(output_dir, cnx, edges_queries_list, edge_rel_list, db_server) + + cnx.close() + logging.info(f"\nTotal execution time: {time.time()-total_time:.3f} seconds") + logging.info("\nEdge file written to " + str(output_dir / Path(OUTPUT_FILE_NAME))) + except Exception as e: + logging.error(e) + logging.info(f"\nTotal execution time: {time.time()-total_time:.3f} seconds") + + +if __name__ == "__main__": + main() diff --git a/test/db2graph/test_postgres.py b/test/db2graph/test_postgres.py new file mode 100644 index 00000000..c01c8a31 --- /dev/null +++ b/test/db2graph/test_postgres.py @@ -0,0 +1,296 @@ +import os +import random +from pathlib import Path + +import psycopg2 + +from src.python.tools.db2graph.marius_db2graph import connect_to_db, post_processing + + +class TestConnector: + database = "postgres" + user = "postgres" + password = "postgres" + host = os.environ.get("POSTGRES_HOST") + port = os.environ.get("POSTGRES_PORT") + + customer_names = ["sofia", "lukas", "rajesh", "daiyu", "hina", "lorenzo", "donghai", "shuchang", "johnny"] + country_names = ["spain", "germany", "india", "china", "japan", "italy", "china", "china", "usa"] + item_names = [ + "fenugreek", + "soy sauce", + "oregano", + "tomato", + "cumin", + "soy sauce", + "eggs", + "onions", + "onions", + "wasabi", + "rice", + "chicken breast", + "salmon", + "sourdough bread", + "meatballs", + "root beer", + "croissant", + "taco sauce", + ] + + def fill_db(self): + """ + Filling the database with data for testing things + """ + conn = psycopg2.connect( + database=self.database, user=self.user, password=self.password, host=self.host, port=self.port + ) + cur = conn.cursor() + + # DROP TABLE IF EXISTS + cur.execute("DROP TABLE IF EXISTS ORDERS;") + cur.execute("DROP TABLE IF EXISTS CUSTOMERS;") + conn.commit() + + # Create two tables - First Customers and second Orders + cur.execute( + """CREATE TABLE CUSTOMERS + (ID INT PRIMARY KEY NOT NULL, + CUSTOMERNAME TEXT NOT NULL, + COUNTRY TEXT NOT NULL, + PHONE VARCHAR(10) NOT NULL);""" + ) + conn.commit() + cur.execute( + """CREATE TABLE ORDERS + (ID INT PRIMARY KEY NOT NULL, + CUSTOMERID INT NOT NULL, + AMOUNT INT NOT NULL, + ITEM TEXT NOT NULL, + CONSTRAINT fk_customer + FOREIGN KEY(CUSTOMERID) + REFERENCES CUSTOMERS(ID));""" + ) + conn.commit() + + # Insert some data + # Inserting Customers + cur.execute( + f"INSERT INTO CUSTOMERS (ID,CUSTOMERNAME,COUNTRY,PHONE) VALUES (1, '{self.customer_names[0]}'," + f" '{self.country_names[0]}', '6081237654')" + ) + cur.execute( + f"INSERT INTO CUSTOMERS (ID,CUSTOMERNAME,COUNTRY,PHONE) VALUES (2, '{self.customer_names[1]}'," + f" '{self.country_names[1]}', '6721576540')" + ) + cur.execute( + f"INSERT INTO CUSTOMERS (ID,CUSTOMERNAME,COUNTRY,PHONE) VALUES (3, '{self.customer_names[2]}'," + f" '{self.country_names[2]}', '5511234567')" + ) + cur.execute( + f"INSERT INTO CUSTOMERS (ID,CUSTOMERNAME,COUNTRY,PHONE) VALUES (4, '{self.customer_names[3]}'," + f" '{self.country_names[3]}', '3211248173')" + ) + cur.execute( + f"INSERT INTO CUSTOMERS (ID,CUSTOMERNAME,COUNTRY,PHONE) VALUES (5, '{self.customer_names[4]}'," + f" '{self.country_names[4]}', '6667890001')" + ) + cur.execute( + f"INSERT INTO CUSTOMERS (ID,CUSTOMERNAME,COUNTRY,PHONE) VALUES (6, '{self.customer_names[5]}'," + f" '{self.country_names[5]}', '6260001111')" + ) + cur.execute( + f"INSERT INTO CUSTOMERS (ID,CUSTOMERNAME,COUNTRY,PHONE) VALUES (7, '{self.customer_names[6]}'," + f" '{self.country_names[6]}', '7874561234')" + ) + cur.execute( + f"INSERT INTO CUSTOMERS (ID,CUSTOMERNAME,COUNTRY,PHONE) VALUES (8, '{self.customer_names[7]}'," + f" '{self.country_names[7]}', '4041015059')" + ) + cur.execute( + f"INSERT INTO CUSTOMERS (ID,CUSTOMERNAME,COUNTRY,PHONE) VALUES (9, '{self.customer_names[8]}'," + f" '{self.country_names[8]}', '5647525398')" + ) + conn.commit() + + # Inserting Orders + cur.execute( + f"INSERT INTO ORDERS (ID,CUSTOMERID,AMOUNT,ITEM) VALUES (1, 3, 5, '{self.item_names[0]}')" + ) + cur.execute( + f"INSERT INTO ORDERS (ID,CUSTOMERID,AMOUNT,ITEM) VALUES (2, 7, 7, '{self.item_names[1]}')" + ) + cur.execute( + f"INSERT INTO ORDERS (ID,CUSTOMERID,AMOUNT,ITEM) VALUES (3, 6, 2, '{self.item_names[2]}')" + ) + cur.execute( + f"INSERT INTO ORDERS (ID,CUSTOMERID,AMOUNT,ITEM) VALUES (4, 1, 3, '{self.item_names[3]}')" + ) + cur.execute( + f"INSERT INTO ORDERS (ID,CUSTOMERID,AMOUNT,ITEM) VALUES (5, 3, 5, '{self.item_names[4]}')" + ) + cur.execute( + f"INSERT INTO ORDERS (ID,CUSTOMERID,AMOUNT,ITEM) VALUES (6, 5, 7, '{self.item_names[5]}')" + ) + cur.execute( + f"INSERT INTO ORDERS (ID,CUSTOMERID,AMOUNT,ITEM) VALUES (7, 2, 1, '{self.item_names[6]}')" + ) + cur.execute( + f"INSERT INTO ORDERS (ID,CUSTOMERID,AMOUNT,ITEM) VALUES (8, 9, 3, '{self.item_names[7]}')" + ) + cur.execute( + f"INSERT INTO ORDERS (ID,CUSTOMERID,AMOUNT,ITEM) VALUES (9, 4, 3, '{self.item_names[8]}')" + ) + cur.execute( + f"INSERT INTO ORDERS (ID,CUSTOMERID,AMOUNT,ITEM) VALUES (10, 5, 15, '{self.item_names[9]}')" + ) + cur.execute( + f"INSERT INTO ORDERS (ID,CUSTOMERID,AMOUNT,ITEM) VALUES (11, 8, 9, '{self.item_names[10]}')" + ) + cur.execute( + f"INSERT INTO ORDERS (ID,CUSTOMERID,AMOUNT,ITEM) VALUES (12, 4, 12, '{self.item_names[11]}')" + ) + cur.execute( + f"INSERT INTO ORDERS (ID,CUSTOMERID,AMOUNT,ITEM) VALUES (13, 5, 20, '{self.item_names[12]}')" + ) + cur.execute( + f"INSERT INTO ORDERS (ID,CUSTOMERID,AMOUNT,ITEM) VALUES (14, 6, 11, '{self.item_names[13]}')" + ) + cur.execute( + f"INSERT INTO ORDERS (ID,CUSTOMERID,AMOUNT,ITEM) VALUES (15, 2, 8, '{self.item_names[14]}')" + ) + cur.execute( + f"INSERT INTO ORDERS (ID,CUSTOMERID,AMOUNT,ITEM) VALUES (16, 9, 2, '{self.item_names[15]}')" + ) + cur.execute( + f"INSERT INTO ORDERS (ID,CUSTOMERID,AMOUNT,ITEM) VALUES (17, 2, 6, '{self.item_names[16]}')" + ) + cur.execute( + f"INSERT INTO ORDERS (ID,CUSTOMERID,AMOUNT,ITEM) VALUES (18, 1, 4, '{self.item_names[17]}')" + ) + conn.commit() + + conn.close() + return + + def test_connect_to_db(self): + """ + Basic connecter to db test. Just checking if connection established + and corrected values are fetched + """ + # Filling database with data for testing + conn = psycopg2.connect( + database=self.database, user=self.user, password=self.password, host=self.host, port=self.port + ) + + # Create table + cur = conn.cursor() + cur.execute( + """CREATE TABLE COMPANY + (ID INT PRIMARY KEY NOT NULL, + NAME TEXT NOT NULL, + AGE INT NOT NULL);""" + ) + conn.commit() + + # Insert some data + num_data_to_insert = 5 + self.name = [] + self.age = [] + for i in range(num_data_to_insert): + self.name.append("name" + str(i)) + self.age.append(random.randint(1, 100)) + + for i in range(num_data_to_insert): + cur.execute( + f"INSERT INTO COMPANY (ID,NAME,AGE) VALUES ({i}, '{self.name[i]}', {self.age[i]})" + ) + conn.commit() + conn.close() + + # Setting the connect function to test + conn = connect_to_db( + db_server="postgre-sql", + db_name=self.database, + db_user=self.user, + db_password=self.password, + db_host=self.host, + ) + cur = conn.cursor() + cur.execute("SELECT id, name, age from COMPANY") + rows = cur.fetchall() + index = 0 + for row in rows: + assert row[0] == index + assert row[1] == self.name[index] + assert row[2] == self.age[index] + index += 1 + conn.close() + + def test_edges_entity_entity(self): + """ + Testing edges_entity_entity type of queries which generate edges + """ + self.fill_db() # Filling database with data for testing + + # Getting all the inputs for the function + output_dir = Path("output_dir_edges_entity_entity/") + output_dir.mkdir(parents=True, exist_ok=True) + + db_server = "postgre-sql" + + conn = psycopg2.connect( + database=self.database, user=self.user, password=self.password, host=self.host, port=self.port + ) + + edge_entity_entity_queries_list = [] + edge_entity_entity_queries_list.append( + "SELECT customers.customername, customers.country FROM customers ORDER BY customers.customername ASC;" + ) + edge_entity_entity_queries_list.append( + "SELECT orders.item, customers.country FROM orders, customers WHERE orders.customerid = customers.id ORDER" + " BY orders.item ASC;" + ) + edge_entity_entity_rel_list = ["lives_in", "ordered_by_people_from_country"] + + # Testing the function + post_processing(output_dir, conn, edge_entity_entity_queries_list, edge_entity_entity_rel_list, db_server) + + # Asserting the correctionness of the output + # Predefined correct output for the input queries + correct_output = [] + + # expected outputs for query 1 + correct_output.append("customers_customername_daiyu\tlives_in\tcustomers_country_china\n") + correct_output.append("customers_customername_donghai\tlives_in\tcustomers_country_china\n") + correct_output.append("customers_customername_hina\tlives_in\tcustomers_country_japan\n") + correct_output.append("customers_customername_johnny\tlives_in\tcustomers_country_usa\n") + correct_output.append("customers_customername_lorenzo\tlives_in\tcustomers_country_italy\n") + correct_output.append("customers_customername_lukas\tlives_in\tcustomers_country_germany\n") + correct_output.append("customers_customername_rajesh\tlives_in\tcustomers_country_india\n") + correct_output.append("customers_customername_shuchang\tlives_in\tcustomers_country_china\n") + correct_output.append("customers_customername_sofia\tlives_in\tcustomers_country_spain\n") + + # expected outputs for query 2 + correct_output.append("orders_item_chicken breast\tordered_by_people_from_country\tcustomers_country_china\n") + correct_output.append("orders_item_croissant\tordered_by_people_from_country\tcustomers_country_germany\n") + correct_output.append("orders_item_cumin\tordered_by_people_from_country\tcustomers_country_india\n") + correct_output.append("orders_item_eggs\tordered_by_people_from_country\tcustomers_country_germany\n") + correct_output.append("orders_item_fenugreek\tordered_by_people_from_country\tcustomers_country_india\n") + correct_output.append("orders_item_meatballs\tordered_by_people_from_country\tcustomers_country_germany\n") + correct_output.append("orders_item_onions\tordered_by_people_from_country\tcustomers_country_usa\n") + correct_output.append("orders_item_onions\tordered_by_people_from_country\tcustomers_country_china\n") + correct_output.append("orders_item_oregano\tordered_by_people_from_country\tcustomers_country_italy\n") + correct_output.append("orders_item_rice\tordered_by_people_from_country\tcustomers_country_china\n") + correct_output.append("orders_item_root beer\tordered_by_people_from_country\tcustomers_country_usa\n") + correct_output.append("orders_item_salmon\tordered_by_people_from_country\tcustomers_country_japan\n") + correct_output.append("orders_item_sourdough bread\tordered_by_people_from_country\tcustomers_country_italy\n") + correct_output.append("orders_item_soy sauce\tordered_by_people_from_country\tcustomers_country_japan\n") + correct_output.append("orders_item_soy sauce\tordered_by_people_from_country\tcustomers_country_china\n") + correct_output.append("orders_item_taco sauce\tordered_by_people_from_country\tcustomers_country_spain\n") + correct_output.append("orders_item_tomato\tordered_by_people_from_country\tcustomers_country_spain\n") + correct_output.append("orders_item_wasabi\tordered_by_people_from_country\tcustomers_country_japan\n") + with open(output_dir / "edges.txt", "r") as file: + for line in file: + assert line in correct_output + + return