TIF_E41211115_lstm-quiz-gen.../training_model.ipynb

604 lines
30 KiB
Plaintext

{
"cells": [
{
"cell_type": "code",
"execution_count": 47,
"metadata": {},
"outputs": [],
"source": [
"# import library\n",
"\n",
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import json\n",
"from tensorflow.keras.preprocessing.text import Tokenizer\n",
"from tensorflow.keras.preprocessing.sequence import pad_sequences\n",
"\n",
"import re\n",
"import string\n",
"import nltk\n",
"from nltk.corpus import stopwords\n",
"from nltk.tokenize import word_tokenize\n",
"from nltk.stem import WordNetLemmatizer\n",
"import pickle\n",
"\n",
"from tensorflow.keras.models import Model\n",
"from tensorflow.keras.layers import Input, Embedding, LSTM, Dense, Concatenate\n",
"from sklearn.metrics import classification_report, precision_score, recall_score, accuracy_score\n"
]
},
{
"cell_type": "code",
"execution_count": 48,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"[nltk_data] Downloading package stopwords to /home/akeon/nltk_data...\n",
"[nltk_data] Package stopwords is already up-to-date!\n",
"[nltk_data] Downloading package punkt to /home/akeon/nltk_data...\n",
"[nltk_data] Package punkt is already up-to-date!\n",
"[nltk_data] Downloading package punkt_tab to /home/akeon/nltk_data...\n",
"[nltk_data] Package punkt_tab is already up-to-date!\n",
"[nltk_data] Downloading package wordnet to /home/akeon/nltk_data...\n",
"[nltk_data] Package wordnet is already up-to-date!\n"
]
},
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 48,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# download assets\n",
"nltk.download(\"stopwords\")\n",
"nltk.download(\"punkt\")\n",
"nltk.download(\"punkt_tab\")\n",
"nltk.download(\"wordnet\")"
]
},
{
"cell_type": "code",
"execution_count": 49,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>context</th>\n",
" <th>qa_pairs</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Albert Einstein adalah fisikawan teoretis kela...</td>\n",
" <td>[{'type': 'fill_in_the_blank', 'question': '__...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>Samudra Pasifik adalah yang terbesar dan terda...</td>\n",
" <td>[{'type': 'fill_in_the_blank', 'question': 'Sa...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>Proklamasi Kemerdekaan Indonesia dibacakan pad...</td>\n",
" <td>[{'type': 'fill_in_the_blank', 'question': 'Pr...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>Hukum Newton adalah tiga hukum fisika yang men...</td>\n",
" <td>[{'type': 'fill_in_the_blank', 'question': 'Hu...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>Budi Utomo adalah organisasi pemuda yang didir...</td>\n",
" <td>[{'type': 'fill_in_the_blank', 'question': 'Bu...</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" context \\\n",
"0 Albert Einstein adalah fisikawan teoretis kela... \n",
"1 Samudra Pasifik adalah yang terbesar dan terda... \n",
"2 Proklamasi Kemerdekaan Indonesia dibacakan pad... \n",
"3 Hukum Newton adalah tiga hukum fisika yang men... \n",
"4 Budi Utomo adalah organisasi pemuda yang didir... \n",
"\n",
" qa_pairs \n",
"0 [{'type': 'fill_in_the_blank', 'question': '__... \n",
"1 [{'type': 'fill_in_the_blank', 'question': 'Sa... \n",
"2 [{'type': 'fill_in_the_blank', 'question': 'Pr... \n",
"3 [{'type': 'fill_in_the_blank', 'question': 'Hu... \n",
"4 [{'type': 'fill_in_the_blank', 'question': 'Bu... "
]
},
"execution_count": 49,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# load dataset\n",
"df = pd.read_json(\"independent_dataset.json\")\n",
"df.head()"
]
},
{
"cell_type": "code",
"execution_count": 50,
"metadata": {},
"outputs": [],
"source": [
"# Text Preprocessing\n",
"stop_words = set(stopwords.words(\"indonesian\")) \n",
"lemmatizer = WordNetLemmatizer()\n",
"\n",
"normalization_dict = {\n",
" \"yg\": \"yang\",\n",
" \"gokil\": \"kocak\",\n",
" \"kalo\": \"kalau\",\n",
" \"gue\": \"saya\",\n",
" \"elo\": \"kamu\",\n",
" \"nih\": \"ini\",\n",
" \"trs\": \"terus\",\n",
" \"tdk\": \"tidak\",\n",
" \"gmna\": \"bagaimana\",\n",
" \"tp\": \"tapi\",\n",
" \"jd\": \"jadi\",\n",
" \"aja\": \"saja\",\n",
" \"krn\": \"karena\",\n",
" \"blm\": \"belum\",\n",
" \"dgn\": \"dengan\",\n",
" \"skrg\": \"sekarang\",\n",
" \"msh\": \"masih\",\n",
" \"lg\": \"lagi\",\n",
" \"sy\": \"saya\",\n",
" \"sm\": \"sama\",\n",
" \"bgt\": \"banget\",\n",
" \"dr\": \"dari\",\n",
" \"kpn\": \"kapan\",\n",
" \"hrs\": \"harus\",\n",
" \"cm\": \"cuma\",\n",
" \"sbnrnya\": \"sebenarnya\",\n",
" \"tdr\": \"tidur\",\n",
" \"tdk\": \"tidak\",\n",
" \"kl\": \"kalau\",\n",
" \"org\": \"orang\",\n",
" \"pke\": \"pakai\",\n",
" \"prnh\": \"pernah\",\n",
" \"brgkt\": \"berangkat\",\n",
" \"pdhl\": \"padahal\",\n",
" \"btw\": \"ngomong-ngomong\",\n",
" \"dmn\": \"di mana\",\n",
" \"bsk\": \"besok\",\n",
" \"td\": \"tadi\",\n",
" \"dlm\": \"dalam\",\n",
" \"utk\": \"untuk\",\n",
" \"spt\": \"seperti\",\n",
" \"gpp\": \"tidak apa-apa\",\n",
" \"bs\": \"bisa\",\n",
" \"jg\": \"juga\",\n",
" \"tp\": \"tapi\",\n",
" \"dg\": \"dengan\",\n",
" \"klw\": \"kalau\",\n",
" \"wkwk\": \"haha\",\n",
" \"cpt\": \"cepat\",\n",
" \"knp\": \"kenapa\",\n",
" \"jgk\": \"juga\",\n",
" \"plg\": \"pulang\",\n",
" \"brp\": \"berapa\",\n",
" \"bkn\": \"bukan\",\n",
" \"mnt\": \"minta\",\n",
" \"udh\": \"sudah\",\n",
" \"sdh\": \"sudah\",\n",
" \"brkt\": \"berangkat\",\n",
" \"btw\": \"by the way\",\n",
" \"tdk\": \"tidak\",\n",
" \"sprt\": \"seperti\",\n",
" \"jgn\": \"jangan\",\n",
" \"mlm\": \"malam\",\n",
" \"sblm\": \"sebelum\",\n",
" \"stlh\": \"setelah\",\n",
" \"tdr\": \"tidur\",\n",
" \"mlh\": \"malah\",\n",
" \"tmn\": \"teman\",\n",
"}\n",
"\n",
"\n",
"def text_preprocessing(text):\n",
" #doing lower case \n",
" text = text.lower()\n",
" \n",
" # remove symbol and read mark\n",
" text = text.translate(str.maketrans(\"\", \"\", string.punctuation))\n",
" \n",
" # remove blank space\n",
" text = re.sub(r\"\\s+\", \" \", text).strip()\n",
" \n",
" # word tokenize \n",
" tokens = word_tokenize(text)\n",
" \n",
" # normalassi kata\n",
" tokens = [normalization_dict[word] if word in normalization_dict else word for word in tokens] \n",
" \n",
" \n",
" # lemmatization\n",
" tokens = [lemmatizer.lemmatize(word) for word in tokens] \n",
" \n",
" # stopword removal\n",
" tokens = [word for word in tokens if word not in stop_words] \n",
" \n",
" return tokens\n"
]
},
{
"cell_type": "code",
"execution_count": 51,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"✅ Vocabulary Size: 182\n",
"✅ Sample Tokenized Context: [ 9 10 91 38 92 93 39 5 19 94 95 11 96 97 40 98 99 100\n",
" 101 20 21 22 11 41 102 11 38 0 0 0 0 0 0 0 0 0\n",
" 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
" 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
" 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
" 0 0 0 0 0 0 0 0 0 0]\n",
"✅ Sample Tokenized Question: [39 5 19 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
" 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
" 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
" 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
" 0 0 0 0]\n",
"✅ Sample Tokenized Answer: [ 9 10 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
" 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
" 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
" 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
" 0 0 0 0]\n",
"✅ Sample Question Type Label: 0\n"
]
}
],
"source": [
"with open(\"independent_dataset.json\", \"r\", encoding=\"utf-8\") as file:\n",
" dataset = json.load(file)\n",
"\n",
"for entry in dataset:\n",
" entry[\"context\"] = text_preprocessing(entry[\"context\"])\n",
" for qa in entry[\"qa_pairs\"]:\n",
" qa[\"question\"] = text_preprocessing(qa[\"question\"])\n",
" qa[\"answer\"] = text_preprocessing(qa[\"answer\"])\n",
"\n",
"# === Extract Contexts, Questions, Answers, and Question Types === #\n",
"contexts = [entry[\"context\"] for entry in dataset]\n",
"questions = [qa[\"question\"] for entry in dataset for qa in entry[\"qa_pairs\"]]\n",
"answers = [qa[\"answer\"] for entry in dataset for qa in entry[\"qa_pairs\"]]\n",
"question_types = [qa[\"type\"] for entry in dataset for qa in entry[\"qa_pairs\"]] # Extract Question Types\n",
"\n",
"# === Initialize Tokenizer === #\n",
"tokenizer = Tokenizer(oov_token=\"<OOV>\")\n",
"tokenizer.fit_on_texts(contexts + questions + answers)\n",
"\n",
"# === Convert Text to Sequences === #\n",
"context_sequences = tokenizer.texts_to_sequences(contexts)\n",
"question_sequences = tokenizer.texts_to_sequences(questions)\n",
"answer_sequences = tokenizer.texts_to_sequences(answers)\n",
"\n",
"# === Define Max Length for Padding === #\n",
"MAX_LENGTH = 100 # Adjust based on dataset analysis\n",
"context_padded = pad_sequences(context_sequences, maxlen=MAX_LENGTH, padding=\"post\", truncating=\"post\")\n",
"question_padded = pad_sequences(question_sequences, maxlen=MAX_LENGTH, padding=\"post\", truncating=\"post\")\n",
"answer_padded = pad_sequences(answer_sequences, maxlen=MAX_LENGTH, padding=\"post\", truncating=\"post\")\n",
"\n",
"# === Encode Question Types (Convert Categorical Labels to Numeric) === #\n",
"question_type_dict = {\"fill_in_the_blank\": 0, \"true_false\": 1, \"multiple_choice\": 2}\n",
"question_type_labels = np.array([question_type_dict[q_type] for q_type in question_types])\n",
"\n",
"# === Save Processed Data as .npy Files === #\n",
"np.save(\"context_padded.npy\", context_padded)\n",
"np.save(\"question_padded.npy\", question_padded)\n",
"np.save(\"answer_padded.npy\", answer_padded)\n",
"np.save(\"question_type_labels.npy\", question_type_labels)\n",
"\n",
"# Save Tokenizer for Future Use\n",
"with open(\"tokenizer.pkl\", \"wb\") as handle:\n",
" pickle.dump(tokenizer, handle, protocol=pickle.HIGHEST_PROTOCOL)\n",
"\n",
"# === Check Results === #\n",
"print(f\"✅ Vocabulary Size: {len(tokenizer.word_index) + 1}\")\n",
"print(f\"✅ Sample Tokenized Context: {context_padded[0]}\")\n",
"print(f\"✅ Sample Tokenized Question: {question_padded[0]}\")\n",
"print(f\"✅ Sample Tokenized Answer: {answer_padded[0]}\")\n",
"print(f\"✅ Sample Question Type Label: {question_type_labels[0]}\")"
]
},
{
"cell_type": "code",
"execution_count": 52,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/10\n",
"\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 3s/step - answer_output_accuracy: 0.0000e+00 - answer_output_loss: 5.2070 - loss: 11.5152 - question_output_accuracy: 0.0000e+00 - question_output_loss: 5.2081 - question_type_output_accuracy: 0.3333 - question_type_output_loss: 1.1002 - val_answer_output_accuracy: 0.1250 - val_answer_output_loss: 5.1854 - val_loss: 11.4804 - val_question_output_accuracy: 0.0000e+00 - val_question_output_loss: 5.2043 - val_question_type_output_accuracy: 1.0000 - val_question_type_output_loss: 1.0907\n",
"Epoch 2/10\n",
"\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 184ms/step - answer_output_accuracy: 0.2100 - answer_output_loss: 5.1680 - loss: 11.4156 - question_output_accuracy: 0.0167 - question_output_loss: 5.1820 - question_type_output_accuracy: 1.0000 - question_type_output_loss: 1.0656 - val_answer_output_accuracy: 0.2450 - val_answer_output_loss: 5.1625 - val_loss: 11.4545 - val_question_output_accuracy: 0.0000e+00 - val_question_output_loss: 5.2056 - val_question_type_output_accuracy: 1.0000 - val_question_type_output_loss: 1.0864\n",
"Epoch 3/10\n",
"\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 188ms/step - answer_output_accuracy: 0.2717 - answer_output_loss: 5.1203 - loss: 11.3080 - question_output_accuracy: 0.0250 - question_output_loss: 5.1552 - question_type_output_accuracy: 1.0000 - question_type_output_loss: 1.0326 - val_answer_output_accuracy: 0.3350 - val_answer_output_loss: 5.1270 - val_loss: 11.4122 - val_question_output_accuracy: 0.0000e+00 - val_question_output_loss: 5.2071 - val_question_type_output_accuracy: 1.0000 - val_question_type_output_loss: 1.0781\n",
"Epoch 4/10\n",
"\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 174ms/step - answer_output_accuracy: 0.3417 - answer_output_loss: 5.0458 - loss: 11.1663 - question_output_accuracy: 0.0333 - question_output_loss: 5.1257 - question_type_output_accuracy: 1.0000 - question_type_output_loss: 0.9948 - val_answer_output_accuracy: 0.9700 - val_answer_output_loss: 5.0661 - val_loss: 11.3417 - val_question_output_accuracy: 0.0000e+00 - val_question_output_loss: 5.2090 - val_question_type_output_accuracy: 1.0000 - val_question_type_output_loss: 1.0665\n",
"Epoch 5/10\n",
"\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 166ms/step - answer_output_accuracy: 0.9883 - answer_output_loss: 4.9145 - loss: 10.9538 - question_output_accuracy: 0.0333 - question_output_loss: 5.0917 - question_type_output_accuracy: 1.0000 - question_type_output_loss: 0.9476 - val_answer_output_accuracy: 0.9700 - val_answer_output_loss: 4.9497 - val_loss: 11.2111 - val_question_output_accuracy: 0.0000e+00 - val_question_output_loss: 5.2115 - val_question_type_output_accuracy: 1.0000 - val_question_type_output_loss: 1.0498\n",
"Epoch 6/10\n",
"\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 176ms/step - answer_output_accuracy: 0.9883 - answer_output_loss: 4.6563 - loss: 10.5922 - question_output_accuracy: 0.0333 - question_output_loss: 5.0504 - question_type_output_accuracy: 1.0000 - question_type_output_loss: 0.8855 - val_answer_output_accuracy: 0.9700 - val_answer_output_loss: 4.6939 - val_loss: 10.9339 - val_question_output_accuracy: 0.0000e+00 - val_question_output_loss: 5.2154 - val_question_type_output_accuracy: 1.0000 - val_question_type_output_loss: 1.0247\n",
"Epoch 7/10\n",
"\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 189ms/step - answer_output_accuracy: 0.9867 - answer_output_loss: 4.0798 - loss: 9.8799 - question_output_accuracy: 0.0333 - question_output_loss: 4.9974 - question_type_output_accuracy: 1.0000 - question_type_output_loss: 0.8027 - val_answer_output_accuracy: 0.9700 - val_answer_output_loss: 4.0274 - val_loss: 10.2376 - val_question_output_accuracy: 0.0000e+00 - val_question_output_loss: 5.2215 - val_question_type_output_accuracy: 1.0000 - val_question_type_output_loss: 0.9888\n",
"Epoch 8/10\n",
"\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 172ms/step - answer_output_accuracy: 0.9867 - answer_output_loss: 2.9405 - loss: 8.5655 - question_output_accuracy: 0.0333 - question_output_loss: 4.9250 - question_type_output_accuracy: 1.0000 - question_type_output_loss: 0.7000 - val_answer_output_accuracy: 0.9700 - val_answer_output_loss: 2.6667 - val_loss: 8.8361 - val_question_output_accuracy: 0.0000e+00 - val_question_output_loss: 5.2325 - val_question_type_output_accuracy: 1.0000 - val_question_type_output_loss: 0.9369\n",
"Epoch 9/10\n",
"\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 186ms/step - answer_output_accuracy: 0.9867 - answer_output_loss: 1.7502 - loss: 7.1458 - question_output_accuracy: 0.0317 - question_output_loss: 4.8152 - question_type_output_accuracy: 1.0000 - question_type_output_loss: 0.5805 - val_answer_output_accuracy: 0.9700 - val_answer_output_loss: 1.5605 - val_loss: 7.6711 - val_question_output_accuracy: 0.0000e+00 - val_question_output_loss: 5.2582 - val_question_type_output_accuracy: 0.5000 - val_question_type_output_loss: 0.8525\n",
"Epoch 10/10\n",
"\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 196ms/step - answer_output_accuracy: 0.9867 - answer_output_loss: 0.9620 - loss: 6.0314 - question_output_accuracy: 0.0250 - question_output_loss: 4.6111 - question_type_output_accuracy: 1.0000 - question_type_output_loss: 0.4584 - val_answer_output_accuracy: 0.9700 - val_answer_output_loss: 0.8570 - val_loss: 6.9187 - val_question_output_accuracy: 0.0000e+00 - val_question_output_loss: 5.3395 - val_question_type_output_accuracy: 0.5000 - val_question_type_output_loss: 0.7221\n",
"✅ Model LSTM Multi-Output berhasil dilatih dan disimpan!\n"
]
}
],
"source": [
"# Implementation of lstm with semantic analyz\n",
"# === Load Tokenizer === #\n",
"with open(\"tokenizer.pkl\", \"rb\") as handle:\n",
" tokenizer = pickle.load(handle)\n",
"\n",
"# === Load Data yang Sudah Diproses === #\n",
"MAX_LENGTH = 100\n",
"VOCAB_SIZE = len(tokenizer.word_index) + 1\n",
"\n",
"context_padded = np.load(\"context_padded.npy\")\n",
"question_padded = np.load(\"question_padded.npy\")\n",
"answer_padded = np.load(\"answer_padded.npy\")\n",
"question_type_labels = np.load(\n",
" \"question_type_labels.npy\"\n",
") # Label tipe soal (0 = Fill, 1 = True/False, 2 = Multiple Choice)\n",
"\n",
"# === Hyperparameter === #\n",
"EMBEDDING_DIM = 300\n",
"LSTM_UNITS = 256\n",
"BATCH_SIZE = 32\n",
"EPOCHS = 10\n",
"\n",
"\n",
"# === Input Encoder (Konteks) === #\n",
"context_input = Input(shape=(MAX_LENGTH,), name=\"context_input\")\n",
"context_embedding = Embedding(\n",
" input_dim=VOCAB_SIZE,\n",
" output_dim=EMBEDDING_DIM,\n",
" mask_zero=True,\n",
" name=\"context_embedding\",\n",
")(context_input)\n",
"encoder_lstm = LSTM(LSTM_UNITS, return_state=True, name=\"encoder_lstm\")\n",
"encoder_output, state_h, state_c = encoder_lstm(context_embedding)\n",
"\n",
"# === Decoder untuk Pertanyaan === #\n",
"question_decoder_input = Input(shape=(MAX_LENGTH,), name=\"question_decoder_input\")\n",
"question_embedding = Embedding(\n",
" input_dim=VOCAB_SIZE,\n",
" output_dim=EMBEDDING_DIM,\n",
" mask_zero=True,\n",
" name=\"question_embedding\",\n",
")(question_decoder_input)\n",
"question_lstm = LSTM(\n",
" LSTM_UNITS, return_sequences=True, return_state=True, name=\"question_lstm\"\n",
")\n",
"question_output, _, _ = question_lstm(\n",
" question_embedding, initial_state=[state_h, state_c]\n",
")\n",
"question_dense = Dense(VOCAB_SIZE, activation=\"softmax\", name=\"question_output\")(\n",
" question_output\n",
")\n",
"\n",
"# === Decoder untuk Jawaban === #\n",
"answer_lstm = LSTM(\n",
" LSTM_UNITS, return_sequences=True, return_state=True, name=\"answer_lstm\"\n",
")\n",
"answer_output, _, _ = answer_lstm(context_embedding, initial_state=[state_h, state_c])\n",
"answer_dense = Dense(VOCAB_SIZE, activation=\"softmax\", name=\"answer_output\")(\n",
" answer_output\n",
")\n",
"\n",
"# === Prediksi Tipe Soal (Fill, True/False, Multiple Choice) === #\n",
"type_dense = Dense(128, activation=\"relu\")(encoder_output)\n",
"question_type_output = Dense(3, activation=\"softmax\", name=\"question_type_output\")(\n",
" type_dense\n",
") # 3 Kategori soal\n",
"\n",
"# === Membangun Model Multi-Output === #\n",
"model = Model(\n",
" inputs=[context_input, question_decoder_input],\n",
" outputs=[question_dense, answer_dense, question_type_output],\n",
")\n",
"\n",
"# === Compile Model === #\n",
"# Compile Model (Fix for multiple outputs)\n",
"model.compile(\n",
" optimizer=\"adam\",\n",
" loss={\n",
" \"question_output\": \"sparse_categorical_crossentropy\",\n",
" \"answer_output\": \"sparse_categorical_crossentropy\",\n",
" \"question_type_output\": \"sparse_categorical_crossentropy\",\n",
" },\n",
" metrics={\n",
" \"question_output\": [\"accuracy\"],\n",
" \"answer_output\": [\"accuracy\"],\n",
" \"question_type_output\": [\"accuracy\"],\n",
" },\n",
")\n",
"\n",
"# === Training Model === #\n",
"model.fit(\n",
" [context_padded, question_padded],\n",
" {\n",
" \"question_output\": question_padded,\n",
" \"answer_output\": answer_padded,\n",
" \"question_type_output\": question_type_labels,\n",
" },\n",
" batch_size=BATCH_SIZE,\n",
" epochs=EPOCHS,\n",
" validation_split=0.2,\n",
")\n",
"\n",
"# === Simpan Model === #\n",
"model.save(\"lstm_multi_output_model.keras\")\n",
"\n",
"print(\"✅ Model LSTM Multi-Output berhasil dilatih dan disimpan!\")"
]
},
{
"cell_type": "code",
"execution_count": 53,
"metadata": {},
"outputs": [
{
"ename": "NameError",
"evalue": "name 'context_padded_test' is not defined",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[53], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m predictions \u001b[38;5;241m=\u001b[39m model\u001b[38;5;241m.\u001b[39mpredict([\u001b[43mcontext_padded_test\u001b[49m, question_padded_test])\n\u001b[1;32m 3\u001b[0m \u001b[38;5;66;03m# predictions[0] corresponds to question_output (shape: [batch_size, MAX_LENGTH, VOCAB_SIZE])\u001b[39;00m\n\u001b[1;32m 4\u001b[0m \u001b[38;5;66;03m# predictions[1] corresponds to answer_output (shape: [batch_size, MAX_LENGTH, VOCAB_SIZE])\u001b[39;00m\n\u001b[1;32m 5\u001b[0m \u001b[38;5;66;03m# predictions[2] corresponds to question_type_output (shape: [batch_size, 3])\u001b[39;00m\n\u001b[1;32m 6\u001b[0m \n\u001b[1;32m 7\u001b[0m \u001b[38;5;66;03m# Convert probabilities to predicted class indices\u001b[39;00m\n\u001b[1;32m 8\u001b[0m question_output_pred \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39margmax(predictions[\u001b[38;5;241m0\u001b[39m], axis\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m) \u001b[38;5;66;03m# shape: (batch_size, MAX_LENGTH)\u001b[39;00m\n",
"\u001b[0;31mNameError\u001b[0m: name 'context_padded_test' is not defined"
]
}
],
"source": [
"predictions = model.predict([context_padded_test, question_padded_test])\n",
"\n",
"# predictions[0] corresponds to question_output (shape: [batch_size, MAX_LENGTH, VOCAB_SIZE])\n",
"# predictions[1] corresponds to answer_output (shape: [batch_size, MAX_LENGTH, VOCAB_SIZE])\n",
"# predictions[2] corresponds to question_type_output (shape: [batch_size, 3])\n",
"\n",
"# Convert probabilities to predicted class indices\n",
"question_output_pred = np.argmax(predictions[0], axis=-1) # shape: (batch_size, MAX_LENGTH)\n",
"answer_output_pred = np.argmax(predictions[1], axis=-1) # shape: (batch_size, MAX_LENGTH)\n",
"question_type_pred = np.argmax(predictions[2], axis=-1) # shape: (batch_size,)\n",
"\n",
"# === 3. Evaluate QUESTION TYPE (single-label classification) === #\n",
"print(\"=== Evaluation for Question Type ===\")\n",
"print(classification_report(\n",
" question_type_test, # True labels\n",
" question_type_pred, # Predicted labels\n",
" target_names=[\"Fill\", \"True/False\", \"Multiple Choice\"], # Optionally label your classes\n",
" zero_division=0 # Avoids warning if a class is absent\n",
"))\n",
"\n",
"# If you just want separate metrics (macro-average for multi-class):\n",
"acc_qtype = accuracy_score(question_type_test, question_type_pred)\n",
"prec_qtype = precision_score(question_type_test, question_type_pred, average='macro', zero_division=0)\n",
"rec_qtype = recall_score(question_type_test, question_type_pred, average='macro', zero_division=0)\n",
"\n",
"print(f\"Question Type -> Accuracy: {acc_qtype:.4f}, Precision(macro): {prec_qtype:.4f}, Recall(macro): {rec_qtype:.4f}\")\n",
"print(\"\")\n",
"\n",
"# === 4. Evaluate QUESTION OUTPUT & ANSWER OUTPUT (sequence predictions) === #\n",
"# We do a token-level comparison. We must exclude padded positions to get a fair score.\n",
"\n",
"# A helper function to flatten predictions & true labels while ignoring padding (zeros).\n",
"def flatten_and_mask(true_seq, pred_seq, pad_token=0):\n",
" \"\"\"\n",
" true_seq, pred_seq = [batch_size, MAX_LENGTH]\n",
" Returns flattened arrays of true & predicted labels, ignoring where true_seq == pad_token.\n",
" \"\"\"\n",
" mask = (true_seq != pad_token)\n",
" true_flat = true_seq[mask].flatten()\n",
" pred_flat = pred_seq[mask].flatten()\n",
" return true_flat, pred_flat\n",
"\n",
"# --- 4a. Question Output ---\n",
"q_true_flat, q_pred_flat = flatten_and_mask(question_padded_test, question_output_pred, pad_token=0)\n",
"\n",
"print(\"=== Evaluation for Question Tokens ===\")\n",
"print(classification_report(\n",
" q_true_flat, \n",
" q_pred_flat,\n",
" zero_division=0 # Avoid warnings if a class is absent\n",
"))\n",
"\n",
"acc_q = accuracy_score(q_true_flat, q_pred_flat)\n",
"prec_q = precision_score(q_true_flat, q_pred_flat, average='macro', zero_division=0)\n",
"rec_q = recall_score(q_true_flat, q_pred_flat, average='macro', zero_division=0)\n",
"print(f\"Question Tokens -> Accuracy: {acc_q:.4f}, Precision(macro): {prec_q:.4f}, Recall(macro): {rec_q:.4f}\")\n",
"print(\"\")\n",
"\n",
"# --- 4b. Answer Output ---\n",
"a_true_flat, a_pred_flat = flatten_and_mask(answer_padded_test, answer_output_pred, pad_token=0)\n",
"\n",
"print(\"=== Evaluation for Answer Tokens ===\")\n",
"print(classification_report(\n",
" a_true_flat,\n",
" a_pred_flat,\n",
" zero_division=0\n",
"))\n",
"\n",
"acc_a = accuracy_score(a_true_flat, a_pred_flat)\n",
"prec_a = precision_score(a_true_flat, a_pred_flat, average='macro', zero_division=0)\n",
"rec_a = recall_score(a_true_flat, a_pred_flat, average='macro', zero_division=0)\n",
"print(f\"Answer Tokens -> Accuracy: {acc_a:.4f}, Precision(macro): {prec_a:.4f}, Recall(macro): {rec_a:.4f}\")\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "myenv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.16"
}
},
"nbformat": 4,
"nbformat_minor": 2
}