522 lines
25 KiB
Plaintext
522 lines
25 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 38,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# import library\n",
|
|
"\n",
|
|
"# Data manipulation and visualization\n",
|
|
"import pandas as pd\n",
|
|
"import numpy as np\n",
|
|
"import matplotlib.pyplot as plt\n",
|
|
"import json\n",
|
|
"\n",
|
|
"# Natural language processing\n",
|
|
"import re\n",
|
|
"import string\n",
|
|
"import nltk\n",
|
|
"from nltk.corpus import stopwords\n",
|
|
"from nltk.tokenize import word_tokenize\n",
|
|
"from nltk.stem import WordNetLemmatizer\n",
|
|
"\n",
|
|
"# Deep learning\n",
|
|
"from tensorflow.keras.preprocessing.text import Tokenizer\n",
|
|
"from tensorflow.keras.preprocessing.sequence import pad_sequences\n",
|
|
"from tensorflow.keras.models import Model\n",
|
|
"from tensorflow.keras.layers import Input, Embedding, LSTM, Dense, Concatenate\n",
|
|
"from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint\n",
|
|
"\n",
|
|
"\n",
|
|
"from sklearn.model_selection import train_test_split\n",
|
|
"\n",
|
|
"# Metrics for model evaluation\n",
|
|
"from sklearn.metrics import classification_report, precision_score, recall_score, accuracy_score\n",
|
|
"\n",
|
|
"# Utility for serialization\n",
|
|
"import pickle\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 39,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"[nltk_data] Downloading package stopwords to /home/akeon/nltk_data...\n",
|
|
"[nltk_data] Package stopwords is already up-to-date!\n",
|
|
"[nltk_data] Downloading package punkt to /home/akeon/nltk_data...\n",
|
|
"[nltk_data] Package punkt is already up-to-date!\n",
|
|
"[nltk_data] Downloading package punkt_tab to /home/akeon/nltk_data...\n",
|
|
"[nltk_data] Package punkt_tab is already up-to-date!\n",
|
|
"[nltk_data] Downloading package wordnet to /home/akeon/nltk_data...\n",
|
|
"[nltk_data] Package wordnet is already up-to-date!\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"True"
|
|
]
|
|
},
|
|
"execution_count": 39,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"# download assets\n",
|
|
"nltk.download(\"stopwords\")\n",
|
|
"nltk.download(\"punkt\")\n",
|
|
"nltk.download(\"punkt_tab\")\n",
|
|
"nltk.download(\"wordnet\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 40,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
" context \\\n",
|
|
"0 Albert Einstein adalah fisikawan teoretis kela... \n",
|
|
"1 Samudra Pasifik adalah yang terbesar dan terda... \n",
|
|
"2 Proklamasi Kemerdekaan Indonesia dibacakan pad... \n",
|
|
"3 Hukum Newton adalah tiga hukum fisika yang men... \n",
|
|
"4 Budi Utomo adalah organisasi pemuda yang didir... \n",
|
|
"\n",
|
|
" qa_pairs \n",
|
|
"0 [{'type': 'fill_in_the_blank', 'question': 'Si... \n",
|
|
"1 [{'type': 'fill_in_the_blank', 'question': 'Sa... \n",
|
|
"2 [{'type': 'fill_in_the_blank', 'question': 'Pr... \n",
|
|
"3 [{'type': 'fill_in_the_blank', 'question': 'Hu... \n",
|
|
"4 [{'type': 'fill_in_the_blank', 'question': 'Bu... \n",
|
|
"\n",
|
|
"Total Context: 25\n",
|
|
"Total QA Pairs: 57\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# load dataset\n",
|
|
"df = pd.read_json(\"dataset/training_dataset.json\")\n",
|
|
"print(df.head())\n",
|
|
"with open(\"dataset/training_dataset.json\", \"r\", encoding=\"utf-8\") as file:\n",
|
|
" dataset = json.load(file)\n",
|
|
" \n",
|
|
" \n",
|
|
"# Menghitung total context\n",
|
|
"total_context = len(dataset)\n",
|
|
"\n",
|
|
"# Menghitung total qa_pairs\n",
|
|
"total_qa_pairs = sum(len(entry[\"qa_pairs\"]) for entry in dataset)\n",
|
|
"\n",
|
|
"# Menampilkan hasil\n",
|
|
"print(f\"\\nTotal Context: {total_context}\")\n",
|
|
"print(f\"Total QA Pairs: {total_qa_pairs}\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 41,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Text Preprocessing\n",
|
|
"stop_words = set(stopwords.words(\"indonesian\")) \n",
|
|
"lemmatizer = WordNetLemmatizer()\n",
|
|
"\n",
|
|
"normalization_dict = {\n",
|
|
" \"yg\": \"yang\",\n",
|
|
" \"gokil\": \"kocak\",\n",
|
|
" \"kalo\": \"kalau\",\n",
|
|
" \"gue\": \"saya\",\n",
|
|
" \"elo\": \"kamu\",\n",
|
|
" \"nih\": \"ini\",\n",
|
|
" \"trs\": \"terus\",\n",
|
|
" \"tdk\": \"tidak\",\n",
|
|
" \"gmna\": \"bagaimana\",\n",
|
|
" \"tp\": \"tapi\",\n",
|
|
" \"jd\": \"jadi\",\n",
|
|
" \"aja\": \"saja\",\n",
|
|
" \"krn\": \"karena\",\n",
|
|
" \"blm\": \"belum\",\n",
|
|
" \"dgn\": \"dengan\",\n",
|
|
" \"skrg\": \"sekarang\",\n",
|
|
" \"msh\": \"masih\",\n",
|
|
" \"lg\": \"lagi\",\n",
|
|
" \"sy\": \"saya\",\n",
|
|
" \"sm\": \"sama\",\n",
|
|
" \"bgt\": \"banget\",\n",
|
|
" \"dr\": \"dari\",\n",
|
|
" \"kpn\": \"kapan\",\n",
|
|
" \"hrs\": \"harus\",\n",
|
|
" \"cm\": \"cuma\",\n",
|
|
" \"sbnrnya\": \"sebenarnya\",\n",
|
|
" \"tdr\": \"tidur\",\n",
|
|
" \"tdk\": \"tidak\",\n",
|
|
" \"kl\": \"kalau\",\n",
|
|
" \"org\": \"orang\",\n",
|
|
" \"pke\": \"pakai\",\n",
|
|
" \"prnh\": \"pernah\",\n",
|
|
" \"brgkt\": \"berangkat\",\n",
|
|
" \"pdhl\": \"padahal\",\n",
|
|
" \"btw\": \"ngomong-ngomong\",\n",
|
|
" \"dmn\": \"di mana\",\n",
|
|
" \"bsk\": \"besok\",\n",
|
|
" \"td\": \"tadi\",\n",
|
|
" \"dlm\": \"dalam\",\n",
|
|
" \"utk\": \"untuk\",\n",
|
|
" \"spt\": \"seperti\",\n",
|
|
" \"gpp\": \"tidak apa-apa\",\n",
|
|
" \"bs\": \"bisa\",\n",
|
|
" \"jg\": \"juga\",\n",
|
|
" \"tp\": \"tapi\",\n",
|
|
" \"dg\": \"dengan\",\n",
|
|
" \"klw\": \"kalau\",\n",
|
|
" \"wkwk\": \"haha\",\n",
|
|
" \"cpt\": \"cepat\",\n",
|
|
" \"knp\": \"kenapa\",\n",
|
|
" \"jgk\": \"juga\",\n",
|
|
" \"plg\": \"pulang\",\n",
|
|
" \"brp\": \"berapa\",\n",
|
|
" \"bkn\": \"bukan\",\n",
|
|
" \"mnt\": \"minta\",\n",
|
|
" \"udh\": \"sudah\",\n",
|
|
" \"sdh\": \"sudah\",\n",
|
|
" \"brkt\": \"berangkat\",\n",
|
|
" \"btw\": \"by the way\",\n",
|
|
" \"tdk\": \"tidak\",\n",
|
|
" \"sprt\": \"seperti\",\n",
|
|
" \"jgn\": \"jangan\",\n",
|
|
" \"mlm\": \"malam\",\n",
|
|
" \"sblm\": \"sebelum\",\n",
|
|
" \"stlh\": \"setelah\",\n",
|
|
" \"tdr\": \"tidur\",\n",
|
|
" \"mlh\": \"malah\",\n",
|
|
" \"tmn\": \"teman\",\n",
|
|
"}\n",
|
|
"\n",
|
|
"\n",
|
|
"def text_preprocessing(text):\n",
|
|
" #doing lower case \n",
|
|
" text = text.lower()\n",
|
|
" \n",
|
|
" # remove symbol and read mark\n",
|
|
" text = text.translate(str.maketrans(\"\", \"\", string.punctuation))\n",
|
|
" \n",
|
|
" # remove blank space\n",
|
|
" text = re.sub(r\"\\s+\", \" \", text).strip()\n",
|
|
" \n",
|
|
" # word tokenize \n",
|
|
" tokens = word_tokenize(text)\n",
|
|
" \n",
|
|
" # normalassi kata\n",
|
|
" tokens = [normalization_dict[word] if word in normalization_dict else word for word in tokens] \n",
|
|
" \n",
|
|
" \n",
|
|
" # lemmatization\n",
|
|
" tokens = [lemmatizer.lemmatize(word) for word in tokens] \n",
|
|
" \n",
|
|
" # stopword removal\n",
|
|
" tokens = [word for word in tokens if word not in stop_words] \n",
|
|
" \n",
|
|
" return tokens\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 42,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"✅ Data processing complete!\n",
|
|
"Samples: 57\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# with open(\"dataset/training_dataset.json\", \"r\", encoding=\"utf-8\") as file:\n",
|
|
"# dataset = json.load(file)\n",
|
|
"\n",
|
|
"# === Extract Data so that each QA pair has its own context === #\n",
|
|
"contexts = []\n",
|
|
"questions = []\n",
|
|
"answers = []\n",
|
|
"question_types = []\n",
|
|
"\n",
|
|
"for entry in dataset:\n",
|
|
" processed_context = text_preprocessing(entry[\"context\"])\n",
|
|
" for qa in entry[\"qa_pairs\"]:\n",
|
|
" contexts.append(processed_context)\n",
|
|
" questions.append(text_preprocessing(qa[\"question\"]))\n",
|
|
" answers.append(text_preprocessing(qa[\"answer\"]))\n",
|
|
" question_types.append(qa[\"type\"])\n",
|
|
"\n",
|
|
"# === Initialize Tokenizer and fit on all text === #\n",
|
|
"tokenizer = Tokenizer(oov_token=\"<OOV>\")\n",
|
|
"tokenizer.fit_on_texts(contexts + questions + answers)\n",
|
|
"\n",
|
|
"# === Convert Text to Sequences === #\n",
|
|
"context_sequences = tokenizer.texts_to_sequences(contexts)\n",
|
|
"question_sequences = tokenizer.texts_to_sequences(questions)\n",
|
|
"answer_sequences = tokenizer.texts_to_sequences(answers)\n",
|
|
"\n",
|
|
"# === Define Max Length for Padding === #\n",
|
|
"MAX_LENGTH = 100\n",
|
|
"context_padded = pad_sequences(context_sequences, maxlen=MAX_LENGTH, padding=\"post\", truncating=\"post\")\n",
|
|
"question_padded = pad_sequences(question_sequences, maxlen=MAX_LENGTH, padding=\"post\", truncating=\"post\")\n",
|
|
"answer_padded = pad_sequences(answer_sequences, maxlen=MAX_LENGTH, padding=\"post\", truncating=\"post\")\n",
|
|
"\n",
|
|
"# Encode Question Types\n",
|
|
"question_type_dict = {\"fill_in_the_blank\": 0, \"true_false\": 1, \"multiple_choice\": 2}\n",
|
|
"question_type_labels = np.array([question_type_dict[q_type] for q_type in question_types])\n",
|
|
"\n",
|
|
"# Save the processed data (optional)\n",
|
|
"np.save(\"context_padded.npy\", context_padded)\n",
|
|
"np.save(\"question_padded.npy\", question_padded)\n",
|
|
"np.save(\"answer_padded.npy\", answer_padded)\n",
|
|
"np.save(\"question_type_labels.npy\", question_type_labels)\n",
|
|
"with open(\"tokenizer.pkl\", \"wb\") as handle:\n",
|
|
" pickle.dump(tokenizer, handle, protocol=pickle.HIGHEST_PROTOCOL)\n",
|
|
"\n",
|
|
"print(\"✅ Data processing complete!\")\n",
|
|
"print(\"Samples:\", context_padded.shape[0]) # This should now match the number of QA pairs\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 43,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Training samples: 45\n",
|
|
"Testing samples: 12\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# === Split Data into Training and Testing Sets === #\n",
|
|
"(context_train, context_test,\n",
|
|
" question_train, question_test,\n",
|
|
" answer_train, answer_test,\n",
|
|
" qtype_train, qtype_test) = train_test_split(\n",
|
|
" context_padded,\n",
|
|
" question_padded,\n",
|
|
" answer_padded,\n",
|
|
" question_type_labels,\n",
|
|
" test_size=0.2,\n",
|
|
" random_state=42\n",
|
|
")\n",
|
|
"\n",
|
|
"print(\"Training samples:\", context_train.shape[0])\n",
|
|
"print(\"Testing samples:\", context_test.shape[0])\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 48,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Epoch 1/10\n",
|
|
"\u001b[1m2/2\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 536ms/step - answer_output_accuracy: 0.0143 - answer_output_loss: 5.8964 - loss: 12.9070 - question_output_accuracy: 3.7037e-04 - question_output_loss: 5.9030 - question_type_output_accuracy: 0.3843 - question_type_output_loss: 1.0984 - val_answer_output_accuracy: 0.2689 - val_answer_output_loss: 5.8541 - val_loss: 12.8370 - val_question_output_accuracy: 0.0100 - val_question_output_loss: 5.8942 - val_question_type_output_accuracy: 0.4444 - val_question_type_output_loss: 1.0888\n",
|
|
"Epoch 2/10\n",
|
|
"\u001b[1m2/2\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 171ms/step - answer_output_accuracy: 0.3408 - answer_output_loss: 5.8273 - loss: 12.8054 - question_output_accuracy: 0.0138 - question_output_loss: 5.8815 - question_type_output_accuracy: 0.5868 - question_type_output_loss: 1.0860 - val_answer_output_accuracy: 0.9856 - val_answer_output_loss: 5.7489 - val_loss: 12.7114 - val_question_output_accuracy: 0.0100 - val_question_output_loss: 5.8837 - val_question_type_output_accuracy: 0.4444 - val_question_type_output_loss: 1.0788\n",
|
|
"Epoch 3/10\n",
|
|
"\u001b[1m2/2\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 172ms/step - answer_output_accuracy: 0.9826 - answer_output_loss: 5.6819 - loss: 12.6314 - question_output_accuracy: 0.0214 - question_output_loss: 5.8659 - question_type_output_accuracy: 0.5579 - question_type_output_loss: 1.0659 - val_answer_output_accuracy: 0.9856 - val_answer_output_loss: 5.4070 - val_loss: 12.3515 - val_question_output_accuracy: 0.0122 - val_question_output_loss: 5.8714 - val_question_type_output_accuracy: 0.4444 - val_question_type_output_loss: 1.0730\n",
|
|
"Epoch 4/10\n",
|
|
"\u001b[1m2/2\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 170ms/step - answer_output_accuracy: 0.9825 - answer_output_loss: 4.9562 - loss: 12.0453 - question_output_accuracy: 0.0214 - question_output_loss: 5.8386 - question_type_output_accuracy: 0.5972 - question_type_output_loss: 1.0674 - val_answer_output_accuracy: 0.9856 - val_answer_output_loss: 4.0363 - val_loss: 10.9596 - val_question_output_accuracy: 0.0078 - val_question_output_loss: 5.8545 - val_question_type_output_accuracy: 0.4444 - val_question_type_output_loss: 1.0688\n",
|
|
"Epoch 5/10\n",
|
|
"\u001b[1m2/2\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 180ms/step - answer_output_accuracy: 0.9827 - answer_output_loss: 3.0621 - loss: 10.1366 - question_output_accuracy: 0.0117 - question_output_loss: 5.8038 - question_type_output_accuracy: 0.5868 - question_type_output_loss: 1.0336 - val_answer_output_accuracy: 0.9856 - val_answer_output_loss: 1.7133 - val_loss: 8.5937 - val_question_output_accuracy: 0.0078 - val_question_output_loss: 5.8148 - val_question_type_output_accuracy: 0.5556 - val_question_type_output_loss: 1.0657\n",
|
|
"Epoch 6/10\n",
|
|
"\u001b[1m2/2\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 167ms/step - answer_output_accuracy: 0.9824 - answer_output_loss: 1.1229 - loss: 8.0884 - question_output_accuracy: 0.0054 - question_output_loss: 5.7361 - question_type_output_accuracy: 0.5394 - question_type_output_loss: 1.1969 - val_answer_output_accuracy: 0.9856 - val_answer_output_loss: 0.6490 - val_loss: 7.4129 - val_question_output_accuracy: 0.0078 - val_question_output_loss: 5.6864 - val_question_type_output_accuracy: 0.5556 - val_question_type_output_loss: 1.0775\n",
|
|
"Epoch 7/10\n",
|
|
"\u001b[1m2/2\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 171ms/step - answer_output_accuracy: 0.9824 - answer_output_loss: 0.5961 - loss: 7.2279 - question_output_accuracy: 0.0045 - question_output_loss: 5.4419 - question_type_output_accuracy: 0.4340 - question_type_output_loss: 1.1878 - val_answer_output_accuracy: 0.9856 - val_answer_output_loss: 0.5096 - val_loss: 7.2387 - val_question_output_accuracy: 0.0078 - val_question_output_loss: 5.5038 - val_question_type_output_accuracy: 0.2222 - val_question_type_output_loss: 1.2253\n",
|
|
"Epoch 8/10\n",
|
|
"\u001b[1m2/2\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 183ms/step - answer_output_accuracy: 0.9824 - answer_output_loss: 0.6082 - loss: 6.8731 - question_output_accuracy: 0.0045 - question_output_loss: 5.1820 - question_type_output_accuracy: 0.4132 - question_type_output_loss: 1.0811 - val_answer_output_accuracy: 0.9856 - val_answer_output_loss: 0.5227 - val_loss: 7.0044 - val_question_output_accuracy: 0.0067 - val_question_output_loss: 5.4302 - val_question_type_output_accuracy: 0.5556 - val_question_type_output_loss: 1.0514\n",
|
|
"Epoch 9/10\n",
|
|
"\u001b[1m2/2\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 164ms/step - answer_output_accuracy: 0.9824 - answer_output_loss: 0.6137 - loss: 6.7673 - question_output_accuracy: 0.0045 - question_output_loss: 4.9309 - question_type_output_accuracy: 0.4815 - question_type_output_loss: 1.3263 - val_answer_output_accuracy: 0.9856 - val_answer_output_loss: 0.5185 - val_loss: 6.9885 - val_question_output_accuracy: 0.0044 - val_question_output_loss: 5.4844 - val_question_type_output_accuracy: 0.6667 - val_question_type_output_loss: 0.9857\n",
|
|
"Epoch 10/10\n",
|
|
"\u001b[1m2/2\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 176ms/step - answer_output_accuracy: 0.9824 - answer_output_loss: 0.6061 - loss: 6.6094 - question_output_accuracy: 0.0045 - question_output_loss: 4.8379 - question_type_output_accuracy: 0.4132 - question_type_output_loss: 1.1360 - val_answer_output_accuracy: 0.9856 - val_answer_output_loss: 0.4949 - val_loss: 7.1082 - val_question_output_accuracy: 0.0044 - val_question_output_loss: 5.5818 - val_question_type_output_accuracy: 0.5556 - val_question_type_output_loss: 1.0315\n",
|
|
"✅ Model training complete and saved!\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# === Model Hyperparameters === #\n",
|
|
"VOCAB_SIZE = len(tokenizer.word_index) + 1\n",
|
|
"EMBEDDING_DIM = 300\n",
|
|
"LSTM_UNITS = 256\n",
|
|
"BATCH_SIZE = 32\n",
|
|
"EPOCHS = 10\n",
|
|
"\n",
|
|
"# === Build Model === #\n",
|
|
"# Encoder for Context\n",
|
|
"context_input = Input(shape=(MAX_LENGTH,), name=\"context_input\")\n",
|
|
"context_embedding = Embedding(input_dim=VOCAB_SIZE, output_dim=EMBEDDING_DIM, mask_zero=True, name=\"context_embedding\")(context_input)\n",
|
|
"encoder_lstm = LSTM(LSTM_UNITS, return_state=True, name=\"encoder_lstm\")\n",
|
|
"encoder_output, state_h, state_c = encoder_lstm(context_embedding)\n",
|
|
"\n",
|
|
"# Decoder for Question (Teacher Forcing)\n",
|
|
"question_decoder_input = Input(shape=(MAX_LENGTH,), name=\"question_decoder_input\")\n",
|
|
"question_embedding = Embedding(input_dim=VOCAB_SIZE, output_dim=EMBEDDING_DIM, mask_zero=True, name=\"question_embedding\")(question_decoder_input)\n",
|
|
"question_lstm = LSTM(LSTM_UNITS, return_sequences=True, return_state=True, name=\"question_lstm\")\n",
|
|
"question_output, _, _ = question_lstm(question_embedding, initial_state=[state_h, state_c])\n",
|
|
"question_dense = Dense(VOCAB_SIZE, activation=\"softmax\", name=\"question_output\")(question_output)\n",
|
|
"\n",
|
|
"# Decoder for Answer\n",
|
|
"answer_lstm = LSTM(LSTM_UNITS, return_sequences=True, return_state=True, name=\"answer_lstm\")\n",
|
|
"answer_output, _, _ = answer_lstm(context_embedding, initial_state=[state_h, state_c])\n",
|
|
"answer_dense = Dense(VOCAB_SIZE, activation=\"softmax\", name=\"answer_output\")(answer_output)\n",
|
|
"\n",
|
|
"# Classification Output for Question Type\n",
|
|
"type_dense = Dense(128, activation=\"relu\")(encoder_output)\n",
|
|
"question_type_output = Dense(3, activation=\"softmax\", name=\"question_type_output\")(type_dense)\n",
|
|
"\n",
|
|
"# Construct the Model\n",
|
|
"model = Model(\n",
|
|
" inputs=[context_input, question_decoder_input],\n",
|
|
" outputs=[question_dense, answer_dense, question_type_output],\n",
|
|
")\n",
|
|
"\n",
|
|
"# === Compile the Model === #\n",
|
|
"model.compile(\n",
|
|
" optimizer=\"adam\",\n",
|
|
" loss={\n",
|
|
" \"question_output\": \"sparse_categorical_crossentropy\",\n",
|
|
" \"answer_output\": \"sparse_categorical_crossentropy\",\n",
|
|
" \"question_type_output\": \"sparse_categorical_crossentropy\",\n",
|
|
" },\n",
|
|
" metrics={\n",
|
|
" \"question_output\": [\"accuracy\"],\n",
|
|
" \"answer_output\": [\"accuracy\"],\n",
|
|
" \"question_type_output\": [\"accuracy\"],\n",
|
|
" },\n",
|
|
")\n",
|
|
"\n",
|
|
"# === Train the Model === #\n",
|
|
"model.fit(\n",
|
|
" [context_train, question_train],\n",
|
|
" {\n",
|
|
" \"question_output\": question_train,\n",
|
|
" \"answer_output\": answer_train,\n",
|
|
" \"question_type_output\": qtype_train,\n",
|
|
" },\n",
|
|
" batch_size=BATCH_SIZE,\n",
|
|
" epochs=EPOCHS,\n",
|
|
" validation_split=0.2,\n",
|
|
")\n",
|
|
"\n",
|
|
"# Save the Model\n",
|
|
"model.save(\"lstm_multi_output_model.keras\")\n",
|
|
"print(\"✅ Model training complete and saved!\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 49,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 255ms/step\n",
|
|
"=== Evaluation on Test Data ===\n",
|
|
"Classification Report for Question Type:\n",
|
|
" precision recall f1-score support\n",
|
|
"\n",
|
|
" 0 0.33 0.40 0.36 5\n",
|
|
" 1 0.33 0.50 0.40 4\n",
|
|
" 2 0.00 0.00 0.00 3\n",
|
|
"\n",
|
|
" accuracy 0.33 12\n",
|
|
" macro avg 0.22 0.30 0.25 12\n",
|
|
"weighted avg 0.25 0.33 0.28 12\n",
|
|
"\n",
|
|
"Accuracy: 0.3333333333333333\n",
|
|
"Precision: 0.25\n",
|
|
"Recall: 0.3333333333333333\n",
|
|
"BLEU score for first test sample (question generation): 0\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"/mnt/disc1/code/thesis_quiz_project/lstm-quiz/myenv/lib64/python3.10/site-packages/sklearn/metrics/_classification.py:1565: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
|
|
" _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n",
|
|
"/mnt/disc1/code/thesis_quiz_project/lstm-quiz/myenv/lib64/python3.10/site-packages/sklearn/metrics/_classification.py:1565: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
|
|
" _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n",
|
|
"/mnt/disc1/code/thesis_quiz_project/lstm-quiz/myenv/lib64/python3.10/site-packages/sklearn/metrics/_classification.py:1565: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
|
|
" _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n",
|
|
"/mnt/disc1/code/thesis_quiz_project/lstm-quiz/myenv/lib64/python3.10/site-packages/sklearn/metrics/_classification.py:1565: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
|
|
" _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# === Evaluate on Test Set === #\n",
|
|
"pred_question, pred_answer, pred_qtype = model.predict([context_test, question_test])\n",
|
|
"pred_qtype_labels = np.argmax(pred_qtype, axis=1)\n",
|
|
"\n",
|
|
"print(\"=== Evaluation on Test Data ===\")\n",
|
|
"print(\"Classification Report for Question Type:\")\n",
|
|
"print(classification_report(qtype_test, pred_qtype_labels))\n",
|
|
"print(\"Accuracy:\", accuracy_score(qtype_test, pred_qtype_labels))\n",
|
|
"print(\"Precision:\", precision_score(qtype_test, pred_qtype_labels, average='weighted'))\n",
|
|
"print(\"Recall:\", recall_score(qtype_test, pred_qtype_labels, average='weighted'))\n",
|
|
"\n",
|
|
"# Optional: Evaluate sequence generation using BLEU score for the first sample\n",
|
|
"import nltk\n",
|
|
"def sequence_to_text(sequence, tokenizer):\n",
|
|
" return [tokenizer.index_word.get(idx, \"<OOV>\") for idx in sequence if idx != 0]\n",
|
|
"\n",
|
|
"reference = [sequence_to_text(question_test[0], tokenizer)]\n",
|
|
"candidate = sequence_to_text(np.argmax(pred_question[0], axis=-1), tokenizer)\n",
|
|
"bleu_score = nltk.translate.bleu_score.sentence_bleu(reference, candidate)\n",
|
|
"print(\"BLEU score for first test sample (question generation):\", bleu_score)\n"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "myenv",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.10.16"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 2
|
|
}
|