{ "cells": [ { "cell_type": "code", "execution_count": 162, "metadata": {}, "outputs": [], "source": [ "\n", "import numpy as np\n", "import re\n", "import string\n", "import nltk\n", "from nltk.corpus import stopwords\n", "from nltk.tokenize import word_tokenize\n", "import matplotlib.pyplot as plt\n", "\n", "from tensorflow.keras.preprocessing.text import Tokenizer\n", "from tensorflow.keras.preprocessing.sequence import pad_sequences\n", "from tensorflow.keras.models import Model\n", "from tensorflow.keras.layers import Input, Embedding, LSTM, Dense\n", "\n", "from Sastrawi.Stemmer.StemmerFactory import StemmerFactory\n", "from sklearn.model_selection import train_test_split\n", "import pickle\n", "\n", "from sklearn.metrics import classification_report, accuracy_score, precision_score, recall_score\n", "import nltk" ] }, { "cell_type": "code", "execution_count": 163, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[nltk_data] Downloading package stopwords to /home/akeon/nltk_data...\n", "[nltk_data] Package stopwords is already up-to-date!\n", "[nltk_data] Downloading package punkt to /home/akeon/nltk_data...\n", "[nltk_data] Package punkt is already up-to-date!\n", "[nltk_data] Downloading package punkt_tab to /home/akeon/nltk_data...\n", "[nltk_data] Package punkt_tab is already up-to-date!\n", "[nltk_data] Downloading package wordnet to /home/akeon/nltk_data...\n", "[nltk_data] Package wordnet is already up-to-date!\n" ] }, { "data": { "text/plain": [ "True" ] }, "execution_count": 163, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# download assets\n", "nltk.download(\"stopwords\")\n", "nltk.download(\"punkt\")\n", "nltk.download(\"punkt_tab\")\n", "nltk.download(\"wordnet\")" ] }, { "cell_type": "code", "execution_count": 164, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Total Context: 54\n", "Total Possibility Questions: 97\n", "Total Fill in the Blank Questions: 24\n", "Total Multiple Choice Questions: 29\n", "Total True/False Questions: 44\n" ] } ], "source": [ "import json\n", "from collections import defaultdict\n", "\n", "# path dataset\n", "file_path = \"dataset/training_dataset.json\"\n", "\n", "with open(file_path, \"r\", encoding=\"utf-8\") as file:\n", " dataset = json.load(file)\n", "\n", "total_context = len(dataset)\n", "\n", "total_question_posibility = sum(len(entry[\"question_posibility\"]) for entry in dataset)\n", "\n", "question_type_counts = defaultdict(int)\n", "for entry in dataset:\n", " for question in entry[\"question_posibility\"]:\n", " question_type_counts[question[\"type\"]] += 1\n", "\n", "print(f\"Total Context: {total_context}\")\n", "print(f\"Total Possibility Questions: {total_question_posibility}\")\n", "print(f\"Total Fill in the Blank Questions: {question_type_counts.get('fill_in_the_blank', 0)}\")\n", "print(f\"Total Multiple Choice Questions: {question_type_counts.get('multiple_choice', 0)}\")\n", "print(f\"Total True/False Questions: {question_type_counts.get('true_false', 0)}\")\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Data processing complete!\n", "Samples: 97\n" ] } ], "source": [ "# Text Preprocessing\n", "stop_words = set(stopwords.words(\"indonesian\")) \n", "factory = StemmerFactory()\n", "stemmer = factory.create_stemmer()\n", "\n", "with open(\"normalize_text/normalize.json\", \"r\", encoding=\"utf-8\") as file:\n", " normalization_dict = json.load(file)\n", "\n", "\n", "def text_preprocessing(text):\n", " #doing lower case \n", " text = text.lower()\n", " \n", " # remove symbol and read mark\n", " text = text.translate(str.maketrans(\"\", \"\", string.punctuation))\n", " \n", " # remove blank space\n", " text = re.sub(r\"\\s+\", \" \", text).strip()\n", " \n", " # word tokenize \n", " tokens = word_tokenize(text)\n", " \n", " \n", " # normalassi kata\n", " tokens = [normalization_dict[word] if word in normalization_dict else word for word in tokens] \n", " \n", " \n", " # Lemmatization using Sastrawi (stemming in Indonesian)\n", " tokens = [stemmer.stem(word) for word in tokens]\n", " \n", " # stopword removal\n", " tokens = [word for word in tokens if word not in stop_words] \n", " \n", " return tokens\n", "\n", "# text processing all data training\n", "contexts = []\n", "questions = []\n", "correct_answers = []\n", "wrong_answers = []\n", "question_types = []\n", "\n", "for entry in dataset:\n", " processed_context = text_preprocessing(entry[\"context\"])\n", " \n", " for qa in entry[\"question_posibility\"]:\n", " processed_question = text_preprocessing(qa[\"question\"])\n", " processed_answer = text_preprocessing(qa[\"answer\"])\n", " \n", " contexts.append(processed_context)\n", " questions.append(processed_question)\n", " correct_answers.append(processed_answer)\n", " question_types.append(qa[\"type\"])\n", "\n", " if qa[\"type\"] == \"multiple_choice\":\n", " incorrect_options = [opt for opt in qa[\"options\"] if opt != qa[\"answer\"]]\n", " wrong_answers.append(incorrect_options)\n", " else:\n", " wrong_answers.append([])\n", "\n", "tokenizer = Tokenizer(oov_token=\"\")\n", "tokenizer.fit_on_texts(contexts + questions + correct_answers)\n", "\n", "context_sequences = tokenizer.texts_to_sequences(contexts)\n", "question_sequences = tokenizer.texts_to_sequences(questions)\n", "answer_sequences = tokenizer.texts_to_sequences(correct_answers)\n", "\n", "MAX_LENGTH = 100\n", "context_padded = pad_sequences(context_sequences, maxlen=MAX_LENGTH, padding=\"post\", truncating=\"post\")\n", "question_padded = pad_sequences(question_sequences, maxlen=MAX_LENGTH, padding=\"post\", truncating=\"post\")\n", "answer_padded = pad_sequences(answer_sequences, maxlen=MAX_LENGTH, padding=\"post\", truncating=\"post\")\n", "\n", "question_type_dict = {\"fill_in_the_blank\": 0, \"true_false\": 1, \"multiple_choice\": 2}\n", "question_type_labels = np.array([question_type_dict[q_type] for q_type in question_types])\n", "\n", "print(\"Data processing complete!\")\n", "print(\"Samples:\", context_padded.shape[0]) \n" ] }, { "cell_type": "code", "execution_count": 166, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Training samples: 77\n", "Testing samples: 10\n", "Validation samples: 10\n" ] } ], "source": [ "# split text for testing 8:2\n", "context_train, context_temp, question_train, question_temp, answer_train, answer_temp, qtype_train, qtype_temp = train_test_split(\n", " context_padded, question_padded, answer_padded, question_type_labels, test_size=0.2, random_state=42\n", ")\n", "\n", "# split dataset testing and validation 5:5\n", "context_test, context_val, question_test, question_val, answer_test, answer_val, qtype_test, qtype_val = train_test_split(\n", " context_temp, question_temp, answer_temp, qtype_temp, test_size=0.5, random_state=42\n", ")\n", "\n", "print(\"Training samples:\", context_train.shape[0])\n", "print(\"Testing samples:\", context_test.shape[0])\n", "print(\"Validation samples:\", context_val.shape[0])\n" ] }, { "cell_type": "code", "execution_count": 167, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/20\n", "\u001b[1m2/2\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 752ms/step - answer_output_accuracy: 0.0239 - answer_output_loss: 6.2109 - loss: 13.5254 - question_output_accuracy: 4.3716e-04 - question_output_loss: 6.2153 - question_type_output_accuracy: 0.3332 - question_type_output_loss: 1.0988 - val_answer_output_accuracy: 0.1931 - val_answer_output_loss: 6.1791 - val_loss: 13.4936 - val_question_output_accuracy: 0.0056 - val_question_output_loss: 6.2104 - val_question_type_output_accuracy: 0.3125 - val_question_type_output_loss: 1.1042\n", "Epoch 2/20\n", "\u001b[1m2/2\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 329ms/step - answer_output_accuracy: 0.1947 - answer_output_loss: 6.1534 - loss: 13.4319 - question_output_accuracy: 0.0122 - question_output_loss: 6.1985 - question_type_output_accuracy: 0.6445 - question_type_output_loss: 1.0791 - val_answer_output_accuracy: 0.9856 - val_answer_output_loss: 6.0885 - val_loss: 13.4025 - val_question_output_accuracy: 0.0112 - val_question_output_loss: 6.2026 - val_question_type_output_accuracy: 0.3125 - val_question_type_output_loss: 1.1115\n", "Epoch 3/20\n", "\u001b[1m2/2\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 316ms/step - answer_output_accuracy: 0.9831 - answer_output_loss: 6.0118 - loss: 13.2454 - question_output_accuracy: 0.0183 - question_output_loss: 6.1792 - question_type_output_accuracy: 0.6341 - question_type_output_loss: 1.0521 - val_answer_output_accuracy: 0.9856 - val_answer_output_loss: 5.7557 - val_loss: 13.0783 - val_question_output_accuracy: 0.0106 - val_question_output_loss: 6.1923 - val_question_type_output_accuracy: 0.3125 - val_question_type_output_loss: 1.1303\n", "Epoch 4/20\n", "\u001b[1m2/2\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 318ms/step - answer_output_accuracy: 0.9831 - answer_output_loss: 5.4126 - loss: 12.5932 - question_output_accuracy: 0.0159 - question_output_loss: 6.1526 - question_type_output_accuracy: 0.6132 - question_type_output_loss: 1.0133 - val_answer_output_accuracy: 0.9856 - val_answer_output_loss: 4.0785 - val_loss: 11.4385 - val_question_output_accuracy: 0.0087 - val_question_output_loss: 6.1729 - val_question_type_output_accuracy: 0.4375 - val_question_type_output_loss: 1.1871\n", "Epoch 5/20\n", "\u001b[1m2/2\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 321ms/step - answer_output_accuracy: 0.9833 - answer_output_loss: 3.5350 - loss: 10.6302 - question_output_accuracy: 0.0109 - question_output_loss: 6.0941 - question_type_output_accuracy: 0.5482 - question_type_output_loss: 0.9777 - val_answer_output_accuracy: 0.9856 - val_answer_output_loss: 1.4486 - val_loss: 9.1339 - val_question_output_accuracy: 0.0069 - val_question_output_loss: 6.1108 - val_question_type_output_accuracy: 0.5000 - val_question_type_output_loss: 1.5745\n", "Epoch 6/20\n", "\u001b[1m2/2\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 324ms/step - answer_output_accuracy: 0.9830 - answer_output_loss: 1.3763 - loss: 8.3790 - question_output_accuracy: 0.0050 - question_output_loss: 5.8928 - question_type_output_accuracy: 0.5596 - question_type_output_loss: 1.0961 - val_answer_output_accuracy: 0.9856 - val_answer_output_loss: 0.6301 - val_loss: 8.8752 - val_question_output_accuracy: 0.0050 - val_question_output_loss: 6.0297 - val_question_type_output_accuracy: 0.1250 - val_question_type_output_loss: 2.2154\n", "Epoch 7/20\n", "\u001b[1m2/2\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 320ms/step - answer_output_accuracy: 0.9827 - answer_output_loss: 0.8154 - loss: 7.6408 - question_output_accuracy: 0.0030 - question_output_loss: 5.5596 - question_type_output_accuracy: 0.5596 - question_type_output_loss: 1.2471 - val_answer_output_accuracy: 0.9856 - val_answer_output_loss: 0.5587 - val_loss: 7.8821 - val_question_output_accuracy: 0.0044 - val_question_output_loss: 6.0440 - val_question_type_output_accuracy: 0.5000 - val_question_type_output_loss: 1.2793\n", "Epoch 8/20\n", "\u001b[1m2/2\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 315ms/step - answer_output_accuracy: 0.9845 - answer_output_loss: 0.6699 - loss: 7.0558 - question_output_accuracy: 0.0025 - question_output_loss: 5.2922 - question_type_output_accuracy: 0.5159 - question_type_output_loss: 1.0964 - val_answer_output_accuracy: 0.9856 - val_answer_output_loss: 0.5644 - val_loss: 7.7566 - val_question_output_accuracy: 0.0044 - val_question_output_loss: 6.1598 - val_question_type_output_accuracy: 0.4375 - val_question_type_output_loss: 1.0324\n", "Epoch 9/20\n", "\u001b[1m2/2\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 319ms/step - answer_output_accuracy: 0.9837 - answer_output_loss: 0.7007 - loss: 6.7585 - question_output_accuracy: 0.0021 - question_output_loss: 5.0688 - question_type_output_accuracy: 0.5804 - question_type_output_loss: 0.9895 - val_answer_output_accuracy: 0.9856 - val_answer_output_loss: 0.5536 - val_loss: 8.3754 - val_question_output_accuracy: 0.0044 - val_question_output_loss: 6.3426 - val_question_type_output_accuracy: 0.1250 - val_question_type_output_loss: 1.4793\n", "Epoch 10/20\n", "\u001b[1m2/2\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 316ms/step - answer_output_accuracy: 0.9841 - answer_output_loss: 0.6571 - loss: 6.6996 - question_output_accuracy: 0.0020 - question_output_loss: 4.9654 - question_type_output_accuracy: 0.3769 - question_type_output_loss: 1.0726 - val_answer_output_accuracy: 0.9856 - val_answer_output_loss: 0.5333 - val_loss: 8.2096 - val_question_output_accuracy: 0.0044 - val_question_output_loss: 6.5258 - val_question_type_output_accuracy: 0.5000 - val_question_type_output_loss: 1.1504\n", "Epoch 11/20\n", "\u001b[1m2/2\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 313ms/step - answer_output_accuracy: 0.9846 - answer_output_loss: 0.5896 - loss: 6.3947 - question_output_accuracy: 0.0029 - question_output_loss: 4.8274 - question_type_output_accuracy: 0.5367 - question_type_output_loss: 0.9851 - val_answer_output_accuracy: 0.9856 - val_answer_output_loss: 0.5115 - val_loss: 8.4411 - val_question_output_accuracy: 0.0050 - val_question_output_loss: 6.6733 - val_question_type_output_accuracy: 0.5000 - val_question_type_output_loss: 1.2564\n", "Epoch 12/20\n", "\u001b[1m2/2\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 287ms/step - answer_output_accuracy: 0.9832 - answer_output_loss: 0.6274 - loss: 6.3656 - question_output_accuracy: 0.0030 - question_output_loss: 4.7141 - question_type_output_accuracy: 0.4950 - question_type_output_loss: 1.0145 - val_answer_output_accuracy: 0.9856 - val_answer_output_loss: 0.5007 - val_loss: 8.7380 - val_question_output_accuracy: 0.0037 - val_question_output_loss: 6.7743 - val_question_type_output_accuracy: 0.1875 - val_question_type_output_loss: 1.4630\n", "Epoch 13/20\n", "\u001b[1m2/2\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 294ms/step - answer_output_accuracy: 0.9841 - answer_output_loss: 0.5365 - loss: 6.0931 - question_output_accuracy: 0.0028 - question_output_loss: 4.6340 - question_type_output_accuracy: 0.6330 - question_type_output_loss: 0.9268 - val_answer_output_accuracy: 0.9856 - val_answer_output_loss: 0.5095 - val_loss: 8.8004 - val_question_output_accuracy: 0.0050 - val_question_output_loss: 6.8402 - val_question_type_output_accuracy: 0.1875 - val_question_type_output_loss: 1.4508\n", "Epoch 14/20\n", "\u001b[1m2/2\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 293ms/step - answer_output_accuracy: 0.9839 - answer_output_loss: 0.5214 - loss: 5.9535 - question_output_accuracy: 0.0038 - question_output_loss: 4.5068 - question_type_output_accuracy: 0.6023 - question_type_output_loss: 0.9284 - val_answer_output_accuracy: 0.9856 - val_answer_output_loss: 0.5292 - val_loss: 8.5136 - val_question_output_accuracy: 0.0050 - val_question_output_loss: 6.8903 - val_question_type_output_accuracy: 0.5000 - val_question_type_output_loss: 1.0940\n", "Epoch 15/20\n", "\u001b[1m2/2\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 294ms/step - answer_output_accuracy: 0.9838 - answer_output_loss: 0.5345 - loss: 5.8897 - question_output_accuracy: 0.0041 - question_output_loss: 4.4544 - question_type_output_accuracy: 0.5596 - question_type_output_loss: 0.9003 - val_answer_output_accuracy: 0.9856 - val_answer_output_loss: 0.5447 - val_loss: 8.5770 - val_question_output_accuracy: 0.0056 - val_question_output_loss: 6.9331 - val_question_type_output_accuracy: 0.3125 - val_question_type_output_loss: 1.0993\n", "Epoch 16/20\n", "\u001b[1m2/2\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 292ms/step - answer_output_accuracy: 0.9832 - answer_output_loss: 0.5705 - loss: 5.8373 - question_output_accuracy: 0.0048 - question_output_loss: 4.3814 - question_type_output_accuracy: 0.6351 - question_type_output_loss: 0.8800 - val_answer_output_accuracy: 0.9856 - val_answer_output_loss: 0.5496 - val_loss: 8.8434 - val_question_output_accuracy: 0.0062 - val_question_output_loss: 6.9745 - val_question_type_output_accuracy: 0.2500 - val_question_type_output_loss: 1.3193\n", "Epoch 17/20\n", "\u001b[1m2/2\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 288ms/step - answer_output_accuracy: 0.9832 - answer_output_loss: 0.5433 - loss: 5.6367 - question_output_accuracy: 0.0053 - question_output_loss: 4.2834 - question_type_output_accuracy: 0.6773 - question_type_output_loss: 0.8080 - val_answer_output_accuracy: 0.9856 - val_answer_output_loss: 0.5488 - val_loss: 8.9683 - val_question_output_accuracy: 0.0062 - val_question_output_loss: 7.0182 - val_question_type_output_accuracy: 0.1875 - val_question_type_output_loss: 1.4014\n", "Epoch 18/20\n", "\u001b[1m2/2\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 292ms/step - answer_output_accuracy: 0.9843 - answer_output_loss: 0.4771 - loss: 5.4290 - question_output_accuracy: 0.0060 - question_output_loss: 4.1923 - question_type_output_accuracy: 0.6877 - question_type_output_loss: 0.7646 - val_answer_output_accuracy: 0.9856 - val_answer_output_loss: 0.5510 - val_loss: 9.0373 - val_question_output_accuracy: 0.0062 - val_question_output_loss: 7.0739 - val_question_type_output_accuracy: 0.2500 - val_question_type_output_loss: 1.4124\n", "Epoch 19/20\n", "\u001b[1m2/2\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 305ms/step - answer_output_accuracy: 0.9846 - answer_output_loss: 0.4586 - loss: 5.3489 - question_output_accuracy: 0.0053 - question_output_loss: 4.1443 - question_type_output_accuracy: 0.6668 - question_type_output_loss: 0.7466 - val_answer_output_accuracy: 0.9856 - val_answer_output_loss: 0.5583 - val_loss: 9.1426 - val_question_output_accuracy: 0.0062 - val_question_output_loss: 7.1137 - val_question_type_output_accuracy: 0.1875 - val_question_type_output_loss: 1.4707\n", "Epoch 20/20\n", "\u001b[1m2/2\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 305ms/step - answer_output_accuracy: 0.9830 - answer_output_loss: 0.5251 - loss: 5.2352 - question_output_accuracy: 0.0066 - question_output_loss: 4.0488 - question_type_output_accuracy: 0.7298 - question_type_output_loss: 0.6596 - val_answer_output_accuracy: 0.9856 - val_answer_output_loss: 0.5674 - val_loss: 9.4190 - val_question_output_accuracy: 0.0062 - val_question_output_loss: 7.1243 - val_question_type_output_accuracy: 0.1250 - val_question_type_output_loss: 1.7272\n" ] } ], "source": [ "\n", "VOCAB_SIZE = len(tokenizer.word_index) + 1\n", "EMBEDDING_DIM = 300\n", "LSTM_UNITS = 256\n", "BATCH_SIZE = 32\n", "EPOCHS = 20\n", "\n", "context_input = Input(shape=(MAX_LENGTH,), name=\"context_input\")\n", "context_embedding = Embedding(input_dim=VOCAB_SIZE, output_dim=EMBEDDING_DIM, mask_zero=True, name=\"context_embedding\")(context_input)\n", "encoder_lstm = LSTM(LSTM_UNITS, return_state=True, name=\"encoder_lstm\")\n", "encoder_output, state_h, state_c = encoder_lstm(context_embedding)\n", "\n", "# Question Decoder\n", "question_decoder_input = Input(shape=(MAX_LENGTH,), name=\"question_decoder_input\")\n", "question_embedding = Embedding(input_dim=VOCAB_SIZE, output_dim=EMBEDDING_DIM, mask_zero=True, name=\"question_embedding\")(question_decoder_input)\n", "question_lstm = LSTM(LSTM_UNITS, return_sequences=True, return_state=True, name=\"question_lstm\")\n", "question_output, _, _ = question_lstm(question_embedding, initial_state=[state_h, state_c])\n", "question_dense = Dense(VOCAB_SIZE, activation=\"softmax\", name=\"question_output\")(question_output)\n", "\n", "# Answer Decoder\n", "answer_lstm = LSTM(LSTM_UNITS, return_sequences=True, return_state=True, name=\"answer_lstm\")\n", "answer_output, _, _ = answer_lstm(context_embedding, initial_state=[state_h, state_c])\n", "answer_dense = Dense(VOCAB_SIZE, activation=\"softmax\", name=\"answer_output\")(answer_output)\n", "\n", "type_dense = Dense(128, activation=\"relu\")(encoder_output)\n", "question_type_output = Dense(3, activation=\"softmax\", name=\"question_type_output\")(type_dense)\n", "\n", "model = Model(\n", " inputs=[context_input, question_decoder_input],\n", " outputs=[question_dense, answer_dense, question_type_output]\n", ")\n", "\n", "model.compile(\n", " optimizer=\"adam\",\n", " loss={\"question_output\": \"sparse_categorical_crossentropy\",\n", " \"answer_output\": \"sparse_categorical_crossentropy\",\n", " \"question_type_output\": \"sparse_categorical_crossentropy\"},\n", " metrics={\"question_output\": [\"accuracy\"],\n", " \"answer_output\": [\"accuracy\"],\n", " \"question_type_output\": [\"accuracy\"]}\n", ")\n", "\n", "data_model = model.fit(\n", " [context_train, question_train],\n", " {\"question_output\": question_train, \"answer_output\": answer_train, \"question_type_output\": qtype_train},\n", " batch_size=BATCH_SIZE,\n", " epochs=EPOCHS,\n", " validation_split=0.2\n", ")\n", "\n", "model.save(\"lstm_multi_output_model.keras\")\n", "with open(\"tokenizer.pkl\", \"wb\") as handle:\n", " pickle.dump(tokenizer, handle, protocol=pickle.HIGHEST_PROTOCOL)\n" ] }, { "cell_type": "code", "execution_count": 168, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "\n", "# data epoch\n", "\n", "plt.figure(figsize=(12, 6))\n", "\n", "plt.subplot(1, 2, 1)\n", "plt.plot(data_model.history['question_output_accuracy'], label='Question Train Accuracy')\n", "plt.plot(data_model.history['val_question_output_accuracy'], label='Question Val Accuracy')\n", "plt.plot(data_model.history['answer_output_accuracy'], label='Answer Train Accuracy')\n", "plt.plot(data_model.history['val_answer_output_accuracy'], label='Answer Val Accuracy')\n", "plt.plot(data_model.history['question_type_output_accuracy'], label='Question Type Train Accuracy')\n", "plt.plot(data_model.history['val_question_type_output_accuracy'], label='Question Type Val Accuracy')\n", "plt.title('Model Accuracy')\n", "plt.xlabel('Epoch')\n", "plt.ylabel('Accuracy')\n", "plt.legend()\n", "\n", "plt.subplot(1, 2, 2)\n", "plt.plot(data_model.history['question_output_loss'], label='Question Train Loss')\n", "plt.plot(data_model.history['val_question_output_loss'], label='Question Val Loss')\n", "plt.plot(data_model.history['answer_output_loss'], label='Answer Train Loss')\n", "plt.plot(data_model.history['val_answer_output_loss'], label='Answer Val Loss')\n", "plt.plot(data_model.history['question_type_output_loss'], label='Question Type Train Loss')\n", "plt.plot(data_model.history['val_question_type_output_loss'], label='Question Type Val Loss')\n", "plt.title('Model Loss')\n", "plt.xlabel('Epoch')\n", "plt.ylabel('Loss')\n", "plt.legend()\n", "\n", "plt.tight_layout()\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "=== Evaluation on Test Data ===\n" ] } ], "source": [ "\n", "\n", "\n", "def sequence_to_text(sequence, tokenizer):\n", " \"\"\" Converts a tokenized sequence back into readable text. \"\"\"\n", " return \" \".join([tokenizer.index_word.get(idx, \"\") for idx in sequence if idx != 0])\n", "\n", "print(\"\\n=== Evaluation on Test Data ===\")\n", "pred_question_test, pred_answer_test, pred_qtype_test = model.predict([context_test, question_test])\n", "pred_qtype_labels_test = np.argmax(pred_qtype_test, axis=1)\n", "\n", "print(\"Classification Report for Question Type (Test Set):\")\n", "print(classification_report(qtype_test, pred_qtype_labels_test))\n", "print(\"Test Accuracy:\", accuracy_score(qtype_test, pred_qtype_labels_test))\n", "print(\"Test Precision:\", precision_score(qtype_test, pred_qtype_labels_test, average='weighted'))\n", "print(\"Test Recall:\", recall_score(qtype_test, pred_qtype_labels_test, average='weighted'))\n", "\n", "reference_question_test = [sequence_to_text(question_test[0], tokenizer)]\n", "candidate_question_test = sequence_to_text(np.argmax(pred_question_test[0], axis=-1), tokenizer)\n", "bleu_score_question_test = nltk.translate.bleu_score.sentence_bleu(reference_question_test, candidate_question_test)\n", "print(\"BLEU Score for first test sample (question generation):\", bleu_score_question_test)\n", "\n", "reference_answer_test = [sequence_to_text(answer_test[0], tokenizer)]\n", "candidate_answer_test = sequence_to_text(np.argmax(pred_answer_test[0], axis=-1), tokenizer)\n", "bleu_score_answer_test = nltk.translate.bleu_score.sentence_bleu(reference_answer_test, candidate_answer_test)\n", "print(\"BLEU Score for first test sample (answer generation):\", bleu_score_answer_test)\n", "\n", "print(\"\\n=== Evaluation on Validation Data ===\")\n", "pred_question_val, pred_answer_val, pred_qtype_val = model.predict([context_val, question_val])\n", "pred_qtype_labels_val = np.argmax(pred_qtype_val, axis=1)\n", "\n", "print(\"Classification Report for Question Type (Validation Set):\")\n", "print(classification_report(qtype_val, pred_qtype_labels_val))\n", "print(\"Validation Accuracy:\", accuracy_score(qtype_val, pred_qtype_labels_val))\n", "print(\"Validation Precision:\", precision_score(qtype_val, pred_qtype_labels_val, average='weighted'))\n", "print(\"Validation Recall:\", recall_score(qtype_val, pred_qtype_labels_val, average='weighted'))\n", "\n", "reference_question_val = [sequence_to_text(question_val[0], tokenizer)]\n", "candidate_question_val = sequence_to_text(np.argmax(pred_question_val[0], axis=-1), tokenizer)\n", "bleu_score_question_val = nltk.translate.bleu_score.sentence_bleu(reference_question_val, candidate_question_val)\n", "print(\"BLEU Score for first validation sample (question generation):\", bleu_score_question_val)\n", "\n", "reference_answer_val = [sequence_to_text(answer_val[0], tokenizer)]\n", "candidate_answer_val = sequence_to_text(np.argmax(pred_answer_val[0], axis=-1), tokenizer)\n", "bleu_score_answer_val = nltk.translate.bleu_score.sentence_bleu(reference_answer_val, candidate_answer_val)\n", "print(\"BLEU Score for first validation sample (answer generation):\", bleu_score_answer_val)\n" ] } ], "metadata": { "kernelspec": { "display_name": "myenv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.16" } }, "nbformat": 4, "nbformat_minor": 2 }