TIF_E41211115_lstm-quiz-gen.../old/QC/qg_train.ipynb

704 lines
57 KiB
Plaintext
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"cells": [
{
"cell_type": "code",
"execution_count": 20,
"id": "9bf2159a",
"metadata": {},
"outputs": [],
"source": [
"import json\n",
"import numpy as np\n",
"from pathlib import Path\n",
"from sklearn.model_selection import train_test_split\n",
"from tensorflow.keras.preprocessing.text import Tokenizer\n",
"from tensorflow.keras.preprocessing.sequence import pad_sequences\n",
"from tensorflow.keras.utils import to_categorical\n",
"\n",
"from tensorflow.keras.models import Model\n",
"from tensorflow.keras.layers import (\n",
" Input,\n",
" Embedding,\n",
" LSTM,\n",
" Concatenate,\n",
" Dense,\n",
" TimeDistributed,\n",
")\n",
"from tensorflow.keras.callbacks import EarlyStopping\n",
"from sklearn.metrics import classification_report\n",
"from collections import Counter"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "50118278",
"metadata": {},
"outputs": [],
"source": [
"# # Load raw data\n",
"# with open(\"qg_dataset.json\", encoding=\"utf-8\") as f:\n",
"# raw_data = json.load(f)\n",
"\n",
"# # Validasi lengkap\n",
"# required_keys = {\"tokens\", \"ner\", \"srl\", \"question\", \"answer\", \"type\"}\n",
"# valid_data = []\n",
"# invalid_data = []\n",
"\n",
"# for idx, item in enumerate(raw_data):\n",
"# error_messages = []\n",
"\n",
"# if not isinstance(item, dict):\n",
"# error_messages.append(\"bukan dictionary\")\n",
"\n",
"# missing_keys = required_keys - item.keys()\n",
"# if missing_keys:\n",
"# error_messages.append(f\"missing keys: {missing_keys}\")\n",
"\n",
"# if not error_messages:\n",
"# # Cek tipe data dan None\n",
"# if (not isinstance(item[\"tokens\"], list) or\n",
"# not isinstance(item[\"ner\"], list) or\n",
"# not isinstance(item[\"srl\"], list) or\n",
"# not isinstance(item[\"question\"], list) or\n",
"# not isinstance(item[\"answer\"], list) or\n",
"# not isinstance(item[\"type\"], str)):\n",
"# error_messages.append(\"field type tidak sesuai\")\n",
" \n",
"# if error_messages:\n",
"# print(f\"\\n Index {idx} | Masalah: {', '.join(error_messages)}\")\n",
"# print(json.dumps(item, indent=2, ensure_ascii=False))\n",
"# invalid_data.append(item)\n",
"# continue\n",
"\n",
"# valid_data.append(item)\n",
"\n",
"# # Statistik\n",
"# print(f\"\\n Jumlah data valid: {len(valid_data)} / {len(raw_data)}\")\n",
"# print(f\" Jumlah data tidak valid: {len(invalid_data)}\")\n",
"\n",
"# # Proses data valid\n",
"# tokens = [[t.lower().strip() for t in item[\"tokens\"]] for item in valid_data]\n",
"# ner_tags = [item[\"ner\"] for item in valid_data]\n",
"# srl_tags = [item[\"srl\"] for item in valid_data]\n",
"# questions = [[token.lower().strip() for token in item[\"question\"]] for item in valid_data]\n",
"# answers = [[token.lower().strip() for token in item[\"answer\"]] for item in valid_data]\n",
"# types = [item[\"type\"] for item in valid_data]\n",
"\n",
"# type_counts = Counter(types)\n",
"\n",
"# print(type_counts)\n"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "970867e2",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Jumlah data valid: 396 / 397\n",
"Jumlah data tidak valid: 1\n",
"\n",
"Distribusi Tipe Soal:\n",
"- isian: 390\n",
"- opsi: 4\n",
"- true_false: 2\n"
]
}
],
"source": [
"import json\n",
"from collections import Counter\n",
"\n",
"# Load raw data\n",
"with open(\"../../dataset/dev_dataset_qg.json\", encoding=\"utf-8\") as f:\n",
" raw_data = json.load(f)\n",
"\n",
"# Validasi lengkap\n",
"required_keys = {\"tokens\", \"ner\", \"srl\", \"quiz_possibility\"}\n",
"valid_data = []\n",
"invalid_data = []\n",
"\n",
"for idx, item in enumerate(raw_data):\n",
" error_messages = []\n",
"\n",
" if not isinstance(item, dict):\n",
" error_messages.append(\"bukan dictionary\")\n",
" invalid_data.append(item)\n",
" continue\n",
"\n",
" missing_keys = required_keys - item.keys()\n",
" if missing_keys:\n",
" error_messages.append(f\"missing keys: {missing_keys}\")\n",
"\n",
" if not error_messages:\n",
" # Cek tipe data utama\n",
" if (not isinstance(item[\"tokens\"], list) or\n",
" not isinstance(item[\"ner\"], list) or\n",
" not isinstance(item[\"srl\"], list) or\n",
" not isinstance(item[\"quiz_possibility\"], list)):\n",
" error_messages.append(\"field type tidak sesuai di level utama\")\n",
"\n",
" # Validasi quiz_possibility\n",
" if not error_messages:\n",
" if not item[\"quiz_possibility\"]:\n",
" error_messages.append(\"quiz_possibility kosong\")\n",
" else:\n",
" quiz_item = item[\"quiz_possibility\"][0]\n",
"\n",
" # Validasi kunci di dalam quiz_possibility[0]\n",
" expected_quiz_keys = {\"type\", \"question\", \"answer\"}\n",
" missing_quiz_keys = expected_quiz_keys - quiz_item.keys()\n",
"\n",
" if missing_quiz_keys:\n",
" error_messages.append(f\"missing keys di quiz_possibility[0]: {missing_quiz_keys}\")\n",
" else:\n",
" # Cek tipe data di quiz_possibility[0]\n",
" if (not isinstance(quiz_item[\"type\"], str) or\n",
" not isinstance(quiz_item[\"question\"], list) or\n",
" not isinstance(quiz_item[\"answer\"], list)):\n",
" error_messages.append(\"field type tidak sesuai di quiz_possibility[0]\")\n",
" else:\n",
" # Flatten ke struktur lama untuk konsistensi\n",
" item[\"type\"] = quiz_item[\"type\"]\n",
" item[\"question\"] = quiz_item[\"question\"]\n",
" item[\"answer\"] = quiz_item[\"answer\"]\n",
"\n",
" if error_messages:\n",
" print(f\"\\nIndex {idx} | Masalah: {', '.join(error_messages)}\")\n",
" print(json.dumps(item, indent=2, ensure_ascii=False))\n",
" invalid_data.append(item)\n",
" continue\n",
"\n",
" valid_data.append(item)\n",
"\n",
"# Statistik\n",
"print(f\"\\nJumlah data valid: {len(valid_data)} / {len(raw_data)}\")\n",
"print(f\"Jumlah data tidak valid: {len(invalid_data)}\")\n",
"\n",
"# Proses data valid\n",
"tokens = [[t.lower().strip() for t in item[\"tokens\"]] for item in valid_data]\n",
"ner_tags = [item[\"ner\"] for item in valid_data]\n",
"srl_tags = [item[\"srl\"] for item in valid_data]\n",
"questions = [[token.lower().strip() for token in item[\"question\"]] for item in valid_data]\n",
"answers = [[token.lower().strip() for token in item[\"answer\"]] for item in valid_data]\n",
"types = [item[\"type\"].lower().strip() for item in valid_data] # Konsistensi lowercase untuk tipe\n",
"\n",
"# Statistik tipe soal\n",
"type_counts = Counter(types)\n",
"print(\"\\nDistribusi Tipe Soal:\")\n",
"for t, count in type_counts.items():\n",
" print(f\"- {t}: {count}\")\n",
"\n",
"# (Opsional) Simpan data valid\n",
"with open(\"cleaned_qg_dataset.json\", \"w\", encoding=\"utf-8\") as f:\n",
" json.dump(valid_data, f, ensure_ascii=False, indent=2)\n",
"\n",
"# (Opsional) Simpan data tidak valid untuk analisa\n",
"with open(\"invalid_qg_dataset.json\", \"w\", encoding=\"utf-8\") as f:\n",
" json.dump(invalid_data, f, ensure_ascii=False, indent=2)\n"
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "4e3a0088",
"metadata": {},
"outputs": [],
"source": [
"# tokenize\n",
"token_tok = Tokenizer(lower=False, oov_token=\"UNK\")\n",
"token_ner = Tokenizer(lower=False)\n",
"token_srl = Tokenizer(lower=False)\n",
"token_q = Tokenizer(lower=False)\n",
"token_a = Tokenizer(lower=False)\n",
"token_type = Tokenizer(lower=False)\n",
"\n",
"token_tok.fit_on_texts(tokens)\n",
"token_ner.fit_on_texts(ner_tags)\n",
"token_srl.fit_on_texts(srl_tags)\n",
"token_q.fit_on_texts(questions)\n",
"token_a.fit_on_texts(answers)\n",
"token_type.fit_on_texts(types)\n",
"\n",
"\n",
"maxlen = 20"
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "555f9e22",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'opsi', 'isian', 'true_false'}\n"
]
}
],
"source": [
"\n",
"X_tok = pad_sequences(\n",
" token_tok.texts_to_sequences(tokens), padding=\"post\", maxlen=maxlen\n",
")\n",
"X_ner = pad_sequences(\n",
" token_ner.texts_to_sequences(ner_tags), padding=\"post\", maxlen=maxlen\n",
")\n",
"X_srl = pad_sequences(\n",
" token_srl.texts_to_sequences(srl_tags), padding=\"post\", maxlen=maxlen\n",
")\n",
"y_q = pad_sequences(token_q.texts_to_sequences(questions), padding=\"post\", maxlen=maxlen)\n",
"y_a = pad_sequences(token_a.texts_to_sequences(answers), padding=\"post\", maxlen=maxlen)\n",
"\n",
"print(set(types))\n",
"\n",
"y_type = [seq[0] for seq in token_type.texts_to_sequences(types)] # list of int\n",
"y_type = to_categorical(np.array(y_type) - 1, num_classes=len(token_type.word_index))\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 25,
"id": "f530cfe7",
"metadata": {},
"outputs": [],
"source": [
"X_tok_train, X_tok_test, X_ner_train, X_ner_test, X_srl_train, X_srl_test, \\\n",
"y_q_train, y_q_test, y_a_train, y_a_test, y_type_train, y_type_test = train_test_split(\n",
" X_tok, X_ner, X_srl, y_q, y_a, y_type, test_size=0.2, random_state=42\n",
")\n",
"\n",
"X_train = [X_tok_train, X_ner_train, X_srl_train]\n",
"X_test = [X_tok_test, X_ner_test, X_srl_test]"
]
},
{
"cell_type": "code",
"execution_count": 26,
"id": "255e2a9a",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">Model: \"functional_1\"</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1mModel: \"functional_1\"\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓\n",
"┃<span style=\"font-weight: bold\"> Layer (type) </span>┃<span style=\"font-weight: bold\"> Output Shape </span>┃<span style=\"font-weight: bold\"> Param # </span>┃<span style=\"font-weight: bold\"> Connected to </span>┃\n",
"┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩\n",
"│ tok_input │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ - │\n",
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">InputLayer</span>) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ ner_input │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ - │\n",
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">InputLayer</span>) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ srl_input │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ - │\n",
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">InputLayer</span>) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ embedding_3 │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">128</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">116,992</span> │ tok_input[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Embedding</span>) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ embedding_4 │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">16</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">704</span> │ ner_input[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Embedding</span>) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ embedding_5 │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">16</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">336</span> │ srl_input[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Embedding</span>) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ concatenate_1 │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">160</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ embedding_3[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>… │\n",
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Concatenate</span>) │ │ │ embedding_4[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>… │\n",
"│ │ │ │ embedding_5[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ lstm_1 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">LSTM</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">256</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">427,008</span> │ concatenate_1[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>]… │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ get_item_1 │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">256</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ lstm_1[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">GetItem</span>) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ question_output │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">479</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">123,103</span> │ lstm_1[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">TimeDistributed</span>) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ answer_output │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">308</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">79,156</span> │ lstm_1[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">TimeDistributed</span>) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ type_output (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Dense</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">4</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">1,028</span> │ get_item_1[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
"└─────────────────────┴───────────────────┴────────────┴───────────────────┘\n",
"</pre>\n"
],
"text/plain": [
"┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓\n",
"┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mConnected to \u001b[0m\u001b[1m \u001b[0m┃\n",
"┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩\n",
"│ tok_input │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n",
"│ (\u001b[38;5;33mInputLayer\u001b[0m) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ ner_input │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n",
"│ (\u001b[38;5;33mInputLayer\u001b[0m) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ srl_input │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n",
"│ (\u001b[38;5;33mInputLayer\u001b[0m) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ embedding_3 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m128\u001b[0m) │ \u001b[38;5;34m116,992\u001b[0m │ tok_input[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"│ (\u001b[38;5;33mEmbedding\u001b[0m) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ embedding_4 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m16\u001b[0m) │ \u001b[38;5;34m704\u001b[0m │ ner_input[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"│ (\u001b[38;5;33mEmbedding\u001b[0m) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ embedding_5 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m16\u001b[0m) │ \u001b[38;5;34m336\u001b[0m │ srl_input[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"│ (\u001b[38;5;33mEmbedding\u001b[0m) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ concatenate_1 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m160\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ embedding_3[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m… │\n",
"│ (\u001b[38;5;33mConcatenate\u001b[0m) │ │ │ embedding_4[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m… │\n",
"│ │ │ │ embedding_5[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ lstm_1 (\u001b[38;5;33mLSTM\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m256\u001b[0m) │ \u001b[38;5;34m427,008\u001b[0m │ concatenate_1[\u001b[38;5;34m0\u001b[0m]… │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ get_item_1 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m256\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ lstm_1[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"│ (\u001b[38;5;33mGetItem\u001b[0m) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ question_output │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m479\u001b[0m) │ \u001b[38;5;34m123,103\u001b[0m │ lstm_1[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"│ (\u001b[38;5;33mTimeDistributed\u001b[0m) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ answer_output │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m308\u001b[0m) │ \u001b[38;5;34m79,156\u001b[0m │ lstm_1[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"│ (\u001b[38;5;33mTimeDistributed\u001b[0m) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ type_output (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m4\u001b[0m) │ \u001b[38;5;34m1,028\u001b[0m │ get_item_1[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"└─────────────────────┴───────────────────┴────────────┴───────────────────┘\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Total params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">748,327</span> (2.85 MB)\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m Total params: \u001b[0m\u001b[38;5;34m748,327\u001b[0m (2.85 MB)\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Trainable params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">748,327</span> (2.85 MB)\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m748,327\u001b[0m (2.85 MB)\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Non-trainable params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> (0.00 B)\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/30\n",
"\u001b[1m5/5\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 176ms/step - answer_output_accuracy: 0.4544 - answer_output_loss: 5.6455 - loss: 13.1436 - question_output_accuracy: 0.3565 - question_output_loss: 6.1017 - type_output_accuracy: 0.6386 - type_output_loss: 1.3766 - val_answer_output_accuracy: 0.9141 - val_answer_output_loss: 5.0547 - val_loss: 12.0109 - val_question_output_accuracy: 0.6844 - val_question_output_loss: 5.6110 - val_type_output_accuracy: 1.0000 - val_type_output_loss: 1.3453\n",
"Epoch 2/30\n",
"\u001b[1m5/5\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 54ms/step - answer_output_accuracy: 0.9145 - answer_output_loss: 4.3849 - loss: 10.8584 - question_output_accuracy: 0.6760 - question_output_loss: 5.0255 - type_output_accuracy: 0.9758 - type_output_loss: 1.3371 - val_answer_output_accuracy: 0.9141 - val_answer_output_loss: 2.1055 - val_loss: 6.1782 - val_question_output_accuracy: 0.6844 - val_question_output_loss: 2.7704 - val_type_output_accuracy: 1.0000 - val_type_output_loss: 1.3023\n",
"Epoch 3/30\n",
"\u001b[1m5/5\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 55ms/step - answer_output_accuracy: 0.9095 - answer_output_loss: 1.7129 - loss: 5.4664 - question_output_accuracy: 0.6777 - question_output_loss: 2.4346 - type_output_accuracy: 0.9795 - type_output_loss: 1.2889 - val_answer_output_accuracy: 0.9141 - val_answer_output_loss: 1.0023 - val_loss: 4.2358 - val_question_output_accuracy: 0.6844 - val_question_output_loss: 2.0019 - val_type_output_accuracy: 1.0000 - val_type_output_loss: 1.2316\n",
"Epoch 4/30\n",
"\u001b[1m5/5\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 54ms/step - answer_output_accuracy: 0.9140 - answer_output_loss: 0.9210 - loss: 4.2240 - question_output_accuracy: 0.6804 - question_output_loss: 2.1028 - type_output_accuracy: 0.9812 - type_output_loss: 1.2037 - val_answer_output_accuracy: 0.9141 - val_answer_output_loss: 0.7526 - val_loss: 4.0127 - val_question_output_accuracy: 0.6844 - val_question_output_loss: 2.1652 - val_type_output_accuracy: 1.0000 - val_type_output_loss: 1.0949\n",
"Epoch 5/30\n",
"\u001b[1m5/5\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 52ms/step - answer_output_accuracy: 0.9131 - answer_output_loss: 0.7388 - loss: 4.0409 - question_output_accuracy: 0.6753 - question_output_loss: 2.2497 - type_output_accuracy: 0.9832 - type_output_loss: 1.0455 - val_answer_output_accuracy: 0.9141 - val_answer_output_loss: 0.6789 - val_loss: 3.6821 - val_question_output_accuracy: 0.6844 - val_question_output_loss: 2.1028 - val_type_output_accuracy: 1.0000 - val_type_output_loss: 0.9003\n",
"Epoch 6/30\n",
"\u001b[1m5/5\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 52ms/step - answer_output_accuracy: 0.9190 - answer_output_loss: 0.6585 - loss: 3.5809 - question_output_accuracy: 0.6788 - question_output_loss: 2.0865 - type_output_accuracy: 0.9797 - type_output_loss: 0.8341 - val_answer_output_accuracy: 0.9141 - val_answer_output_loss: 0.6491 - val_loss: 3.3418 - val_question_output_accuracy: 0.6844 - val_question_output_loss: 2.0148 - val_type_output_accuracy: 1.0000 - val_type_output_loss: 0.6779\n",
"Epoch 7/30\n",
"\u001b[1m5/5\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 54ms/step - answer_output_accuracy: 0.9165 - answer_output_loss: 0.6312 - loss: 3.2776 - question_output_accuracy: 0.6763 - question_output_loss: 2.0259 - type_output_accuracy: 0.9695 - type_output_loss: 0.6233 - val_answer_output_accuracy: 0.9141 - val_answer_output_loss: 0.6313 - val_loss: 3.1431 - val_question_output_accuracy: 0.6844 - val_question_output_loss: 2.0432 - val_type_output_accuracy: 1.0000 - val_type_output_loss: 0.4687\n",
"Epoch 8/30\n",
"\u001b[1m5/5\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 53ms/step - answer_output_accuracy: 0.9148 - answer_output_loss: 0.6209 - loss: 3.0631 - question_output_accuracy: 0.6762 - question_output_loss: 2.0136 - type_output_accuracy: 0.9708 - type_output_loss: 0.4301 - val_answer_output_accuracy: 0.9141 - val_answer_output_loss: 0.6193 - val_loss: 2.9071 - val_question_output_accuracy: 0.6844 - val_question_output_loss: 1.9849 - val_type_output_accuracy: 1.0000 - val_type_output_loss: 0.3029\n",
"Epoch 9/30\n",
"\u001b[1m5/5\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 54ms/step - answer_output_accuracy: 0.9155 - answer_output_loss: 0.6067 - loss: 2.7923 - question_output_accuracy: 0.6799 - question_output_loss: 1.9057 - type_output_accuracy: 0.9747 - type_output_loss: 0.2789 - val_answer_output_accuracy: 0.9141 - val_answer_output_loss: 0.6109 - val_loss: 2.7805 - val_question_output_accuracy: 0.6844 - val_question_output_loss: 1.9768 - val_type_output_accuracy: 1.0000 - val_type_output_loss: 0.1928\n",
"Epoch 10/30\n",
"\u001b[1m5/5\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 54ms/step - answer_output_accuracy: 0.9160 - answer_output_loss: 0.5715 - loss: 2.6738 - question_output_accuracy: 0.6770 - question_output_loss: 1.9091 - type_output_accuracy: 0.9784 - type_output_loss: 0.1873 - val_answer_output_accuracy: 0.9141 - val_answer_output_loss: 0.6033 - val_loss: 2.6801 - val_question_output_accuracy: 0.6844 - val_question_output_loss: 1.9506 - val_type_output_accuracy: 1.0000 - val_type_output_loss: 0.1262\n",
"Epoch 11/30\n",
"\u001b[1m5/5\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 55ms/step - answer_output_accuracy: 0.9159 - answer_output_loss: 0.5691 - loss: 2.5854 - question_output_accuracy: 0.6791 - question_output_loss: 1.8621 - type_output_accuracy: 0.9743 - type_output_loss: 0.1495 - val_answer_output_accuracy: 0.9141 - val_answer_output_loss: 0.5962 - val_loss: 2.5971 - val_question_output_accuracy: 0.7031 - val_question_output_loss: 1.9119 - val_type_output_accuracy: 1.0000 - val_type_output_loss: 0.0890\n",
"Epoch 12/30\n",
"\u001b[1m5/5\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 54ms/step - answer_output_accuracy: 0.9151 - answer_output_loss: 0.5528 - loss: 2.4857 - question_output_accuracy: 0.6954 - question_output_loss: 1.8064 - type_output_accuracy: 0.9765 - type_output_loss: 0.1240 - val_answer_output_accuracy: 0.9141 - val_answer_output_loss: 0.5907 - val_loss: 2.5231 - val_question_output_accuracy: 0.7031 - val_question_output_loss: 1.8654 - val_type_output_accuracy: 1.0000 - val_type_output_loss: 0.0670\n",
"Epoch 13/30\n",
"\u001b[1m5/5\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 55ms/step - answer_output_accuracy: 0.9116 - answer_output_loss: 0.5741 - loss: 2.4910 - question_output_accuracy: 0.6913 - question_output_loss: 1.7912 - type_output_accuracy: 0.9721 - type_output_loss: 0.1279 - val_answer_output_accuracy: 0.9141 - val_answer_output_loss: 0.5874 - val_loss: 2.4624 - val_question_output_accuracy: 0.7031 - val_question_output_loss: 1.8207 - val_type_output_accuracy: 1.0000 - val_type_output_loss: 0.0543\n",
"Epoch 14/30\n",
"\u001b[1m5/5\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 59ms/step - answer_output_accuracy: 0.9142 - answer_output_loss: 0.5370 - loss: 2.4278 - question_output_accuracy: 0.6900 - question_output_loss: 1.7686 - type_output_accuracy: 0.9730 - type_output_loss: 0.1186 - val_answer_output_accuracy: 0.9141 - val_answer_output_loss: 0.5837 - val_loss: 2.4136 - val_question_output_accuracy: 0.7031 - val_question_output_loss: 1.7833 - val_type_output_accuracy: 1.0000 - val_type_output_loss: 0.0466\n",
"Epoch 15/30\n",
"\u001b[1m5/5\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 55ms/step - answer_output_accuracy: 0.9160 - answer_output_loss: 0.5186 - loss: 2.3183 - question_output_accuracy: 0.6898 - question_output_loss: 1.7028 - type_output_accuracy: 0.9784 - type_output_loss: 0.1001 - val_answer_output_accuracy: 0.9141 - val_answer_output_loss: 0.5794 - val_loss: 2.3714 - val_question_output_accuracy: 0.7109 - val_question_output_loss: 1.7506 - val_type_output_accuracy: 1.0000 - val_type_output_loss: 0.0414\n",
"Epoch 16/30\n",
"\u001b[1m5/5\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 54ms/step - answer_output_accuracy: 0.9171 - answer_output_loss: 0.5077 - loss: 2.2275 - question_output_accuracy: 0.7036 - question_output_loss: 1.6393 - type_output_accuracy: 0.9791 - type_output_loss: 0.0876 - val_answer_output_accuracy: 0.9141 - val_answer_output_loss: 0.5748 - val_loss: 2.3340 - val_question_output_accuracy: 0.7172 - val_question_output_loss: 1.7214 - val_type_output_accuracy: 1.0000 - val_type_output_loss: 0.0379\n",
"Epoch 17/30\n",
"\u001b[1m5/5\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 55ms/step - answer_output_accuracy: 0.9137 - answer_output_loss: 0.5248 - loss: 2.2290 - question_output_accuracy: 0.7070 - question_output_loss: 1.6285 - type_output_accuracy: 0.9828 - type_output_loss: 0.0771 - val_answer_output_accuracy: 0.9219 - val_answer_output_loss: 0.5716 - val_loss: 2.3017 - val_question_output_accuracy: 0.7172 - val_question_output_loss: 1.6946 - val_type_output_accuracy: 1.0000 - val_type_output_loss: 0.0355\n",
"Epoch 18/30\n",
"\u001b[1m5/5\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 54ms/step - answer_output_accuracy: 0.9233 - answer_output_loss: 0.5080 - loss: 2.2392 - question_output_accuracy: 0.7059 - question_output_loss: 1.6139 - type_output_accuracy: 0.9678 - type_output_loss: 0.1205 - val_answer_output_accuracy: 0.9219 - val_answer_output_loss: 0.5676 - val_loss: 2.2777 - val_question_output_accuracy: 0.7219 - val_question_output_loss: 1.6760 - val_type_output_accuracy: 1.0000 - val_type_output_loss: 0.0341\n",
"Epoch 19/30\n",
"\u001b[1m5/5\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 53ms/step - answer_output_accuracy: 0.9221 - answer_output_loss: 0.5038 - loss: 2.1188 - question_output_accuracy: 0.7131 - question_output_loss: 1.5706 - type_output_accuracy: 0.9854 - type_output_loss: 0.0616 - val_answer_output_accuracy: 0.9219 - val_answer_output_loss: 0.5639 - val_loss: 2.2545 - val_question_output_accuracy: 0.7203 - val_question_output_loss: 1.6580 - val_type_output_accuracy: 1.0000 - val_type_output_loss: 0.0326\n",
"Epoch 20/30\n",
"\u001b[1m5/5\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 54ms/step - answer_output_accuracy: 0.9175 - answer_output_loss: 0.5233 - loss: 2.1645 - question_output_accuracy: 0.7128 - question_output_loss: 1.5526 - type_output_accuracy: 0.9775 - type_output_loss: 0.0858 - val_answer_output_accuracy: 0.9219 - val_answer_output_loss: 0.5603 - val_loss: 2.2376 - val_question_output_accuracy: 0.7234 - val_question_output_loss: 1.6450 - val_type_output_accuracy: 1.0000 - val_type_output_loss: 0.0323\n",
"Epoch 21/30\n",
"\u001b[1m5/5\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 54ms/step - answer_output_accuracy: 0.9193 - answer_output_loss: 0.5090 - loss: 2.1288 - question_output_accuracy: 0.7118 - question_output_loss: 1.5447 - type_output_accuracy: 0.9828 - type_output_loss: 0.0644 - val_answer_output_accuracy: 0.9219 - val_answer_output_loss: 0.5568 - val_loss: 2.2206 - val_question_output_accuracy: 0.7219 - val_question_output_loss: 1.6317 - val_type_output_accuracy: 1.0000 - val_type_output_loss: 0.0321\n",
"Epoch 22/30\n",
"\u001b[1m5/5\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 54ms/step - answer_output_accuracy: 0.9204 - answer_output_loss: 0.4971 - loss: 2.0726 - question_output_accuracy: 0.7128 - question_output_loss: 1.5100 - type_output_accuracy: 0.9817 - type_output_loss: 0.0626 - val_answer_output_accuracy: 0.9219 - val_answer_output_loss: 0.5535 - val_loss: 2.2055 - val_question_output_accuracy: 0.7359 - val_question_output_loss: 1.6200 - val_type_output_accuracy: 1.0000 - val_type_output_loss: 0.0320\n",
"Epoch 23/30\n",
"\u001b[1m5/5\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 52ms/step - answer_output_accuracy: 0.9191 - answer_output_loss: 0.5003 - loss: 2.1218 - question_output_accuracy: 0.7108 - question_output_loss: 1.5310 - type_output_accuracy: 0.9762 - type_output_loss: 0.0771 - val_answer_output_accuracy: 0.9219 - val_answer_output_loss: 0.5517 - val_loss: 2.1920 - val_question_output_accuracy: 0.7234 - val_question_output_loss: 1.6081 - val_type_output_accuracy: 1.0000 - val_type_output_loss: 0.0322\n",
"Epoch 24/30\n",
"\u001b[1m5/5\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 53ms/step - answer_output_accuracy: 0.9220 - answer_output_loss: 0.4808 - loss: 2.0044 - question_output_accuracy: 0.7175 - question_output_loss: 1.4722 - type_output_accuracy: 0.9810 - type_output_loss: 0.0608 - val_answer_output_accuracy: 0.9219 - val_answer_output_loss: 0.5494 - val_loss: 2.1723 - val_question_output_accuracy: 0.7312 - val_question_output_loss: 1.5905 - val_type_output_accuracy: 1.0000 - val_type_output_loss: 0.0323\n",
"Epoch 25/30\n",
"\u001b[1m5/5\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 58ms/step - answer_output_accuracy: 0.9183 - answer_output_loss: 0.4965 - loss: 2.0500 - question_output_accuracy: 0.7174 - question_output_loss: 1.4835 - type_output_accuracy: 0.9775 - type_output_loss: 0.0676 - val_answer_output_accuracy: 0.9219 - val_answer_output_loss: 0.5473 - val_loss: 2.1609 - val_question_output_accuracy: 0.7328 - val_question_output_loss: 1.5810 - val_type_output_accuracy: 1.0000 - val_type_output_loss: 0.0326\n",
"Epoch 26/30\n",
"\u001b[1m5/5\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 52ms/step - answer_output_accuracy: 0.9236 - answer_output_loss: 0.4672 - loss: 1.9620 - question_output_accuracy: 0.7220 - question_output_loss: 1.4313 - type_output_accuracy: 0.9780 - type_output_loss: 0.0672 - val_answer_output_accuracy: 0.9219 - val_answer_output_loss: 0.5454 - val_loss: 2.1488 - val_question_output_accuracy: 0.7344 - val_question_output_loss: 1.5705 - val_type_output_accuracy: 1.0000 - val_type_output_loss: 0.0328\n",
"Epoch 27/30\n",
"\u001b[1m5/5\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 70ms/step - answer_output_accuracy: 0.9219 - answer_output_loss: 0.4671 - loss: 1.9415 - question_output_accuracy: 0.7288 - question_output_loss: 1.4130 - type_output_accuracy: 0.9765 - type_output_loss: 0.0605 - val_answer_output_accuracy: 0.9219 - val_answer_output_loss: 0.5440 - val_loss: 2.1382 - val_question_output_accuracy: 0.7359 - val_question_output_loss: 1.5615 - val_type_output_accuracy: 1.0000 - val_type_output_loss: 0.0327\n",
"Epoch 28/30\n",
"\u001b[1m5/5\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 52ms/step - answer_output_accuracy: 0.9212 - answer_output_loss: 0.4676 - loss: 1.9277 - question_output_accuracy: 0.7271 - question_output_loss: 1.4106 - type_output_accuracy: 0.9823 - type_output_loss: 0.0526 - val_answer_output_accuracy: 0.9219 - val_answer_output_loss: 0.5435 - val_loss: 2.1317 - val_question_output_accuracy: 0.7422 - val_question_output_loss: 1.5559 - val_type_output_accuracy: 1.0000 - val_type_output_loss: 0.0323\n",
"Epoch 29/30\n",
"\u001b[1m5/5\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 57ms/step - answer_output_accuracy: 0.9228 - answer_output_loss: 0.4658 - loss: 1.8773 - question_output_accuracy: 0.7397 - question_output_loss: 1.3683 - type_output_accuracy: 0.9823 - type_output_loss: 0.0487 - val_answer_output_accuracy: 0.9219 - val_answer_output_loss: 0.5428 - val_loss: 2.1239 - val_question_output_accuracy: 0.7437 - val_question_output_loss: 1.5493 - val_type_output_accuracy: 1.0000 - val_type_output_loss: 0.0319\n",
"Epoch 30/30\n",
"\u001b[1m5/5\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 56ms/step - answer_output_accuracy: 0.9207 - answer_output_loss: 0.4658 - loss: 1.9146 - question_output_accuracy: 0.7355 - question_output_loss: 1.3799 - type_output_accuracy: 0.9795 - type_output_loss: 0.0563 - val_answer_output_accuracy: 0.9219 - val_answer_output_loss: 0.5421 - val_loss: 2.1174 - val_question_output_accuracy: 0.7437 - val_question_output_loss: 1.5436 - val_type_output_accuracy: 1.0000 - val_type_output_loss: 0.0317\n"
]
}
],
"source": [
"\n",
"inp_tok = Input(shape=(None,), name=\"tok_input\")\n",
"inp_ner = Input(shape=(None,), name=\"ner_input\")\n",
"inp_srl = Input(shape=(None,), name=\"srl_input\")\n",
"\n",
"emb_tok = Embedding(input_dim=len(token_tok.word_index) + 1, output_dim=128)(inp_tok)\n",
"emb_ner = Embedding(input_dim=len(token_ner.word_index) + 1, output_dim=16)(inp_ner)\n",
"emb_srl = Embedding(input_dim=len(token_srl.word_index) + 1, output_dim=16)(inp_srl)\n",
"\n",
"# emb_tok = Embedding(input_dim=..., output_dim=..., mask_zero=True)(inp_tok)\n",
"# emb_ner = Embedding(input_dim=..., output_dim=..., mask_zero=True)(inp_ner)\n",
"# emb_srl = Embedding(input_dim=..., output_dim=..., mask_zero=True)(inp_srl)\n",
"\n",
"merged = Concatenate()([emb_tok, emb_ner, emb_srl])\n",
"\n",
"x = LSTM(256, return_sequences=True)(merged)\n",
"\n",
"out_question = TimeDistributed(Dense(len(token_q.word_index) + 1, activation=\"softmax\"), name=\"question_output\")(x)\n",
"out_answer = TimeDistributed(Dense(len(token_a.word_index) + 1, activation=\"softmax\"), name=\"answer_output\")(x)\n",
"out_type = Dense(len(token_type.word_index), activation=\"softmax\", name=\"type_output\")(\n",
" x[:, 0, :]\n",
") # gunakan step pertama\n",
"\n",
"model = Model(\n",
" inputs=[inp_tok, inp_ner, inp_srl], outputs=[out_question, out_answer, out_type]\n",
")\n",
"model.compile(\n",
" optimizer=\"adam\",\n",
" loss={\n",
" \"question_output\": \"sparse_categorical_crossentropy\",\n",
" \"answer_output\": \"sparse_categorical_crossentropy\",\n",
" \"type_output\": \"categorical_crossentropy\",\n",
" },\n",
" metrics={\n",
" \"question_output\": \"accuracy\",\n",
" \"answer_output\": \"accuracy\",\n",
" \"type_output\": \"accuracy\",\n",
" },\n",
")\n",
"\n",
"model.summary()\n",
"\n",
"# ----------------------------------------------------------------------------\n",
"# 5. TRAINING\n",
"# ----------------------------------------------------------------------------\n",
"model.fit(\n",
" X_train,\n",
" {\n",
" \"question_output\": np.expand_dims(y_q_train, -1),\n",
" \"answer_output\": np.expand_dims(y_a_train, -1),\n",
" \"type_output\": y_type_train,\n",
" },\n",
" batch_size=64,\n",
" epochs=30,\n",
" validation_split=0.1,\n",
" callbacks=[EarlyStopping(patience=3, restore_best_weights=True)],\n",
")\n",
"\n",
"import pickle\n",
"\n",
"\n",
"model.save(\"new_model_lstm_qg.keras\")\n",
"with open(\"tokenizers.pkl\", \"wb\") as f:\n",
" pickle.dump({\n",
" \"token\": token_tok,\n",
" \"ner\": token_ner,\n",
" \"srl\": token_srl,\n",
" \"question\": token_q,\n",
" \"answer\": token_a,\n",
" \"type\": token_type\n",
" }, f)\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 27,
"id": "06fd86c7",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[1m3/3\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 137ms/step\n",
"\n",
"=== Akurasi Detail ===\n",
"Question Accuracy (Token-level): 0.1519\n",
"Answer Accuracy (Token-level) : 0.0638\n",
"Type Accuracy (Class-level) : 1.00\n"
]
}
],
"source": [
"\n",
"def token_level_accuracy(y_true, y_pred):\n",
" correct = 0\n",
" total = 0\n",
" for true_seq, pred_seq in zip(y_true, y_pred):\n",
" for t, p in zip(true_seq, pred_seq):\n",
" if t != 0: # ignore padding\n",
" total += 1\n",
" if t == p:\n",
" correct += 1\n",
" return correct / total if total > 0 else 0\n",
"\n",
"\n",
"# Predict on test set\n",
"y_pred_q, y_pred_a, y_pred_type = model.predict(X_test)\n",
"\n",
"# Decode predictions to class indices\n",
"y_pred_q = np.argmax(y_pred_q, axis=-1)\n",
"y_pred_a = np.argmax(y_pred_a, axis=-1)\n",
"y_pred_type = np.argmax(y_pred_type, axis=-1)\n",
"y_true_type = np.argmax(y_type_test, axis=-1)\n",
"\n",
"# Calculate token-level accuracy\n",
"acc_q = token_level_accuracy(y_q_test, y_pred_q)\n",
"acc_a = token_level_accuracy(y_a_test, y_pred_a)\n",
"\n",
"# Type classification report\n",
"report_type = classification_report(y_true_type, y_pred_type, zero_division=0)\n",
"\n",
"# Print Results\n",
"print(\"\\n=== Akurasi Detail ===\")\n",
"print(f\"Question Accuracy (Token-level): {acc_q:.4f}\")\n",
"print(f\"Answer Accuracy (Token-level) : {acc_a:.4f}\")\n",
"print(f\"Type Accuracy (Class-level) : {np.mean(y_true_type == y_pred_type):.2f}\")"
]
},
{
"cell_type": "code",
"execution_count": 28,
"id": "b17b6470",
"metadata": {},
"outputs": [],
"source": [
"# import sacrebleu\n",
"# from sacrebleu.metrics import BLEU # optional kalau mau smoothing/effective_order\n",
"\n",
"# idx2tok = {v:k for k,v in word2idx.items()}\n",
"# PAD_ID = word2idx[\"PAD\"]\n",
"# SOS_ID = word2idx.get(\"SOS\", None)\n",
"# EOS_ID = word2idx.get(\"EOS\", None)\n",
"\n",
"# def seq2str(seq):\n",
"# \"\"\"Konversi list index -> kalimat string, sambil buang token spesial.\"\"\"\n",
"# toks = [idx2tok[i] for i in seq\n",
"# if i not in {PAD_ID, SOS_ID, EOS_ID}]\n",
"# return \" \".join(toks).strip().lower()\n",
"\n",
"# bleu_metric = BLEU(effective_order=True) # lebih stabil utk kalimat pendek\n",
"\n",
"# def bleu_corpus(pred_seqs, true_seqs):\n",
"# preds = [seq2str(p) for p in pred_seqs]\n",
"# refs = [[seq2str(t)] for t in true_seqs] # listoflist, satu ref/kalimat\n",
"# return bleu_metric.corpus_score(preds, refs).score\n"
]
},
{
"cell_type": "code",
"execution_count": 29,
"id": "d5ed106c",
"metadata": {},
"outputs": [],
"source": [
"\n",
"# flat_true_a, flat_pred_a = flatten_valid(y_a_test, y_pred_a_class)\n",
"# print(\"\\n=== Classification Report: ANSWER ===\")\n",
"# print(classification_report(flat_true_a, flat_pred_a))\n"
]
},
{
"cell_type": "code",
"execution_count": 30,
"id": "aa3860de",
"metadata": {},
"outputs": [],
"source": [
"\n",
"# print(\"\\n=== Classification Report: TYPE ===\")\n",
"# print(classification_report(y_true_type_class, y_pred_type_class))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "myenv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.16"
}
},
"nbformat": 4,
"nbformat_minor": 5
}