TIF_E41211115_lstm-quiz-gen.../training_model.ipynb

522 lines
25 KiB
Plaintext

{
"cells": [
{
"cell_type": "code",
"execution_count": 38,
"metadata": {},
"outputs": [],
"source": [
"# import library\n",
"\n",
"# Data manipulation and visualization\n",
"import pandas as pd\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import json\n",
"\n",
"# Natural language processing\n",
"import re\n",
"import string\n",
"import nltk\n",
"from nltk.corpus import stopwords\n",
"from nltk.tokenize import word_tokenize\n",
"from nltk.stem import WordNetLemmatizer\n",
"\n",
"# Deep learning\n",
"from tensorflow.keras.preprocessing.text import Tokenizer\n",
"from tensorflow.keras.preprocessing.sequence import pad_sequences\n",
"from tensorflow.keras.models import Model\n",
"from tensorflow.keras.layers import Input, Embedding, LSTM, Dense, Concatenate\n",
"from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint\n",
"\n",
"\n",
"from sklearn.model_selection import train_test_split\n",
"\n",
"# Metrics for model evaluation\n",
"from sklearn.metrics import classification_report, precision_score, recall_score, accuracy_score\n",
"\n",
"# Utility for serialization\n",
"import pickle\n"
]
},
{
"cell_type": "code",
"execution_count": 39,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"[nltk_data] Downloading package stopwords to /home/akeon/nltk_data...\n",
"[nltk_data] Package stopwords is already up-to-date!\n",
"[nltk_data] Downloading package punkt to /home/akeon/nltk_data...\n",
"[nltk_data] Package punkt is already up-to-date!\n",
"[nltk_data] Downloading package punkt_tab to /home/akeon/nltk_data...\n",
"[nltk_data] Package punkt_tab is already up-to-date!\n",
"[nltk_data] Downloading package wordnet to /home/akeon/nltk_data...\n",
"[nltk_data] Package wordnet is already up-to-date!\n"
]
},
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 39,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# download assets\n",
"nltk.download(\"stopwords\")\n",
"nltk.download(\"punkt\")\n",
"nltk.download(\"punkt_tab\")\n",
"nltk.download(\"wordnet\")"
]
},
{
"cell_type": "code",
"execution_count": 40,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" context \\\n",
"0 Albert Einstein adalah fisikawan teoretis kela... \n",
"1 Samudra Pasifik adalah yang terbesar dan terda... \n",
"2 Proklamasi Kemerdekaan Indonesia dibacakan pad... \n",
"3 Hukum Newton adalah tiga hukum fisika yang men... \n",
"4 Budi Utomo adalah organisasi pemuda yang didir... \n",
"\n",
" qa_pairs \n",
"0 [{'type': 'fill_in_the_blank', 'question': 'Si... \n",
"1 [{'type': 'fill_in_the_blank', 'question': 'Sa... \n",
"2 [{'type': 'fill_in_the_blank', 'question': 'Pr... \n",
"3 [{'type': 'fill_in_the_blank', 'question': 'Hu... \n",
"4 [{'type': 'fill_in_the_blank', 'question': 'Bu... \n",
"\n",
"Total Context: 25\n",
"Total QA Pairs: 57\n"
]
}
],
"source": [
"# load dataset\n",
"df = pd.read_json(\"dataset/training_dataset.json\")\n",
"print(df.head())\n",
"with open(\"dataset/training_dataset.json\", \"r\", encoding=\"utf-8\") as file:\n",
" dataset = json.load(file)\n",
" \n",
" \n",
"# Menghitung total context\n",
"total_context = len(dataset)\n",
"\n",
"# Menghitung total qa_pairs\n",
"total_qa_pairs = sum(len(entry[\"qa_pairs\"]) for entry in dataset)\n",
"\n",
"# Menampilkan hasil\n",
"print(f\"\\nTotal Context: {total_context}\")\n",
"print(f\"Total QA Pairs: {total_qa_pairs}\")"
]
},
{
"cell_type": "code",
"execution_count": 41,
"metadata": {},
"outputs": [],
"source": [
"# Text Preprocessing\n",
"stop_words = set(stopwords.words(\"indonesian\")) \n",
"lemmatizer = WordNetLemmatizer()\n",
"\n",
"normalization_dict = {\n",
" \"yg\": \"yang\",\n",
" \"gokil\": \"kocak\",\n",
" \"kalo\": \"kalau\",\n",
" \"gue\": \"saya\",\n",
" \"elo\": \"kamu\",\n",
" \"nih\": \"ini\",\n",
" \"trs\": \"terus\",\n",
" \"tdk\": \"tidak\",\n",
" \"gmna\": \"bagaimana\",\n",
" \"tp\": \"tapi\",\n",
" \"jd\": \"jadi\",\n",
" \"aja\": \"saja\",\n",
" \"krn\": \"karena\",\n",
" \"blm\": \"belum\",\n",
" \"dgn\": \"dengan\",\n",
" \"skrg\": \"sekarang\",\n",
" \"msh\": \"masih\",\n",
" \"lg\": \"lagi\",\n",
" \"sy\": \"saya\",\n",
" \"sm\": \"sama\",\n",
" \"bgt\": \"banget\",\n",
" \"dr\": \"dari\",\n",
" \"kpn\": \"kapan\",\n",
" \"hrs\": \"harus\",\n",
" \"cm\": \"cuma\",\n",
" \"sbnrnya\": \"sebenarnya\",\n",
" \"tdr\": \"tidur\",\n",
" \"tdk\": \"tidak\",\n",
" \"kl\": \"kalau\",\n",
" \"org\": \"orang\",\n",
" \"pke\": \"pakai\",\n",
" \"prnh\": \"pernah\",\n",
" \"brgkt\": \"berangkat\",\n",
" \"pdhl\": \"padahal\",\n",
" \"btw\": \"ngomong-ngomong\",\n",
" \"dmn\": \"di mana\",\n",
" \"bsk\": \"besok\",\n",
" \"td\": \"tadi\",\n",
" \"dlm\": \"dalam\",\n",
" \"utk\": \"untuk\",\n",
" \"spt\": \"seperti\",\n",
" \"gpp\": \"tidak apa-apa\",\n",
" \"bs\": \"bisa\",\n",
" \"jg\": \"juga\",\n",
" \"tp\": \"tapi\",\n",
" \"dg\": \"dengan\",\n",
" \"klw\": \"kalau\",\n",
" \"wkwk\": \"haha\",\n",
" \"cpt\": \"cepat\",\n",
" \"knp\": \"kenapa\",\n",
" \"jgk\": \"juga\",\n",
" \"plg\": \"pulang\",\n",
" \"brp\": \"berapa\",\n",
" \"bkn\": \"bukan\",\n",
" \"mnt\": \"minta\",\n",
" \"udh\": \"sudah\",\n",
" \"sdh\": \"sudah\",\n",
" \"brkt\": \"berangkat\",\n",
" \"btw\": \"by the way\",\n",
" \"tdk\": \"tidak\",\n",
" \"sprt\": \"seperti\",\n",
" \"jgn\": \"jangan\",\n",
" \"mlm\": \"malam\",\n",
" \"sblm\": \"sebelum\",\n",
" \"stlh\": \"setelah\",\n",
" \"tdr\": \"tidur\",\n",
" \"mlh\": \"malah\",\n",
" \"tmn\": \"teman\",\n",
"}\n",
"\n",
"\n",
"def text_preprocessing(text):\n",
" #doing lower case \n",
" text = text.lower()\n",
" \n",
" # remove symbol and read mark\n",
" text = text.translate(str.maketrans(\"\", \"\", string.punctuation))\n",
" \n",
" # remove blank space\n",
" text = re.sub(r\"\\s+\", \" \", text).strip()\n",
" \n",
" # word tokenize \n",
" tokens = word_tokenize(text)\n",
" \n",
" # normalassi kata\n",
" tokens = [normalization_dict[word] if word in normalization_dict else word for word in tokens] \n",
" \n",
" \n",
" # lemmatization\n",
" tokens = [lemmatizer.lemmatize(word) for word in tokens] \n",
" \n",
" # stopword removal\n",
" tokens = [word for word in tokens if word not in stop_words] \n",
" \n",
" return tokens\n"
]
},
{
"cell_type": "code",
"execution_count": 42,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"✅ Data processing complete!\n",
"Samples: 57\n"
]
}
],
"source": [
"# with open(\"dataset/training_dataset.json\", \"r\", encoding=\"utf-8\") as file:\n",
"# dataset = json.load(file)\n",
"\n",
"# === Extract Data so that each QA pair has its own context === #\n",
"contexts = []\n",
"questions = []\n",
"answers = []\n",
"question_types = []\n",
"\n",
"for entry in dataset:\n",
" processed_context = text_preprocessing(entry[\"context\"])\n",
" for qa in entry[\"qa_pairs\"]:\n",
" contexts.append(processed_context)\n",
" questions.append(text_preprocessing(qa[\"question\"]))\n",
" answers.append(text_preprocessing(qa[\"answer\"]))\n",
" question_types.append(qa[\"type\"])\n",
"\n",
"# === Initialize Tokenizer and fit on all text === #\n",
"tokenizer = Tokenizer(oov_token=\"<OOV>\")\n",
"tokenizer.fit_on_texts(contexts + questions + answers)\n",
"\n",
"# === Convert Text to Sequences === #\n",
"context_sequences = tokenizer.texts_to_sequences(contexts)\n",
"question_sequences = tokenizer.texts_to_sequences(questions)\n",
"answer_sequences = tokenizer.texts_to_sequences(answers)\n",
"\n",
"# === Define Max Length for Padding === #\n",
"MAX_LENGTH = 100\n",
"context_padded = pad_sequences(context_sequences, maxlen=MAX_LENGTH, padding=\"post\", truncating=\"post\")\n",
"question_padded = pad_sequences(question_sequences, maxlen=MAX_LENGTH, padding=\"post\", truncating=\"post\")\n",
"answer_padded = pad_sequences(answer_sequences, maxlen=MAX_LENGTH, padding=\"post\", truncating=\"post\")\n",
"\n",
"# Encode Question Types\n",
"question_type_dict = {\"fill_in_the_blank\": 0, \"true_false\": 1, \"multiple_choice\": 2}\n",
"question_type_labels = np.array([question_type_dict[q_type] for q_type in question_types])\n",
"\n",
"# Save the processed data (optional)\n",
"np.save(\"context_padded.npy\", context_padded)\n",
"np.save(\"question_padded.npy\", question_padded)\n",
"np.save(\"answer_padded.npy\", answer_padded)\n",
"np.save(\"question_type_labels.npy\", question_type_labels)\n",
"with open(\"tokenizer.pkl\", \"wb\") as handle:\n",
" pickle.dump(tokenizer, handle, protocol=pickle.HIGHEST_PROTOCOL)\n",
"\n",
"print(\"✅ Data processing complete!\")\n",
"print(\"Samples:\", context_padded.shape[0]) # This should now match the number of QA pairs\n"
]
},
{
"cell_type": "code",
"execution_count": 43,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Training samples: 45\n",
"Testing samples: 12\n"
]
}
],
"source": [
"# === Split Data into Training and Testing Sets === #\n",
"(context_train, context_test,\n",
" question_train, question_test,\n",
" answer_train, answer_test,\n",
" qtype_train, qtype_test) = train_test_split(\n",
" context_padded,\n",
" question_padded,\n",
" answer_padded,\n",
" question_type_labels,\n",
" test_size=0.2,\n",
" random_state=42\n",
")\n",
"\n",
"print(\"Training samples:\", context_train.shape[0])\n",
"print(\"Testing samples:\", context_test.shape[0])\n"
]
},
{
"cell_type": "code",
"execution_count": 48,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/10\n",
"\u001b[1m2/2\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 536ms/step - answer_output_accuracy: 0.0143 - answer_output_loss: 5.8964 - loss: 12.9070 - question_output_accuracy: 3.7037e-04 - question_output_loss: 5.9030 - question_type_output_accuracy: 0.3843 - question_type_output_loss: 1.0984 - val_answer_output_accuracy: 0.2689 - val_answer_output_loss: 5.8541 - val_loss: 12.8370 - val_question_output_accuracy: 0.0100 - val_question_output_loss: 5.8942 - val_question_type_output_accuracy: 0.4444 - val_question_type_output_loss: 1.0888\n",
"Epoch 2/10\n",
"\u001b[1m2/2\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 171ms/step - answer_output_accuracy: 0.3408 - answer_output_loss: 5.8273 - loss: 12.8054 - question_output_accuracy: 0.0138 - question_output_loss: 5.8815 - question_type_output_accuracy: 0.5868 - question_type_output_loss: 1.0860 - val_answer_output_accuracy: 0.9856 - val_answer_output_loss: 5.7489 - val_loss: 12.7114 - val_question_output_accuracy: 0.0100 - val_question_output_loss: 5.8837 - val_question_type_output_accuracy: 0.4444 - val_question_type_output_loss: 1.0788\n",
"Epoch 3/10\n",
"\u001b[1m2/2\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 172ms/step - answer_output_accuracy: 0.9826 - answer_output_loss: 5.6819 - loss: 12.6314 - question_output_accuracy: 0.0214 - question_output_loss: 5.8659 - question_type_output_accuracy: 0.5579 - question_type_output_loss: 1.0659 - val_answer_output_accuracy: 0.9856 - val_answer_output_loss: 5.4070 - val_loss: 12.3515 - val_question_output_accuracy: 0.0122 - val_question_output_loss: 5.8714 - val_question_type_output_accuracy: 0.4444 - val_question_type_output_loss: 1.0730\n",
"Epoch 4/10\n",
"\u001b[1m2/2\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 170ms/step - answer_output_accuracy: 0.9825 - answer_output_loss: 4.9562 - loss: 12.0453 - question_output_accuracy: 0.0214 - question_output_loss: 5.8386 - question_type_output_accuracy: 0.5972 - question_type_output_loss: 1.0674 - val_answer_output_accuracy: 0.9856 - val_answer_output_loss: 4.0363 - val_loss: 10.9596 - val_question_output_accuracy: 0.0078 - val_question_output_loss: 5.8545 - val_question_type_output_accuracy: 0.4444 - val_question_type_output_loss: 1.0688\n",
"Epoch 5/10\n",
"\u001b[1m2/2\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 180ms/step - answer_output_accuracy: 0.9827 - answer_output_loss: 3.0621 - loss: 10.1366 - question_output_accuracy: 0.0117 - question_output_loss: 5.8038 - question_type_output_accuracy: 0.5868 - question_type_output_loss: 1.0336 - val_answer_output_accuracy: 0.9856 - val_answer_output_loss: 1.7133 - val_loss: 8.5937 - val_question_output_accuracy: 0.0078 - val_question_output_loss: 5.8148 - val_question_type_output_accuracy: 0.5556 - val_question_type_output_loss: 1.0657\n",
"Epoch 6/10\n",
"\u001b[1m2/2\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 167ms/step - answer_output_accuracy: 0.9824 - answer_output_loss: 1.1229 - loss: 8.0884 - question_output_accuracy: 0.0054 - question_output_loss: 5.7361 - question_type_output_accuracy: 0.5394 - question_type_output_loss: 1.1969 - val_answer_output_accuracy: 0.9856 - val_answer_output_loss: 0.6490 - val_loss: 7.4129 - val_question_output_accuracy: 0.0078 - val_question_output_loss: 5.6864 - val_question_type_output_accuracy: 0.5556 - val_question_type_output_loss: 1.0775\n",
"Epoch 7/10\n",
"\u001b[1m2/2\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 171ms/step - answer_output_accuracy: 0.9824 - answer_output_loss: 0.5961 - loss: 7.2279 - question_output_accuracy: 0.0045 - question_output_loss: 5.4419 - question_type_output_accuracy: 0.4340 - question_type_output_loss: 1.1878 - val_answer_output_accuracy: 0.9856 - val_answer_output_loss: 0.5096 - val_loss: 7.2387 - val_question_output_accuracy: 0.0078 - val_question_output_loss: 5.5038 - val_question_type_output_accuracy: 0.2222 - val_question_type_output_loss: 1.2253\n",
"Epoch 8/10\n",
"\u001b[1m2/2\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 183ms/step - answer_output_accuracy: 0.9824 - answer_output_loss: 0.6082 - loss: 6.8731 - question_output_accuracy: 0.0045 - question_output_loss: 5.1820 - question_type_output_accuracy: 0.4132 - question_type_output_loss: 1.0811 - val_answer_output_accuracy: 0.9856 - val_answer_output_loss: 0.5227 - val_loss: 7.0044 - val_question_output_accuracy: 0.0067 - val_question_output_loss: 5.4302 - val_question_type_output_accuracy: 0.5556 - val_question_type_output_loss: 1.0514\n",
"Epoch 9/10\n",
"\u001b[1m2/2\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 164ms/step - answer_output_accuracy: 0.9824 - answer_output_loss: 0.6137 - loss: 6.7673 - question_output_accuracy: 0.0045 - question_output_loss: 4.9309 - question_type_output_accuracy: 0.4815 - question_type_output_loss: 1.3263 - val_answer_output_accuracy: 0.9856 - val_answer_output_loss: 0.5185 - val_loss: 6.9885 - val_question_output_accuracy: 0.0044 - val_question_output_loss: 5.4844 - val_question_type_output_accuracy: 0.6667 - val_question_type_output_loss: 0.9857\n",
"Epoch 10/10\n",
"\u001b[1m2/2\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 176ms/step - answer_output_accuracy: 0.9824 - answer_output_loss: 0.6061 - loss: 6.6094 - question_output_accuracy: 0.0045 - question_output_loss: 4.8379 - question_type_output_accuracy: 0.4132 - question_type_output_loss: 1.1360 - val_answer_output_accuracy: 0.9856 - val_answer_output_loss: 0.4949 - val_loss: 7.1082 - val_question_output_accuracy: 0.0044 - val_question_output_loss: 5.5818 - val_question_type_output_accuracy: 0.5556 - val_question_type_output_loss: 1.0315\n",
"✅ Model training complete and saved!\n"
]
}
],
"source": [
"# === Model Hyperparameters === #\n",
"VOCAB_SIZE = len(tokenizer.word_index) + 1\n",
"EMBEDDING_DIM = 300\n",
"LSTM_UNITS = 256\n",
"BATCH_SIZE = 32\n",
"EPOCHS = 10\n",
"\n",
"# === Build Model === #\n",
"# Encoder for Context\n",
"context_input = Input(shape=(MAX_LENGTH,), name=\"context_input\")\n",
"context_embedding = Embedding(input_dim=VOCAB_SIZE, output_dim=EMBEDDING_DIM, mask_zero=True, name=\"context_embedding\")(context_input)\n",
"encoder_lstm = LSTM(LSTM_UNITS, return_state=True, name=\"encoder_lstm\")\n",
"encoder_output, state_h, state_c = encoder_lstm(context_embedding)\n",
"\n",
"# Decoder for Question (Teacher Forcing)\n",
"question_decoder_input = Input(shape=(MAX_LENGTH,), name=\"question_decoder_input\")\n",
"question_embedding = Embedding(input_dim=VOCAB_SIZE, output_dim=EMBEDDING_DIM, mask_zero=True, name=\"question_embedding\")(question_decoder_input)\n",
"question_lstm = LSTM(LSTM_UNITS, return_sequences=True, return_state=True, name=\"question_lstm\")\n",
"question_output, _, _ = question_lstm(question_embedding, initial_state=[state_h, state_c])\n",
"question_dense = Dense(VOCAB_SIZE, activation=\"softmax\", name=\"question_output\")(question_output)\n",
"\n",
"# Decoder for Answer\n",
"answer_lstm = LSTM(LSTM_UNITS, return_sequences=True, return_state=True, name=\"answer_lstm\")\n",
"answer_output, _, _ = answer_lstm(context_embedding, initial_state=[state_h, state_c])\n",
"answer_dense = Dense(VOCAB_SIZE, activation=\"softmax\", name=\"answer_output\")(answer_output)\n",
"\n",
"# Classification Output for Question Type\n",
"type_dense = Dense(128, activation=\"relu\")(encoder_output)\n",
"question_type_output = Dense(3, activation=\"softmax\", name=\"question_type_output\")(type_dense)\n",
"\n",
"# Construct the Model\n",
"model = Model(\n",
" inputs=[context_input, question_decoder_input],\n",
" outputs=[question_dense, answer_dense, question_type_output],\n",
")\n",
"\n",
"# === Compile the Model === #\n",
"model.compile(\n",
" optimizer=\"adam\",\n",
" loss={\n",
" \"question_output\": \"sparse_categorical_crossentropy\",\n",
" \"answer_output\": \"sparse_categorical_crossentropy\",\n",
" \"question_type_output\": \"sparse_categorical_crossentropy\",\n",
" },\n",
" metrics={\n",
" \"question_output\": [\"accuracy\"],\n",
" \"answer_output\": [\"accuracy\"],\n",
" \"question_type_output\": [\"accuracy\"],\n",
" },\n",
")\n",
"\n",
"# === Train the Model === #\n",
"model.fit(\n",
" [context_train, question_train],\n",
" {\n",
" \"question_output\": question_train,\n",
" \"answer_output\": answer_train,\n",
" \"question_type_output\": qtype_train,\n",
" },\n",
" batch_size=BATCH_SIZE,\n",
" epochs=EPOCHS,\n",
" validation_split=0.2,\n",
")\n",
"\n",
"# Save the Model\n",
"model.save(\"lstm_multi_output_model.keras\")\n",
"print(\"✅ Model training complete and saved!\")"
]
},
{
"cell_type": "code",
"execution_count": 49,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 255ms/step\n",
"=== Evaluation on Test Data ===\n",
"Classification Report for Question Type:\n",
" precision recall f1-score support\n",
"\n",
" 0 0.33 0.40 0.36 5\n",
" 1 0.33 0.50 0.40 4\n",
" 2 0.00 0.00 0.00 3\n",
"\n",
" accuracy 0.33 12\n",
" macro avg 0.22 0.30 0.25 12\n",
"weighted avg 0.25 0.33 0.28 12\n",
"\n",
"Accuracy: 0.3333333333333333\n",
"Precision: 0.25\n",
"Recall: 0.3333333333333333\n",
"BLEU score for first test sample (question generation): 0\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/mnt/disc1/code/thesis_quiz_project/lstm-quiz/myenv/lib64/python3.10/site-packages/sklearn/metrics/_classification.py:1565: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
" _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n",
"/mnt/disc1/code/thesis_quiz_project/lstm-quiz/myenv/lib64/python3.10/site-packages/sklearn/metrics/_classification.py:1565: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
" _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n",
"/mnt/disc1/code/thesis_quiz_project/lstm-quiz/myenv/lib64/python3.10/site-packages/sklearn/metrics/_classification.py:1565: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
" _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n",
"/mnt/disc1/code/thesis_quiz_project/lstm-quiz/myenv/lib64/python3.10/site-packages/sklearn/metrics/_classification.py:1565: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
" _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n"
]
}
],
"source": [
"# === Evaluate on Test Set === #\n",
"pred_question, pred_answer, pred_qtype = model.predict([context_test, question_test])\n",
"pred_qtype_labels = np.argmax(pred_qtype, axis=1)\n",
"\n",
"print(\"=== Evaluation on Test Data ===\")\n",
"print(\"Classification Report for Question Type:\")\n",
"print(classification_report(qtype_test, pred_qtype_labels))\n",
"print(\"Accuracy:\", accuracy_score(qtype_test, pred_qtype_labels))\n",
"print(\"Precision:\", precision_score(qtype_test, pred_qtype_labels, average='weighted'))\n",
"print(\"Recall:\", recall_score(qtype_test, pred_qtype_labels, average='weighted'))\n",
"\n",
"# Optional: Evaluate sequence generation using BLEU score for the first sample\n",
"import nltk\n",
"def sequence_to_text(sequence, tokenizer):\n",
" return [tokenizer.index_word.get(idx, \"<OOV>\") for idx in sequence if idx != 0]\n",
"\n",
"reference = [sequence_to_text(question_test[0], tokenizer)]\n",
"candidate = sequence_to_text(np.argmax(pred_question[0], axis=-1), tokenizer)\n",
"bleu_score = nltk.translate.bleu_score.sentence_bleu(reference, candidate)\n",
"print(\"BLEU score for first test sample (question generation):\", bleu_score)\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "myenv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.16"
}
},
"nbformat": 4,
"nbformat_minor": 2
}