TIF_E41211115_lstm-quiz-gen.../training_model.ipynb

520 lines
157 KiB
Plaintext

{
"cells": [
{
"cell_type": "code",
"execution_count": 112,
"metadata": {},
"outputs": [],
"source": [
"\n",
"import numpy as np\n",
"import re\n",
"import string\n",
"import nltk\n",
"from nltk.corpus import stopwords\n",
"from nltk.tokenize import word_tokenize\n",
"import matplotlib.pyplot as plt\n",
"\n",
"from tensorflow.keras.preprocessing.text import Tokenizer\n",
"from tensorflow.keras.preprocessing.sequence import pad_sequences\n",
"from tensorflow.keras.models import Model\n",
"from tensorflow.keras.layers import Input, Embedding, LSTM, Dense\n",
"\n",
"from Sastrawi.Stemmer.StemmerFactory import StemmerFactory\n",
"from sklearn.model_selection import train_test_split\n",
"import pickle\n",
"\n",
"from sklearn.metrics import classification_report, accuracy_score, precision_score, recall_score\n",
"import nltk"
]
},
{
"cell_type": "code",
"execution_count": 113,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"[nltk_data] Downloading package stopwords to /home/akeon/nltk_data...\n",
"[nltk_data] Package stopwords is already up-to-date!\n",
"[nltk_data] Downloading package punkt to /home/akeon/nltk_data...\n",
"[nltk_data] Package punkt is already up-to-date!\n",
"[nltk_data] Downloading package punkt_tab to /home/akeon/nltk_data...\n",
"[nltk_data] Package punkt_tab is already up-to-date!\n",
"[nltk_data] Downloading package wordnet to /home/akeon/nltk_data...\n",
"[nltk_data] Package wordnet is already up-to-date!\n"
]
},
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 113,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# download assets\n",
"nltk.download(\"stopwords\")\n",
"nltk.download(\"punkt\")\n",
"nltk.download(\"punkt_tab\")\n",
"nltk.download(\"wordnet\")"
]
},
{
"cell_type": "code",
"execution_count": 114,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Total Context: 54\n",
"Total Possibility Questions: 97\n",
"Total Fill in the Blank Questions: 24\n",
"Total Multiple Choice Questions: 29\n",
"Total True/False Questions: 44\n"
]
}
],
"source": [
"import json\n",
"from collections import defaultdict\n",
"\n",
"# path dataset\n",
"file_path = \"dataset/training_dataset.json\"\n",
"\n",
"with open(file_path, \"r\", encoding=\"utf-8\") as file:\n",
" dataset = json.load(file)\n",
"\n",
"total_context = len(dataset)\n",
"\n",
"total_question_posibility = sum(len(entry[\"question_posibility\"]) for entry in dataset)\n",
"\n",
"question_type_counts = defaultdict(int)\n",
"for entry in dataset:\n",
" for question in entry[\"question_posibility\"]:\n",
" question_type_counts[question[\"type\"]] += 1\n",
"\n",
"print(f\"Total Context: {total_context}\")\n",
"print(f\"Total Possibility Questions: {total_question_posibility}\")\n",
"print(f\"Total Fill in the Blank Questions: {question_type_counts.get('fill_in_the_blank', 0)}\")\n",
"print(f\"Total Multiple Choice Questions: {question_type_counts.get('multiple_choice', 0)}\")\n",
"print(f\"Total True/False Questions: {question_type_counts.get('true_false', 0)}\")\n"
]
},
{
"cell_type": "code",
"execution_count": 115,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Data processing complete!\n",
"Samples: 97\n"
]
}
],
"source": [
"# Text Preprocessing\n",
"stop_words = set(stopwords.words(\"indonesian\")) \n",
"factory = StemmerFactory()\n",
"stemmer = factory.create_stemmer()\n",
"\n",
"with open(\"normalize_text/normalize.json\", \"r\", encoding=\"utf-8\") as file:\n",
" normalization_dict = json.load(file)\n",
"\n",
"\n",
"def text_preprocessing(text):\n",
" #doing lower case \n",
" text = text.lower()\n",
" \n",
" # remove symbol and read mark\n",
" text = text.translate(str.maketrans(\"\", \"\", string.punctuation))\n",
" \n",
" # remove blank space\n",
" text = re.sub(r\"\\s+\", \" \", text).strip()\n",
" \n",
" # word tokenize \n",
" tokens = word_tokenize(text)\n",
" \n",
" \n",
" # normalassi kata\n",
" tokens = [normalization_dict[word] if word in normalization_dict else word for word in tokens] \n",
" \n",
" \n",
" # Lemmatization using Sastrawi (stemming in Indonesian)\n",
" tokens = [stemmer.stem(word) for word in tokens]\n",
" \n",
" # stopword removal\n",
" tokens = [word for word in tokens if word not in stop_words] \n",
" \n",
" return tokens\n",
"\n",
"# text processing all data training\n",
"contexts = []\n",
"questions = []\n",
"correct_answers = []\n",
"wrong_answers = []\n",
"question_types = []\n",
"\n",
"for entry in dataset:\n",
" processed_context = text_preprocessing(entry[\"context\"])\n",
" \n",
" for qa in entry[\"question_posibility\"]:\n",
" processed_question = text_preprocessing(qa[\"question\"])\n",
" processed_answer = text_preprocessing(qa[\"answer\"])\n",
" \n",
" contexts.append(processed_context)\n",
" questions.append(processed_question)\n",
" correct_answers.append(processed_answer)\n",
" question_types.append(qa[\"type\"])\n",
"\n",
" if qa[\"type\"] == \"multiple_choice\":\n",
" incorrect_options = [opt for opt in qa[\"options\"] if opt != qa[\"answer\"]]\n",
" wrong_answers.append(incorrect_options)\n",
" else:\n",
" wrong_answers.append([])\n",
"\n",
"tokenizer = Tokenizer(oov_token=\"<OOV>\")\n",
"tokenizer.fit_on_texts(contexts + questions + correct_answers)\n",
"\n",
"\n",
"context_sequences = tokenizer.texts_to_sequences(contexts)\n",
"question_sequences = tokenizer.texts_to_sequences(questions)\n",
"answer_sequences = tokenizer.texts_to_sequences(correct_answers)\n",
"\n",
"MAX_LENGTH = 100\n",
"context_padded = pad_sequences(context_sequences, maxlen=MAX_LENGTH, padding=\"post\", truncating=\"post\")\n",
"question_padded = pad_sequences(question_sequences, maxlen=MAX_LENGTH, padding=\"post\", truncating=\"post\")\n",
"answer_padded = pad_sequences(answer_sequences, maxlen=MAX_LENGTH, padding=\"post\", truncating=\"post\")\n",
"\n",
"question_type_dict = {\"fill_in_the_blank\": 0, \"true_false\": 1, \"multiple_choice\": 2}\n",
"question_type_labels = np.array([question_type_dict[q_type] for q_type in question_types])\n",
"\n",
"print(\"Data processing complete!\")\n",
"print(\"Samples:\", context_padded.shape[0]) \n"
]
},
{
"cell_type": "code",
"execution_count": 116,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Training samples: 77\n",
"Testing samples: 10\n",
"Validation samples: 10\n"
]
}
],
"source": [
"# split text for testing 8:2\n",
"context_train, context_temp, question_train, question_temp, answer_train, answer_temp, qtype_train, qtype_temp = train_test_split(\n",
" context_padded, question_padded, answer_padded, question_type_labels, test_size=0.2, random_state=42\n",
")\n",
"\n",
"# split dataset testing and validation 5:5\n",
"context_test, context_val, question_test, question_val, answer_test, answer_val, qtype_test, qtype_val = train_test_split(\n",
" context_temp, question_temp, answer_temp, qtype_temp, test_size=0.5, random_state=42\n",
")\n",
"\n",
"print(\"Training samples:\", context_train.shape[0])\n",
"print(\"Testing samples:\", context_test.shape[0])\n",
"print(\"Validation samples:\", context_val.shape[0])\n"
]
},
{
"cell_type": "code",
"execution_count": 120,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/10\n",
"\u001b[1m2/2\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m6s\u001b[0m 1s/step - answer_output_accuracy: 0.0344 - answer_output_loss: 6.2090 - loss: 13.5239 - question_output_accuracy: 0.0000e+00 - question_output_loss: 6.2154 - question_type_output_accuracy: 0.3004 - question_type_output_loss: 1.0991 - val_answer_output_accuracy: 0.2287 - val_answer_output_loss: 6.1669 - val_loss: 13.4815 - val_question_output_accuracy: 0.0050 - val_question_output_loss: 6.2101 - val_question_type_output_accuracy: 0.3125 - val_question_type_output_loss: 1.1046\n",
"Epoch 2/10\n",
"\u001b[1m2/2\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 526ms/step - answer_output_accuracy: 0.2277 - answer_output_loss: 6.1421 - loss: 13.4196 - question_output_accuracy: 0.0113 - question_output_loss: 6.1984 - question_type_output_accuracy: 0.6445 - question_type_output_loss: 1.0780 - val_answer_output_accuracy: 0.9856 - val_answer_output_loss: 6.0462 - val_loss: 13.3570 - val_question_output_accuracy: 0.0081 - val_question_output_loss: 6.2031 - val_question_type_output_accuracy: 0.3125 - val_question_type_output_loss: 1.1076\n",
"Epoch 3/10\n",
"\u001b[1m2/2\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 528ms/step - answer_output_accuracy: 0.9837 - answer_output_loss: 5.9539 - loss: 13.1879 - question_output_accuracy: 0.0171 - question_output_loss: 6.1802 - question_type_output_accuracy: 0.5799 - question_type_output_loss: 1.0503 - val_answer_output_accuracy: 0.9856 - val_answer_output_loss: 5.5439 - val_loss: 12.8565 - val_question_output_accuracy: 0.0087 - val_question_output_loss: 6.1941 - val_question_type_output_accuracy: 0.5000 - val_question_type_output_loss: 1.1185\n",
"Epoch 4/10\n",
"\u001b[1m2/2\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 533ms/step - answer_output_accuracy: 0.9839 - answer_output_loss: 5.1228 - loss: 12.2985 - question_output_accuracy: 0.0137 - question_output_loss: 6.1532 - question_type_output_accuracy: 0.5164 - question_type_output_loss: 1.0060 - val_answer_output_accuracy: 0.9856 - val_answer_output_loss: 3.2875 - val_loss: 10.6708 - val_question_output_accuracy: 0.0050 - val_question_output_loss: 6.1772 - val_question_type_output_accuracy: 0.5000 - val_question_type_output_loss: 1.2060\n",
"Epoch 5/10\n",
"\u001b[1m2/2\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 520ms/step - answer_output_accuracy: 0.9835 - answer_output_loss: 2.7939 - loss: 9.9397 - question_output_accuracy: 0.0056 - question_output_loss: 6.0862 - question_type_output_accuracy: 0.5263 - question_type_output_loss: 1.0473 - val_answer_output_accuracy: 0.9856 - val_answer_output_loss: 1.1028 - val_loss: 9.0601 - val_question_output_accuracy: 0.0012 - val_question_output_loss: 6.1277 - val_question_type_output_accuracy: 0.5000 - val_question_type_output_loss: 1.8296\n",
"Epoch 6/10\n",
"\u001b[1m2/2\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 541ms/step - answer_output_accuracy: 0.9828 - answer_output_loss: 1.2315 - loss: 8.3718 - question_output_accuracy: 0.0016 - question_output_loss: 5.8773 - question_type_output_accuracy: 0.5055 - question_type_output_loss: 1.2478 - val_answer_output_accuracy: 0.9856 - val_answer_output_loss: 0.6227 - val_loss: 8.6339 - val_question_output_accuracy: 0.0012 - val_question_output_loss: 6.0831 - val_question_type_output_accuracy: 0.1250 - val_question_type_output_loss: 1.9281\n",
"Epoch 7/10\n",
"\u001b[1m2/2\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 492ms/step - answer_output_accuracy: 0.9842 - answer_output_loss: 0.7375 - loss: 7.4714 - question_output_accuracy: 9.6824e-04 - question_output_loss: 5.5770 - question_type_output_accuracy: 0.4612 - question_type_output_loss: 1.1578 - val_answer_output_accuracy: 0.9856 - val_answer_output_loss: 0.5788 - val_loss: 7.9850 - val_question_output_accuracy: 0.0012 - val_question_output_loss: 6.1148 - val_question_type_output_accuracy: 0.1250 - val_question_type_output_loss: 1.2913\n",
"Epoch 8/10\n",
"\u001b[1m2/2\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 455ms/step - answer_output_accuracy: 0.9847 - answer_output_loss: 0.6731 - loss: 6.9870 - question_output_accuracy: 0.0011 - question_output_loss: 5.3263 - question_type_output_accuracy: 0.5596 - question_type_output_loss: 0.9895 - val_answer_output_accuracy: 0.9856 - val_answer_output_loss: 0.6030 - val_loss: 7.8753 - val_question_output_accuracy: 0.0012 - val_question_output_loss: 6.2693 - val_question_type_output_accuracy: 0.5000 - val_question_type_output_loss: 1.0031\n",
"Epoch 9/10\n",
"\u001b[1m2/2\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 458ms/step - answer_output_accuracy: 0.9836 - answer_output_loss: 0.7391 - loss: 6.9393 - question_output_accuracy: 0.0017 - question_output_loss: 5.0887 - question_type_output_accuracy: 0.4841 - question_type_output_loss: 1.1123 - val_answer_output_accuracy: 0.9856 - val_answer_output_loss: 0.6056 - val_loss: 8.1353 - val_question_output_accuracy: 0.0019 - val_question_output_loss: 6.4616 - val_question_type_output_accuracy: 0.5000 - val_question_type_output_loss: 1.0682\n",
"Epoch 10/10\n",
"\u001b[1m2/2\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 454ms/step - answer_output_accuracy: 0.9847 - answer_output_loss: 0.6727 - loss: 6.6312 - question_output_accuracy: 0.0018 - question_output_loss: 4.9620 - question_type_output_accuracy: 0.5258 - question_type_output_loss: 1.0078 - val_answer_output_accuracy: 0.9856 - val_answer_output_loss: 0.5869 - val_loss: 8.5074 - val_question_output_accuracy: 0.0037 - val_question_output_loss: 6.6207 - val_question_type_output_accuracy: 0.3750 - val_question_type_output_loss: 1.2998\n"
]
}
],
"source": [
"\n",
"VOCAB_SIZE = len(tokenizer.word_index) + 1\n",
"EMBEDDING_DIM = 300\n",
"LSTM_UNITS = 256\n",
"BATCH_SIZE = 32\n",
"EPOCHS = 10\n",
"\n",
"context_input = Input(shape=(MAX_LENGTH,), name=\"context_input\")\n",
"context_embedding = Embedding(input_dim=VOCAB_SIZE, output_dim=EMBEDDING_DIM, mask_zero=True, name=\"context_embedding\")(context_input)\n",
"encoder_lstm = LSTM(LSTM_UNITS, return_state=True, name=\"encoder_lstm\")\n",
"encoder_output, state_h, state_c = encoder_lstm(context_embedding)\n",
"\n",
"# Question Decoder\n",
"question_decoder_input = Input(shape=(MAX_LENGTH,), name=\"question_decoder_input\")\n",
"question_embedding = Embedding(input_dim=VOCAB_SIZE, output_dim=EMBEDDING_DIM, mask_zero=True, name=\"question_embedding\")(question_decoder_input)\n",
"question_lstm = LSTM(LSTM_UNITS, return_sequences=True, return_state=True, name=\"question_lstm\")\n",
"question_output, _, _ = question_lstm(question_embedding, initial_state=[state_h, state_c])\n",
"question_dense = Dense(VOCAB_SIZE, activation=\"softmax\", name=\"question_output\")(question_output)\n",
"\n",
"# Answer Decoder\n",
"answer_lstm = LSTM(LSTM_UNITS, return_sequences=True, return_state=True, name=\"answer_lstm\")\n",
"answer_output, _, _ = answer_lstm(context_embedding, initial_state=[state_h, state_c])\n",
"answer_dense = Dense(VOCAB_SIZE, activation=\"softmax\", name=\"answer_output\")(answer_output)\n",
"\n",
"type_dense = Dense(128, activation=\"relu\")(encoder_output)\n",
"question_type_output = Dense(3, activation=\"softmax\", name=\"question_type_output\")(type_dense)\n",
"\n",
"model = Model(\n",
" inputs=[context_input, question_decoder_input],\n",
" outputs=[question_dense, answer_dense, question_type_output]\n",
")\n",
"\n",
"model.compile(\n",
" optimizer=\"adam\",\n",
" loss={\"question_output\": \"sparse_categorical_crossentropy\",\n",
" \"answer_output\": \"sparse_categorical_crossentropy\",\n",
" \"question_type_output\": \"sparse_categorical_crossentropy\"},\n",
" metrics={\"question_output\": [\"accuracy\"],\n",
" \"answer_output\": [\"accuracy\"],\n",
" \"question_type_output\": [\"accuracy\"]}\n",
")\n",
"\n",
"data_model = model.fit(\n",
" [context_train, question_train],\n",
" {\"question_output\": question_train, \"answer_output\": answer_train, \"question_type_output\": qtype_train},\n",
" batch_size=BATCH_SIZE,\n",
" epochs=EPOCHS,\n",
" validation_split=0.2\n",
")\n",
"\n",
"model.save(\"lstm_multi_output_model.keras\")\n",
"with open(\"tokenizer.pkl\", \"wb\") as handle:\n",
" pickle.dump(tokenizer, handle, protocol=pickle.HIGHEST_PROTOCOL)\n"
]
},
{
"cell_type": "code",
"execution_count": 121,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 1200x600 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"\n",
"# data epoch\n",
"\n",
"plt.figure(figsize=(12, 6))\n",
"\n",
"plt.subplot(1, 2, 1)\n",
"plt.plot(data_model.history['question_output_accuracy'], label='Question Train Accuracy')\n",
"plt.plot(data_model.history['val_question_output_accuracy'], label='Question Val Accuracy')\n",
"plt.plot(data_model.history['answer_output_accuracy'], label='Answer Train Accuracy')\n",
"plt.plot(data_model.history['val_answer_output_accuracy'], label='Answer Val Accuracy')\n",
"plt.plot(data_model.history['question_type_output_accuracy'], label='Question Type Train Accuracy')\n",
"plt.plot(data_model.history['val_question_type_output_accuracy'], label='Question Type Val Accuracy')\n",
"plt.title('Model Accuracy')\n",
"plt.xlabel('Epoch')\n",
"plt.ylabel('Accuracy')\n",
"plt.legend()\n",
"\n",
"plt.subplot(1, 2, 2)\n",
"plt.plot(data_model.history['question_output_loss'], label='Question Train Loss')\n",
"plt.plot(data_model.history['val_question_output_loss'], label='Question Val Loss')\n",
"plt.plot(data_model.history['answer_output_loss'], label='Answer Train Loss')\n",
"plt.plot(data_model.history['val_answer_output_loss'], label='Answer Val Loss')\n",
"plt.plot(data_model.history['question_type_output_loss'], label='Question Type Train Loss')\n",
"plt.plot(data_model.history['val_question_type_output_loss'], label='Question Type Val Loss')\n",
"plt.title('Model Loss')\n",
"plt.xlabel('Epoch')\n",
"plt.ylabel('Loss')\n",
"plt.legend()\n",
"\n",
"plt.tight_layout()\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 119,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"=== Evaluation on Test Data ===\n",
"\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 389ms/step\n",
"Classification Report for Question Type (Test Set):\n",
" precision recall f1-score support\n",
"\n",
" 0 0.00 0.00 0.00 4\n",
" 1 0.40 0.67 0.50 3\n",
" 2 0.20 0.33 0.25 3\n",
"\n",
" accuracy 0.30 10\n",
" macro avg 0.20 0.33 0.25 10\n",
"weighted avg 0.18 0.30 0.23 10\n",
"\n",
"Test Accuracy: 0.3\n",
"Test Precision: 0.18000000000000002\n",
"Test Recall: 0.3\n",
"BLEU Score for first test sample (question generation): 0.02664466031983166\n",
"BLEU Score for first test sample (answer generation): 0\n",
"\n",
"=== Evaluation on Validation Data ===\n",
"\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 92ms/step\n",
"Classification Report for Question Type (Validation Set):\n",
" precision recall f1-score support\n",
"\n",
" 0 0.00 0.00 0.00 4\n",
" 1 0.50 1.00 0.67 3\n",
" 2 0.25 0.33 0.29 3\n",
"\n",
" accuracy 0.40 10\n",
" macro avg 0.25 0.44 0.32 10\n",
"weighted avg 0.23 0.40 0.29 10\n",
"\n",
"Validation Accuracy: 0.4\n",
"Validation Precision: 0.225\n",
"Validation Recall: 0.4\n",
"BLEU Score for first validation sample (question generation): 0.008991061769415444\n",
"BLEU Score for first validation sample (answer generation): 0\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/mnt/disc1/code/thesis_quiz_project/lstm-quiz/myenv/lib64/python3.10/site-packages/sklearn/metrics/_classification.py:1565: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
" _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n",
"/mnt/disc1/code/thesis_quiz_project/lstm-quiz/myenv/lib64/python3.10/site-packages/sklearn/metrics/_classification.py:1565: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
" _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n",
"/mnt/disc1/code/thesis_quiz_project/lstm-quiz/myenv/lib64/python3.10/site-packages/sklearn/metrics/_classification.py:1565: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
" _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n",
"/mnt/disc1/code/thesis_quiz_project/lstm-quiz/myenv/lib64/python3.10/site-packages/sklearn/metrics/_classification.py:1565: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
" _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n",
"/mnt/disc1/code/thesis_quiz_project/lstm-quiz/myenv/lib64/python3.10/site-packages/sklearn/metrics/_classification.py:1565: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
" _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n",
"/mnt/disc1/code/thesis_quiz_project/lstm-quiz/myenv/lib64/python3.10/site-packages/sklearn/metrics/_classification.py:1565: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
" _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n",
"/mnt/disc1/code/thesis_quiz_project/lstm-quiz/myenv/lib64/python3.10/site-packages/sklearn/metrics/_classification.py:1565: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
" _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n",
"/mnt/disc1/code/thesis_quiz_project/lstm-quiz/myenv/lib64/python3.10/site-packages/sklearn/metrics/_classification.py:1565: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
" _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n"
]
}
],
"source": [
"\n",
"\n",
"\n",
"def sequence_to_text(sequence, tokenizer):\n",
" \"\"\" Converts a tokenized sequence back into readable text. \"\"\"\n",
" return \" \".join([tokenizer.index_word.get(idx, \"<OOV>\") for idx in sequence if idx != 0])\n",
"\n",
"print(\"\\n=== Evaluation on Test Data ===\")\n",
"pred_question_test, pred_answer_test, pred_qtype_test = model.predict([context_test, question_test])\n",
"pred_qtype_labels_test = np.argmax(pred_qtype_test, axis=1)\n",
"\n",
"print(\"Classification Report for Question Type (Test Set):\")\n",
"print(classification_report(qtype_test, pred_qtype_labels_test))\n",
"print(\"Test Accuracy:\", accuracy_score(qtype_test, pred_qtype_labels_test))\n",
"print(\"Test Precision:\", precision_score(qtype_test, pred_qtype_labels_test, average='weighted'))\n",
"print(\"Test Recall:\", recall_score(qtype_test, pred_qtype_labels_test, average='weighted'))\n",
"\n",
"reference_question_test = [sequence_to_text(question_test[0], tokenizer)]\n",
"candidate_question_test = sequence_to_text(np.argmax(pred_question_test[0], axis=-1), tokenizer)\n",
"bleu_score_question_test = nltk.translate.bleu_score.sentence_bleu(reference_question_test, candidate_question_test)\n",
"print(\"BLEU Score for first test sample (question generation):\", bleu_score_question_test)\n",
"\n",
"reference_answer_test = [sequence_to_text(answer_test[0], tokenizer)]\n",
"candidate_answer_test = sequence_to_text(np.argmax(pred_answer_test[0], axis=-1), tokenizer)\n",
"bleu_score_answer_test = nltk.translate.bleu_score.sentence_bleu(reference_answer_test, candidate_answer_test)\n",
"print(\"BLEU Score for first test sample (answer generation):\", bleu_score_answer_test)\n",
"\n",
"print(\"\\n=== Evaluation on Validation Data ===\")\n",
"pred_question_val, pred_answer_val, pred_qtype_val = model.predict([context_val, question_val])\n",
"pred_qtype_labels_val = np.argmax(pred_qtype_val, axis=1)\n",
"\n",
"print(\"Classification Report for Question Type (Validation Set):\")\n",
"print(classification_report(qtype_val, pred_qtype_labels_val))\n",
"print(\"Validation Accuracy:\", accuracy_score(qtype_val, pred_qtype_labels_val))\n",
"print(\"Validation Precision:\", precision_score(qtype_val, pred_qtype_labels_val, average='weighted'))\n",
"print(\"Validation Recall:\", recall_score(qtype_val, pred_qtype_labels_val, average='weighted'))\n",
"\n",
"reference_question_val = [sequence_to_text(question_val[0], tokenizer)]\n",
"candidate_question_val = sequence_to_text(np.argmax(pred_question_val[0], axis=-1), tokenizer)\n",
"bleu_score_question_val = nltk.translate.bleu_score.sentence_bleu(reference_question_val, candidate_question_val)\n",
"print(\"BLEU Score for first validation sample (question generation):\", bleu_score_question_val)\n",
"\n",
"reference_answer_val = [sequence_to_text(answer_val[0], tokenizer)]\n",
"candidate_answer_val = sequence_to_text(np.argmax(pred_answer_val[0], axis=-1), tokenizer)\n",
"bleu_score_answer_val = nltk.translate.bleu_score.sentence_bleu(reference_answer_val, candidate_answer_val)\n",
"print(\"BLEU Score for first validation sample (answer generation):\", bleu_score_answer_val)\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "myenv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.16"
}
},
"nbformat": 4,
"nbformat_minor": 2
}