diff --git a/STP_Machathon_2_0_Chatbot_Message_Classification.ipynb b/STP_Machathon_2_0_Chatbot_Message_Classification.ipynb
new file mode 100644
index 0000000..8a91283
--- /dev/null
+++ b/STP_Machathon_2_0_Chatbot_Message_Classification.ipynb
@@ -0,0 +1,757 @@
+{
+ "nbformat": 4,
+ "nbformat_minor": 0,
+ "metadata": {
+ "colab": {
+ "name": "STP-Machathon-2.0-Chatbot_Message_Classification",
+ "provenance": []
+ },
+ "kernelspec": {
+ "name": "python3",
+ "display_name": "Python 3"
+ },
+ "language_info": {
+ "name": "python"
+ }
+ },
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# Chatbot Message Classification Project\n",
+ "\n",
+ "## Table of Content\n",
+ "
"
+ ],
+ "metadata": {
+ "id": "DXj_r-k7OIJh"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "\n",
+ "## Installing the nescessery Libraries."
+ ],
+ "metadata": {
+ "id": "wCh8rZyEPmhO"
+ }
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "e-X-__pXLpzP",
+ "outputId": "d107cf0d-9667-407b-a7b2-ce0157cc84a6"
+ },
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Collecting pyarabic\n",
+ " Downloading PyArabic-0.6.14-py3-none-any.whl (126 kB)\n",
+ "\u001b[?25l\r\u001b[K |██▋ | 10 kB 26.8 MB/s eta 0:00:01\r\u001b[K |█████▏ | 20 kB 12.3 MB/s eta 0:00:01\r\u001b[K |███████▉ | 30 kB 9.7 MB/s eta 0:00:01\r\u001b[K |██████████▍ | 40 kB 8.7 MB/s eta 0:00:01\r\u001b[K |█████████████ | 51 kB 5.0 MB/s eta 0:00:01\r\u001b[K |███████████████▋ | 61 kB 5.2 MB/s eta 0:00:01\r\u001b[K |██████████████████▏ | 71 kB 5.3 MB/s eta 0:00:01\r\u001b[K |████████████████████▊ | 81 kB 5.9 MB/s eta 0:00:01\r\u001b[K |███████████████████████▍ | 92 kB 4.6 MB/s eta 0:00:01\r\u001b[K |██████████████████████████ | 102 kB 5.0 MB/s eta 0:00:01\r\u001b[K |████████████████████████████▌ | 112 kB 5.0 MB/s eta 0:00:01\r\u001b[K |███████████████████████████████▏| 122 kB 5.0 MB/s eta 0:00:01\r\u001b[K |████████████████████████████████| 126 kB 5.0 MB/s \n",
+ "\u001b[?25hRequirement already satisfied: six>=1.14.0 in /usr/local/lib/python3.7/dist-packages (from pyarabic) (1.15.0)\n",
+ "Installing collected packages: pyarabic\n",
+ "Successfully installed pyarabic-0.6.14\n"
+ ]
+ }
+ ],
+ "source": [
+ "pip install pyarabic"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "pip install qalsadi"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "IBExcMY0LvBs",
+ "outputId": "52ca28ad-8872-4d2d-835b-59f6857dfe6b"
+ },
+ "execution_count": 2,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Collecting qalsadi\n",
+ " Downloading qalsadi-0.4.4-py3-none-any.whl (257 kB)\n",
+ "\u001b[?25l\r\u001b[K |█▎ | 10 kB 21.7 MB/s eta 0:00:01\r\u001b[K |██▌ | 20 kB 14.7 MB/s eta 0:00:01\r\u001b[K |███▉ | 30 kB 10.7 MB/s eta 0:00:01\r\u001b[K |█████ | 40 kB 9.2 MB/s eta 0:00:01\r\u001b[K |██████▍ | 51 kB 5.1 MB/s eta 0:00:01\r\u001b[K |███████▋ | 61 kB 5.3 MB/s eta 0:00:01\r\u001b[K |█████████ | 71 kB 5.6 MB/s eta 0:00:01\r\u001b[K |██████████▏ | 81 kB 6.2 MB/s eta 0:00:01\r\u001b[K |███████████▌ | 92 kB 4.8 MB/s eta 0:00:01\r\u001b[K |████████████▊ | 102 kB 4.8 MB/s eta 0:00:01\r\u001b[K |██████████████ | 112 kB 4.8 MB/s eta 0:00:01\r\u001b[K |███████████████▎ | 122 kB 4.8 MB/s eta 0:00:01\r\u001b[K |████████████████▋ | 133 kB 4.8 MB/s eta 0:00:01\r\u001b[K |█████████████████▉ | 143 kB 4.8 MB/s eta 0:00:01\r\u001b[K |███████████████████▏ | 153 kB 4.8 MB/s eta 0:00:01\r\u001b[K |████████████████████▍ | 163 kB 4.8 MB/s eta 0:00:01\r\u001b[K |█████████████████████▋ | 174 kB 4.8 MB/s eta 0:00:01\r\u001b[K |███████████████████████ | 184 kB 4.8 MB/s eta 0:00:01\r\u001b[K |████████████████████████▏ | 194 kB 4.8 MB/s eta 0:00:01\r\u001b[K |█████████████████████████▌ | 204 kB 4.8 MB/s eta 0:00:01\r\u001b[K |██████████████████████████▊ | 215 kB 4.8 MB/s eta 0:00:01\r\u001b[K |████████████████████████████ | 225 kB 4.8 MB/s eta 0:00:01\r\u001b[K |█████████████████████████████▎ | 235 kB 4.8 MB/s eta 0:00:01\r\u001b[K |██████████████████████████████▋ | 245 kB 4.8 MB/s eta 0:00:01\r\u001b[K |███████████████████████████████▉| 256 kB 4.8 MB/s eta 0:00:01\r\u001b[K |████████████████████████████████| 257 kB 4.8 MB/s \n",
+ "\u001b[?25hRequirement already satisfied: pyarabic>=0.6.7 in /usr/local/lib/python3.7/dist-packages (from qalsadi) (0.6.14)\n",
+ "Collecting pickledb>=0.9.2\n",
+ " Downloading pickleDB-0.9.2.tar.gz (3.7 kB)\n",
+ "Requirement already satisfied: six>=1.10.0 in /usr/local/lib/python3.7/dist-packages (from qalsadi) (1.15.0)\n",
+ "Collecting naftawayh>=0.3\n",
+ " Downloading Naftawayh-0.4-py3-none-any.whl (332 kB)\n",
+ "\u001b[K |████████████████████████████████| 332 kB 40.0 MB/s \n",
+ "\u001b[?25hCollecting arramooz-pysqlite>=0.3\n",
+ " Downloading arramooz_pysqlite-0.3-py3-none-any.whl (9.2 MB)\n",
+ "\u001b[K |████████████████████████████████| 9.2 MB 31.7 MB/s \n",
+ "\u001b[?25hCollecting alyahmor>=0.1\n",
+ " Downloading alyahmor-0.1.2-py3-none-any.whl (50 kB)\n",
+ "\u001b[K |████████████████████████████████| 50 kB 6.4 MB/s \n",
+ "\u001b[?25hCollecting libqutrub>=1.2.3\n",
+ " Downloading libqutrub-1.2.4.1-py3-none-any.whl (138 kB)\n",
+ "\u001b[K |████████████████████████████████| 138 kB 67.5 MB/s \n",
+ "\u001b[?25hRequirement already satisfied: future>=0.16.0 in /usr/local/lib/python3.7/dist-packages (from qalsadi) (0.16.0)\n",
+ "Collecting tashaphyne>=0.3.4.1\n",
+ " Downloading Tashaphyne-0.3.6-py3-none-any.whl (251 kB)\n",
+ "\u001b[K |████████████████████████████████| 251 kB 68.9 MB/s \n",
+ "\u001b[?25hCollecting Arabic-Stopwords>=0.3\n",
+ " Downloading Arabic_Stopwords-0.3-py3-none-any.whl (353 kB)\n",
+ "\u001b[K |████████████████████████████████| 353 kB 71.3 MB/s \n",
+ "\u001b[?25hBuilding wheels for collected packages: pickledb\n",
+ " Building wheel for pickledb (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
+ " Created wheel for pickledb: filename=pickleDB-0.9.2-py3-none-any.whl size=4271 sha256=ef7ab5a748ae0aa8aa5b221243f16bb5ee3632ead4de0e192193d91db5e389cf\n",
+ " Stored in directory: /root/.cache/pip/wheels/08/34/42/9a7f94099208ce3d32638d98586a5a50f821db2fc75a3bdaae\n",
+ "Successfully built pickledb\n",
+ "Installing collected packages: tashaphyne, libqutrub, pickledb, naftawayh, arramooz-pysqlite, Arabic-Stopwords, alyahmor, qalsadi\n",
+ "Successfully installed Arabic-Stopwords-0.3 alyahmor-0.1.2 arramooz-pysqlite-0.3 libqutrub-1.2.4.1 naftawayh-0.4 pickledb-0.9.2 qalsadi-0.4.4 tashaphyne-0.3.6\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "\n",
+ "## Data Wrangling"
+ ],
+ "metadata": {
+ "id": "ha_0uxzuPxGe"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "import pyarabic.araby as araby\n",
+ "\n",
+ "import random\n",
+ "import json\n",
+ "import pickle\n",
+ "import numpy as np\n",
+ "import os\n",
+ "\n",
+ "import nltk\n",
+ "from nltk.stem import WordNetLemmatizer\n",
+ "\n",
+ "from sklearn.model_selection import train_test_split\n",
+ "from tensorflow.keras.models import Sequential\n",
+ "from tensorflow.keras.layers import Dense, Dropout, Input, Bidirectional, GRU\n",
+ "from tensorflow.keras.optimizers import SGD\n",
+ "from tensorflow.keras.models import load_model"
+ ],
+ "metadata": {
+ "id": "Ifld2HONLwcV"
+ },
+ "execution_count": 3,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "import pandas as pd\n",
+ "import qalsadi.lemmatizer\n",
+ "\n",
+ "lemmer = qalsadi.lemmatizer.Lemmatizer()"
+ ],
+ "metadata": {
+ "id": "ckdxFF5pLymk"
+ },
+ "execution_count": 4,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "df = pd.read_csv(r\"train_ara.csv\")\n",
+ "df.head()"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 206
+ },
+ "id": "aAi8Kon3L65e",
+ "outputId": "584aaab6-dcee-4b8b-ceeb-e0d2ebdab0f7"
+ },
+ "execution_count": 6,
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ "
\n",
+ "
\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " text | \n",
+ " intent | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " كم عدد مستشفيات العزل فى مصر وما هى اماكنها | \n",
+ " business location | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " ومخاصمك | \n",
+ " nothing identified | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " متي ينتهي كورونا؟ | \n",
+ " the evolution of the virus | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " انا اسف | \n",
+ " nothing identified | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " كام عدد الوفيات النهارده | \n",
+ " infected cases | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
\n",
+ "
\n",
+ " \n",
+ " \n",
+ "\n",
+ " \n",
+ "
\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ " text intent\n",
+ "0 كم عدد مستشفيات العزل فى مصر وما هى اماكنها business location\n",
+ "1 ومخاصمك nothing identified\n",
+ "2 متي ينتهي كورونا؟ the evolution of the virus\n",
+ "3 انا اسف nothing identified\n",
+ "4 كام عدد الوفيات النهارده infected cases"
+ ]
+ },
+ "metadata": {},
+ "execution_count": 6
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "path = r\"list.txt\"\n",
+ "stop_words = []\n",
+ "with open(path, \"r\", encoding=\"utf-8\", errors=\"ignore\") as myfile:\n",
+ " stop_words = myfile.readlines()\n",
+ "stopWords = [word.strip() for word in stop_words]"
+ ],
+ "metadata": {
+ "id": "lSxRQc2BL8Hj"
+ },
+ "execution_count": 8,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "\n",
+ "## Pre-processing"
+ ],
+ "metadata": {
+ "id": "bQk0aYxAP7tZ"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "words = []\n",
+ "classes = []\n",
+ "documents = []\n",
+ "ignore_letters = ['!', '?', ',', '.']\n",
+ "\n",
+ "for i in range (df['text'].count()):\n",
+ " word = araby.tokenize(df['text'][i])\n",
+ " words.extend(word)\n",
+ " documents.append((word, df['intent'][i]))\n",
+ " if df['intent'][i] not in classes:\n",
+ " classes.append(df['intent'][i])"
+ ],
+ "metadata": {
+ "id": "BN5qfv0fMVTc"
+ },
+ "execution_count": 9,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "words = [lemmer.lemmatize(w.lower()) for w in words if w not in ignore_letters and w not in stopWords]\n",
+ "words = sorted(list(set(words)))\n",
+ "\n",
+ "classes = sorted(list(set(classes)))"
+ ],
+ "metadata": {
+ "id": "ouWVaKiUMjlm"
+ },
+ "execution_count": 10,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "training = []\n",
+ "output_empty = [0] * len(classes)\n",
+ "\n",
+ "for doc in documents:\n",
+ " bag = []\n",
+ " word_patterns = doc[0]\n",
+ " word_patterns = [lemmer.lemmatize(word.lower()) for word in word_patterns]\n",
+ " for word in words:\n",
+ " bag.append(1) if word in word_patterns else bag.append(0)\n",
+ "\n",
+ " output_row = list(output_empty)\n",
+ " output_row[classes.index(doc[1])] = 1\n",
+ " training.append([bag, output_row])\n",
+ "\n",
+ "random.shuffle(training)\n",
+ "training = np.array(training)\n",
+ "\n",
+ "train_x = list(training[:, 0])\n",
+ "train_y = list(training[:, 1])"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "GKh8VYK3Mshy",
+ "outputId": "292f9072-0347-48bf-b39c-20894a51c604"
+ },
+ "execution_count": 11,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stderr",
+ "text": [
+ "/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:16: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray\n",
+ " app.launch_new_instance()\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "X_train, X_test, y_train, y_test = train_test_split(train_x, train_y, test_size = 0.2, random_state = 42, shuffle=True)"
+ ],
+ "metadata": {
+ "id": "vF-EfiW2MxwQ"
+ },
+ "execution_count": 12,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "\n",
+ "## Deep Learning Model"
+ ],
+ "metadata": {
+ "id": "OFqFbv4oQBG9"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "model = Sequential()\n",
+ "model.add(Dense(128, input_shape=(len(train_x[0]),), activation='relu'))\n",
+ "model.add(Dropout(0.5))\n",
+ "model.add(Dense(64, activation='relu'))\n",
+ "model.add(Dropout(0.5))\n",
+ "model.add(Dense(len(train_y[0]), activation='softmax'))\n",
+ "\n",
+ "sgd = SGD(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True)\n",
+ "model.compile(loss='categorical_crossentropy', optimizer=sgd, metrics=['accuracy'])"
+ ],
+ "metadata": {
+ "id": "ZBPU_wSYNN3a",
+ "outputId": "28fca174-6f84-42d6-8177-0c2de34769b1",
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ }
+ },
+ "execution_count": 13,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stderr",
+ "text": [
+ "/usr/local/lib/python3.7/dist-packages/keras/optimizer_v2/gradient_descent.py:102: UserWarning: The `lr` argument is deprecated, use `learning_rate` instead.\n",
+ " super(SGD, self).__init__(name, **kwargs)\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "model.summary()"
+ ],
+ "metadata": {
+ "id": "WcyBotJFNRTR",
+ "outputId": "3b5c3fda-152e-487e-9dae-d4d5a0bbd8ff",
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ }
+ },
+ "execution_count": 14,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Model: \"sequential\"\n",
+ "_________________________________________________________________\n",
+ " Layer (type) Output Shape Param # \n",
+ "=================================================================\n",
+ " dense (Dense) (None, 128) 125440 \n",
+ " \n",
+ " dropout (Dropout) (None, 128) 0 \n",
+ " \n",
+ " dense_1 (Dense) (None, 64) 8256 \n",
+ " \n",
+ " dropout_1 (Dropout) (None, 64) 0 \n",
+ " \n",
+ " dense_2 (Dense) (None, 40) 2600 \n",
+ " \n",
+ "=================================================================\n",
+ "Total params: 136,296\n",
+ "Trainable params: 136,296\n",
+ "Non-trainable params: 0\n",
+ "_________________________________________________________________\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "\n",
+ "## Model Training "
+ ],
+ "metadata": {
+ "id": "jqTQVaeFQGIS"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "model.fit(np.array(X_train), np.array(y_train), epochs=3, batch_size=5, verbose=2)"
+ ],
+ "metadata": {
+ "id": "5HIehtvmNTe8",
+ "outputId": "e4e319f4-dc82-475c-a451-ceb3adba85c1",
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ }
+ },
+ "execution_count": 15,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Epoch 1/3\n",
+ "123903/123903 - 263s - loss: 0.1467 - accuracy: 0.9607 - 263s/epoch - 2ms/step\n",
+ "Epoch 2/3\n",
+ "123903/123903 - 236s - loss: 0.1875 - accuracy: 0.9638 - 236s/epoch - 2ms/step\n",
+ "Epoch 3/3\n",
+ "123903/123903 - 251s - loss: 0.2750 - accuracy: 0.9506 - 251s/epoch - 2ms/step\n"
+ ]
+ },
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "execution_count": 15
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "\n",
+ "## Pre-processing for the test data"
+ ],
+ "metadata": {
+ "id": "4H6nn054NpiJ"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "def _clean_up_sentence(sentence):\n",
+ " sentence_words = araby.tokenize(sentence)\n",
+ " sentence_words = [lemmer.lemmatize(word.lower()) for word in sentence_words]\n",
+ " return sentence_words"
+ ],
+ "metadata": {
+ "id": "337vHiBdNYHw"
+ },
+ "execution_count": 16,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "def _bag_of_words(sentence, words):\n",
+ " sentence_words = _clean_up_sentence(sentence)\n",
+ " bag = [0] * len(words)\n",
+ " for s in sentence_words:\n",
+ " for i, word in enumerate(words):\n",
+ " if word == s:\n",
+ " bag[i] = 1\n",
+ " return np.array(bag)"
+ ],
+ "metadata": {
+ "id": "Mm25vId9Njzv"
+ },
+ "execution_count": 17,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "def _predict_class(sentence):\n",
+ " p = _bag_of_words(sentence, words)\n",
+ " res = model.predict(np.array([p]))[0]\n",
+ " ERROR_THRESHOLD = 0.1\n",
+ " results = [[i, r] for i, r in enumerate(res) if r > ERROR_THRESHOLD]\n",
+ " \n",
+ " results.sort(key=lambda x: x[1], reverse=True)\n",
+ " return_list = \"\"\n",
+ " for r in results:\n",
+ " return_list = classes[r[0]]\n",
+ " return return_list"
+ ],
+ "metadata": {
+ "id": "ZEBa_--2Nk3s"
+ },
+ "execution_count": 18,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "### importing and predecting on the test data"
+ ],
+ "metadata": {
+ "id": "8hsd2pBHNuLV"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "dfTest = pd.read_csv(r\"test.csv\")\n",
+ "output = []\n",
+ "for i in range(dfTest['text'].count()):\n",
+ " x = _predict_class(dfTest['text'][i])\n",
+ " output.append(x)\n",
+ "output[:10]"
+ ],
+ "metadata": {
+ "id": "NhSrh5VENmB5",
+ "outputId": "132fc50c-aa43-4ea5-d7a8-3f3f81ab4ce3",
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ }
+ },
+ "execution_count": 19,
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "['infected cases',\n",
+ " 'infected cases',\n",
+ " 'yes',\n",
+ " 'treatment',\n",
+ " 'infected cases',\n",
+ " 'infected cases',\n",
+ " 'infected cases',\n",
+ " 'infected cases',\n",
+ " 'the evolution of the virus',\n",
+ " 'infected cases']"
+ ]
+ },
+ "metadata": {},
+ "execution_count": 19
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "\n",
+ "## Evaluate on the splitted test data"
+ ],
+ "metadata": {
+ "id": "NU0EAHJRRmgK"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "model.evaluate(np.array(X_test), np.array(y_test))"
+ ],
+ "metadata": {
+ "id": "M3kgcaFfOBJF",
+ "outputId": "fc7f51a5-1244-4fa4-aef4-7399155f869f",
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ }
+ },
+ "execution_count": 20,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "4840/4840 [==============================] - 16s 3ms/step - loss: 0.0494 - accuracy: 0.9862\n"
+ ]
+ },
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "[0.04943576455116272, 0.9862343072891235]"
+ ]
+ },
+ "metadata": {},
+ "execution_count": 20
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ ""
+ ],
+ "metadata": {
+ "id": "McfHdmOeRZYu"
+ },
+ "execution_count": null,
+ "outputs": []
+ }
+ ]
+}
\ No newline at end of file