TIF_E41211115_lstm-quiz-gen.../question_generation/qg_lstm.ipynb

867 lines
76 KiB
Plaintext
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "fb283f23",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Total flattened samples: 596\n"
]
}
],
"source": [
"import json\n",
"from pathlib import Path\n",
"from itertools import chain\n",
"\n",
"RAW = json.loads(\n",
" Path(\"../dataset/dev_dataset_qg.json\").read_text()\n",
") # ← file contoh Anda\n",
"\n",
"samples = []\n",
"for item in RAW:\n",
" for qp in item[\"quiz_posibility\"]:\n",
" samp = {\n",
" \"tokens\": [tok.lower() for tok in item[\"tokens\"]],\n",
" \"ner\": item[\"ner\"],\n",
" \"srl\": item[\"srl\"],\n",
" \"q_type\": qp[\"type\"], # isian / opsi / benar_salah\n",
" \"q_toks\": [tok.lower() for tok in qp[\"question\"]]\n",
" + [\"<eos>\"], # tambahkan <eos>\n",
" }\n",
" # Jawaban bisa multi token\n",
" if isinstance(qp[\"answer\"], list):\n",
" samp[\"a_toks\"] = [tok.lower() for tok in qp[\"answer\"]] + [\"<eos>\"]\n",
" else:\n",
" samp[\"a_toks\"] = [qp[\"answer\"].lower(), \"<eos>\"]\n",
" samples.append(samp)\n",
"\n",
"print(\"Total flattened samples:\", len(samples))"
]
},
{
"cell_type": "code",
"execution_count": 102,
"id": "fa4f979d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'<pad>': 0, '<unk>': 1, '<sos>': 2, '<eos>': 3, 'indra': 4, 'rahmawati': 5, 'lahir': 6, 'di': 7, 'malang': 8, 'pada': 9, '27': 10, 'februari': 11, '1972': 12, '___': 13, 'dimana': 14, 'fitri': 15, 'pratama': 16, 'joko': 17, 'wijaya': 18, 'tangerang': 19, '26': 20, 'september': 21, '1970': 22, 'citra': 23, 'jakarta': 24, '17': 25, '1983': 26, 'saputra': 27, 'kota': 28, 'mana': 29, 'nina': 30, 'ridwan': 31, 'gilang': 32, 'purnama': 33, 'depok': 34, '21': 35, 'desember': 36, '1971': 37, 'budi': 38, 'eka': 39, 'laksmi': 40, 'surabaya': 41, '24': 42, 'oktober': 43, '1996': 44, 'sari': 45, '20': 46, 'mei': 47, '1986': 48, 'maulana': 49, 'hana': 50, 'makassar': 51, '16': 52, 'agustus': 53, '1993': 54, 'restu': 55, 'hasan': 56, 'lina': 57, 'siregar': 58, '22': 59, 'april': 60, '1987': 61, 'palembang': 62, '7': 63, 'januari': 64, '1997': 65, '4': 66, 'november': 67, '2009': 68, 'dina': 69, '1978': 70, 'nanda': 71, 'putra': 72, 'tanggal': 73, 'berapa': 74, 'kusuma': 75, '13': 76, '28': 77, '1984': 78, 'medan': 79, '3': 80, 'maret': 81, '1995': 82, '6': 83, '2000': 84, 'anggraini': 85, '10': 86, '2006': 87, 'semarang': 88, 'juni': 89, '1976': 90, 'mega': 91, 'dwiputra': 92, '1992': 93, 'yuni': 94, 'baskara': 95, 'utami': 96, 'diaz': 97, '14': 98, '2008': 99, 'vina': 100, 'lestari': 101, 'mariani': 102, 'islami': 103, '23': 104, 'juli': 105, 'permata': 106, '2007': 107, '1994': 108, '1': 109, '1981': 110, 'toni': 111, 'hakim': 112, 'hendra': 113, 'dwi': 114, 'santoso': 115, 'andi': 116, '25': 117, '1998': 118, '19': 119, 'bayu': 120, 'suryanto': 121, 'kartika': 122, '2010': 123, 'joni': 124, '1985': 125, 'qori': 126, '2003': 127, '1999': 128, 'bandung': 129, 'adi': 130, 'angga': 131, '18': 132, '2004': 133, 'putri': 134, 'laksono': 135, 'hani': 136, '12': 137, '15': 138, '11': 139, '1980': 140, '2002': 141, 'trisna': 142, '1973': 143, 'oki': 144, 'hidayat': 145, '1982': 146, 'suryadi': 147, 'galih': 148, '2': 149, '1977': 150, 'agus': 151, 'sudarto': 152, 'endah': 153, '1991': 154, 'yuliani': 155, 'seno': 156, 'nandito': 157, 'syah': 158, '5': 159, 'mira': 160, '9': 161, 'indah': 162, 'widodo': 163, '1974': 164, 'rio': 165, '1979': 166, '1988': 167, 'yana': 168, 'elisa': 169, 'irma': 170, 'irawan': 171, 'fajar': 172, '1990': 173, 'mustofa': 174, 'heri': 175, 'fina': 176, 'cahya': 177, 'ramadhan': 178, 'setiawan': 179, 'prasetyo': 180, 'ariani': 181, 'kurniawan': 182, 'farhan': 183, 'dwisaputra': 184, 'hardian': 185, '1975': 186, '8': 187, '2001': 188, 'iqbal': 189, 'kurnia': 190, 'cahyono': 191, '1989': 192, 'kiki': 193, 'suwandi': 194, 'gita': 195, '2005': 196, 'rahman': 197, 'danang': 198, 'maharani': 199, 'lisa': 200, 'firmansyah': 201, 'wulan': 202, 'zaki': 203, 'wicak': 204, 'sita': 205, 'ahmad': 206, 'wahyuni': 207, 'saputri': 208, 'tria': 209, 'fadhil': 210}\n"
]
}
],
"source": [
"def build_vocab(seq_iter, reserved=[\"<pad>\", \"<unk>\", \"<sos>\", \"<eos>\"]):\n",
" vocab = {tok: idx for idx, tok in enumerate(reserved)}\n",
" for tok in chain.from_iterable(seq_iter):\n",
" if tok not in vocab:\n",
" vocab[tok] = len(vocab)\n",
" return vocab\n",
"\n",
"\n",
"vocab_tok = build_vocab((s[\"tokens\"] for s in samples))\n",
"vocab_ner = build_vocab((s[\"ner\"] for s in samples), reserved=[\"<pad>\", \"<unk>\"])\n",
"vocab_srl = build_vocab((s[\"srl\"] for s in samples), reserved=[\"<pad>\", \"<unk>\"])\n",
"vocab_q = build_vocab((s[\"q_toks\"] for s in samples))\n",
"vocab_a = build_vocab((s[\"a_toks\"] for s in samples))\n",
"\n",
"vocab_typ = {\"isian\": 0, \"opsi\": 1, \"true_false\": 2}\n",
"\n",
"print(vocab_q)"
]
},
{
"cell_type": "code",
"execution_count": 103,
"id": "d1a5b324",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"from tensorflow.keras.preprocessing.sequence import pad_sequences\n",
"\n",
"\n",
"def encode(seq, vmap): # token  id\n",
" return [vmap.get(t, vmap[\"<unk>\"]) for t in seq]\n",
"\n",
"\n",
"MAX_SENT = max(len(s[\"tokens\"]) for s in samples)\n",
"MAX_Q = max(len(s[\"q_toks\"]) for s in samples)\n",
"MAX_A = max(len(s[\"a_toks\"]) for s in samples)\n",
"\n",
"X_tok = pad_sequences(\n",
" [encode(s[\"tokens\"], vocab_tok) for s in samples], maxlen=MAX_SENT, padding=\"post\"\n",
")\n",
"X_ner = pad_sequences(\n",
" [encode(s[\"ner\"], vocab_ner) for s in samples], maxlen=MAX_SENT, padding=\"post\"\n",
")\n",
"X_srl = pad_sequences(\n",
" [encode(s[\"srl\"], vocab_srl) for s in samples], maxlen=MAX_SENT, padding=\"post\"\n",
")\n",
"\n",
"# Decoder input = <sos> + target[:-1]\n",
"dec_q_in = pad_sequences(\n",
" [[vocab_q[\"<sos>\"], *encode(s[\"q_toks\"][:-1], vocab_q)] for s in samples],\n",
" maxlen=MAX_Q,\n",
" padding=\"post\",\n",
")\n",
"dec_q_out = pad_sequences(\n",
" [encode(s[\"q_toks\"], vocab_q) for s in samples], maxlen=MAX_Q, padding=\"post\"\n",
")\n",
"\n",
"dec_a_in = pad_sequences(\n",
" [[vocab_a[\"<sos>\"], *encode(s[\"a_toks\"][:-1], vocab_a)] for s in samples],\n",
" maxlen=MAX_A,\n",
" padding=\"post\",\n",
")\n",
"dec_a_out = pad_sequences(\n",
" [encode(s[\"a_toks\"], vocab_a) for s in samples], maxlen=MAX_A, padding=\"post\"\n",
")\n",
"y_type = np.array([vocab_typ[s[\"q_type\"]] for s in samples])\n",
"\n",
"MAX_SENT = max(len(s[\"tokens\"]) for s in samples)\n",
"MAX_Q = max(len(s[\"q_toks\"]) for s in samples)\n",
"MAX_A = max(len(s[\"a_toks\"]) for s in samples)"
]
},
{
"cell_type": "code",
"execution_count": 104,
"id": "ff5bd85f",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">Model: \"functional_12\"</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1mModel: \"functional_12\"\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓\n",
"┃<span style=\"font-weight: bold\"> Layer (type) </span>┃<span style=\"font-weight: bold\"> Output Shape </span>┃<span style=\"font-weight: bold\"> Param # </span>┃<span style=\"font-weight: bold\"> Connected to </span>┃\n",
"┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩\n",
"│ tok_in (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">InputLayer</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">9</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ - │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ ner_in (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">InputLayer</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">9</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ - │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ srl_in (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">InputLayer</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">9</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ - │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ embedding_tok │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">9</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">32</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">7,808</span> │ tok_in[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Embedding</span>) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ embedding_ner │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">9</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">16</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">144</span> │ ner_in[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Embedding</span>) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ embedding_srl │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">9</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">16</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">112</span> │ srl_in[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Embedding</span>) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ dec_q_in │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">11</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ - │\n",
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">InputLayer</span>) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ concatenate_14 │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">9</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">64</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ embedding_tok[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>]… │\n",
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Concatenate</span>) │ │ │ embedding_ner[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>]… │\n",
"│ │ │ │ embedding_srl[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>]… │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ dec_a_in │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">4</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ - │\n",
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">InputLayer</span>) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ embedding_q_decoder │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">11</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">32</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">6,752</span> │ dec_q_in[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Embedding</span>) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ encoder_lstm (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">LSTM</span>) │ [(<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">64</span>), │ <span style=\"color: #00af00; text-decoration-color: #00af00\">33,024</span> │ concatenate_14[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>… │\n",
"│ │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">64</span>), │ │ │\n",
"│ │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">64</span>)] │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ embedding_a_decoder │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">4</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">32</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">3,584</span> │ dec_a_in[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Embedding</span>) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ lstm_q_decoder │ [(<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">11</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">64</span>), │ <span style=\"color: #00af00; text-decoration-color: #00af00\">24,832</span> │ embedding_q_deco… │\n",
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">LSTM</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">64</span>), │ │ encoder_lstm[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">…</span> │\n",
"│ │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">64</span>)] │ │ encoder_lstm[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">…</span> │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ not_equal_51 │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">11</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ dec_q_in[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">NotEqual</span>) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ lstm_a_decoder │ [(<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">4</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">64</span>), │ <span style=\"color: #00af00; text-decoration-color: #00af00\">24,832</span> │ embedding_a_deco… │\n",
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">LSTM</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">64</span>), │ │ encoder_lstm[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">…</span> │\n",
"│ │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">64</span>)] │ │ encoder_lstm[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">…</span> │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ not_equal_52 │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">4</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ dec_a_in[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">NotEqual</span>) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ q_output │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">11</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">211</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">13,715</span> │ lstm_q_decoder[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>… │\n",
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">TimeDistributed</span>) │ │ │ not_equal_51[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">…</span> │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ a_output │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">4</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">112</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">7,280</span> │ lstm_a_decoder[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>… │\n",
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">TimeDistributed</span>) │ │ │ not_equal_52[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">…</span> │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ type_output (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Dense</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">3</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">195</span> │ encoder_lstm[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">…</span> │\n",
"└─────────────────────┴───────────────────┴────────────┴───────────────────┘\n",
"</pre>\n"
],
"text/plain": [
"┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓\n",
"┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mConnected to \u001b[0m\u001b[1m \u001b[0m┃\n",
"┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩\n",
"│ tok_in (\u001b[38;5;33mInputLayer\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m9\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ ner_in (\u001b[38;5;33mInputLayer\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m9\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ srl_in (\u001b[38;5;33mInputLayer\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m9\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ embedding_tok │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m9\u001b[0m, \u001b[38;5;34m32\u001b[0m) │ \u001b[38;5;34m7,808\u001b[0m │ tok_in[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"│ (\u001b[38;5;33mEmbedding\u001b[0m) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ embedding_ner │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m9\u001b[0m, \u001b[38;5;34m16\u001b[0m) │ \u001b[38;5;34m144\u001b[0m │ ner_in[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"│ (\u001b[38;5;33mEmbedding\u001b[0m) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ embedding_srl │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m9\u001b[0m, \u001b[38;5;34m16\u001b[0m) │ \u001b[38;5;34m112\u001b[0m │ srl_in[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"│ (\u001b[38;5;33mEmbedding\u001b[0m) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ dec_q_in │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m11\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n",
"│ (\u001b[38;5;33mInputLayer\u001b[0m) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ concatenate_14 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m9\u001b[0m, \u001b[38;5;34m64\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ embedding_tok[\u001b[38;5;34m0\u001b[0m]… │\n",
"│ (\u001b[38;5;33mConcatenate\u001b[0m) │ │ │ embedding_ner[\u001b[38;5;34m0\u001b[0m]… │\n",
"│ │ │ │ embedding_srl[\u001b[38;5;34m0\u001b[0m]… │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ dec_a_in │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m4\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n",
"│ (\u001b[38;5;33mInputLayer\u001b[0m) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ embedding_q_decoder │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m11\u001b[0m, \u001b[38;5;34m32\u001b[0m) │ \u001b[38;5;34m6,752\u001b[0m │ dec_q_in[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"│ (\u001b[38;5;33mEmbedding\u001b[0m) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ encoder_lstm (\u001b[38;5;33mLSTM\u001b[0m) │ [(\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m64\u001b[0m), │ \u001b[38;5;34m33,024\u001b[0m │ concatenate_14[\u001b[38;5;34m0\u001b[0m… │\n",
"│ │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m64\u001b[0m), │ │ │\n",
"│ │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m64\u001b[0m)] │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ embedding_a_decoder │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m4\u001b[0m, \u001b[38;5;34m32\u001b[0m) │ \u001b[38;5;34m3,584\u001b[0m │ dec_a_in[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"│ (\u001b[38;5;33mEmbedding\u001b[0m) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ lstm_q_decoder │ [(\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m11\u001b[0m, \u001b[38;5;34m64\u001b[0m), │ \u001b[38;5;34m24,832\u001b[0m │ embedding_q_deco… │\n",
"│ (\u001b[38;5;33mLSTM\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m64\u001b[0m), │ │ encoder_lstm[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m…\u001b[0m │\n",
"│ │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m64\u001b[0m)] │ │ encoder_lstm[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m…\u001b[0m │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ not_equal_51 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m11\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ dec_q_in[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"│ (\u001b[38;5;33mNotEqual\u001b[0m) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ lstm_a_decoder │ [(\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m4\u001b[0m, \u001b[38;5;34m64\u001b[0m), │ \u001b[38;5;34m24,832\u001b[0m │ embedding_a_deco… │\n",
"│ (\u001b[38;5;33mLSTM\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m64\u001b[0m), │ │ encoder_lstm[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m…\u001b[0m │\n",
"│ │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m64\u001b[0m)] │ │ encoder_lstm[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m…\u001b[0m │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ not_equal_52 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m4\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ dec_a_in[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"│ (\u001b[38;5;33mNotEqual\u001b[0m) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ q_output │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m11\u001b[0m, \u001b[38;5;34m211\u001b[0m) │ \u001b[38;5;34m13,715\u001b[0m │ lstm_q_decoder[\u001b[38;5;34m0\u001b[0m… │\n",
"│ (\u001b[38;5;33mTimeDistributed\u001b[0m) │ │ │ not_equal_51[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m…\u001b[0m │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ a_output │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m4\u001b[0m, \u001b[38;5;34m112\u001b[0m) │ \u001b[38;5;34m7,280\u001b[0m │ lstm_a_decoder[\u001b[38;5;34m0\u001b[0m… │\n",
"│ (\u001b[38;5;33mTimeDistributed\u001b[0m) │ │ │ not_equal_52[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m…\u001b[0m │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ type_output (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m3\u001b[0m) │ \u001b[38;5;34m195\u001b[0m │ encoder_lstm[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m…\u001b[0m │\n",
"└─────────────────────┴───────────────────┴────────────┴───────────────────┘\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Total params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">122,278</span> (477.65 KB)\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m Total params: \u001b[0m\u001b[38;5;34m122,278\u001b[0m (477.65 KB)\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Trainable params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">122,278</span> (477.65 KB)\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m122,278\u001b[0m (477.65 KB)\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Non-trainable params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> (0.00 B)\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import tensorflow as tf\n",
"from tensorflow.keras.layers import (\n",
" Input,\n",
" Embedding,\n",
" LSTM,\n",
" Concatenate,\n",
" Dense,\n",
" TimeDistributed,\n",
")\n",
"from tensorflow.keras.models import Model\n",
"\n",
"# ---- constants ---------------------------------------------------\n",
"d_tok = 32 # token embedding dim\n",
"d_tag = 16 # NER / SRL embedding dim\n",
"units = 64\n",
"\n",
"# ---- encoder -----------------------------------------------------\n",
"inp_tok = Input((MAX_SENT,), name=\"tok_in\")\n",
"inp_ner = Input((MAX_SENT,), name=\"ner_in\")\n",
"inp_srl = Input((MAX_SENT,), name=\"srl_in\")\n",
"\n",
"# make ALL streams mask the same way (here: no masking,\n",
"# we'll just pad with 0s and let the LSTM ignore them)\n",
"emb_tok = Embedding(len(vocab_tok), d_tok, mask_zero=False, name=\"embedding_tok\")(\n",
" inp_tok\n",
")\n",
"emb_ner = Embedding(len(vocab_ner), d_tag, mask_zero=False, name=\"embedding_ner\")(\n",
" inp_ner\n",
")\n",
"emb_srl = Embedding(len(vocab_srl), d_tag, mask_zero=False, name=\"embedding_srl\")(\n",
" inp_srl\n",
")\n",
"\n",
"enc_concat = Concatenate()([emb_tok, emb_ner, emb_srl])\n",
"enc_out, state_h, state_c = LSTM(units, return_state=True, name=\"encoder_lstm\")(\n",
" enc_concat\n",
")\n",
"\n",
"\n",
"# ---------- DECODER : Question ----------\n",
"dec_q_inp = Input(shape=(MAX_Q,), name=\"dec_q_in\")\n",
"dec_emb_q = Embedding(len(vocab_q), d_tok, mask_zero=True, name=\"embedding_q_decoder\")(\n",
" dec_q_inp\n",
")\n",
"dec_q, _, _ = LSTM(\n",
" units, return_state=True, return_sequences=True, name=\"lstm_q_decoder\"\n",
")(dec_emb_q, initial_state=[state_h, state_c])\n",
"q_out = TimeDistributed(\n",
" Dense(len(vocab_q), activation=\"softmax\", name=\"dense_q_output\"), name=\"q_output\"\n",
")(dec_q)\n",
"\n",
"# ---------- DECODER : Answer ----------\n",
"dec_a_inp = Input(shape=(MAX_A,), name=\"dec_a_in\")\n",
"dec_emb_a = Embedding(len(vocab_a), d_tok, mask_zero=True, name=\"embedding_a_decoder\")(\n",
" dec_a_inp\n",
")\n",
"dec_a, _, _ = LSTM(\n",
" units, return_state=True, return_sequences=True, name=\"lstm_a_decoder\"\n",
")(dec_emb_a, initial_state=[state_h, state_c])\n",
"a_out = TimeDistributed(\n",
" Dense(len(vocab_a), activation=\"softmax\", name=\"dense_a_output\"), name=\"a_output\"\n",
")(dec_a)\n",
"\n",
"# ---------- CLASSIFIER : Question Type ----------\n",
"type_out = Dense(len(vocab_typ), activation=\"softmax\", name=\"type_output\")(enc_out)\n",
"\n",
"model = Model(\n",
" inputs=[inp_tok, inp_ner, inp_srl, dec_q_inp, dec_a_inp],\n",
" outputs=[q_out, a_out, type_out],\n",
")\n",
"\n",
"model.summary()"
]
},
{
"cell_type": "code",
"execution_count": 105,
"id": "fece1ae9",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/30\n",
"\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 69ms/step - a_output_loss: 4.6957 - a_output_sparse_categorical_accuracy: 0.1413 - loss: 10.3613 - q_output_loss: 5.3441 - q_output_sparse_categorical_accuracy: 0.0670 - type_output_accuracy: 0.4939 - type_output_loss: 1.0668 - val_a_output_loss: 4.5668 - val_a_output_sparse_categorical_accuracy: 0.2500 - val_loss: 10.1527 - val_q_output_loss: 5.3034 - val_q_output_sparse_categorical_accuracy: 0.1182 - val_type_output_accuracy: 0.7000 - val_type_output_loss: 0.9419\n",
"Epoch 2/30\n",
"\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 13ms/step - a_output_loss: 4.4732 - a_output_sparse_categorical_accuracy: 0.2500 - loss: 10.0049 - q_output_loss: 5.2579 - q_output_sparse_categorical_accuracy: 0.1001 - type_output_accuracy: 0.5591 - type_output_loss: 0.8840 - val_a_output_loss: 3.6835 - val_a_output_sparse_categorical_accuracy: 0.2500 - val_loss: 8.8571 - val_q_output_loss: 4.9562 - val_q_output_sparse_categorical_accuracy: 0.1000 - val_type_output_accuracy: 0.3833 - val_type_output_loss: 0.7245\n",
"Epoch 3/30\n",
"\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 13ms/step - a_output_loss: 3.4959 - a_output_sparse_categorical_accuracy: 0.2500 - loss: 8.4809 - q_output_loss: 4.7656 - q_output_sparse_categorical_accuracy: 0.1299 - type_output_accuracy: 0.5308 - type_output_loss: 0.7076 - val_a_output_loss: 2.4516 - val_a_output_sparse_categorical_accuracy: 0.2500 - val_loss: 6.9437 - val_q_output_loss: 4.2829 - val_q_output_sparse_categorical_accuracy: 0.1758 - val_type_output_accuracy: 0.3833 - val_type_output_loss: 0.6973\n",
"Epoch 4/30\n",
"\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 13ms/step - a_output_loss: 2.6025 - a_output_sparse_categorical_accuracy: 0.2620 - loss: 6.9572 - q_output_loss: 4.1438 - q_output_sparse_categorical_accuracy: 0.1763 - type_output_accuracy: 0.5010 - type_output_loss: 0.6949 - val_a_output_loss: 1.9285 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 5.9705 - val_q_output_loss: 3.8322 - val_q_output_sparse_categorical_accuracy: 0.1682 - val_type_output_accuracy: 0.3833 - val_type_output_loss: 0.6996\n",
"Epoch 5/30\n",
"\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 13ms/step - a_output_loss: 2.3723 - a_output_sparse_categorical_accuracy: 0.3693 - loss: 6.3040 - q_output_loss: 3.7185 - q_output_sparse_categorical_accuracy: 0.1705 - type_output_accuracy: 0.5228 - type_output_loss: 0.6954 - val_a_output_loss: 1.7395 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 5.5526 - val_q_output_loss: 3.6050 - val_q_output_sparse_categorical_accuracy: 0.1818 - val_type_output_accuracy: 0.4333 - val_type_output_loss: 0.6936\n",
"Epoch 6/30\n",
"\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 13ms/step - a_output_loss: 2.2466 - a_output_sparse_categorical_accuracy: 0.3660 - loss: 5.9522 - q_output_loss: 3.4987 - q_output_sparse_categorical_accuracy: 0.1843 - type_output_accuracy: 0.5311 - type_output_loss: 0.6928 - val_a_output_loss: 1.6680 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 5.3542 - val_q_output_loss: 3.4792 - val_q_output_sparse_categorical_accuracy: 0.2591 - val_type_output_accuracy: 0.6167 - val_type_output_loss: 0.6898\n",
"Epoch 7/30\n",
"\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 13ms/step - a_output_loss: 2.1706 - a_output_sparse_categorical_accuracy: 0.3748 - loss: 5.7833 - q_output_loss: 3.3972 - q_output_sparse_categorical_accuracy: 0.2581 - type_output_accuracy: 0.4983 - type_output_loss: 0.6943 - val_a_output_loss: 1.6312 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 5.2140 - val_q_output_loss: 3.3728 - val_q_output_sparse_categorical_accuracy: 0.2591 - val_type_output_accuracy: 0.3833 - val_type_output_loss: 0.7001\n",
"Epoch 8/30\n",
"\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 13ms/step - a_output_loss: 2.0769 - a_output_sparse_categorical_accuracy: 0.3709 - loss: 5.5707 - q_output_loss: 3.2835 - q_output_sparse_categorical_accuracy: 0.2579 - type_output_accuracy: 0.5165 - type_output_loss: 0.6925 - val_a_output_loss: 1.5953 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 5.0738 - val_q_output_loss: 3.2679 - val_q_output_sparse_categorical_accuracy: 0.2591 - val_type_output_accuracy: 0.3833 - val_type_output_loss: 0.7022\n",
"Epoch 9/30\n",
"\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 13ms/step - a_output_loss: 2.0811 - a_output_sparse_categorical_accuracy: 0.3700 - loss: 5.4626 - q_output_loss: 3.1783 - q_output_sparse_categorical_accuracy: 0.2586 - type_output_accuracy: 0.5294 - type_output_loss: 0.6897 - val_a_output_loss: 1.5771 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 4.9584 - val_q_output_loss: 3.1740 - val_q_output_sparse_categorical_accuracy: 0.2591 - val_type_output_accuracy: 0.5667 - val_type_output_loss: 0.6911\n",
"Epoch 10/30\n",
"\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 13ms/step - a_output_loss: 2.1436 - a_output_sparse_categorical_accuracy: 0.3635 - loss: 5.4602 - q_output_loss: 3.0972 - q_output_sparse_categorical_accuracy: 0.2603 - type_output_accuracy: 0.5954 - type_output_loss: 0.6906 - val_a_output_loss: 1.5554 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 4.8437 - val_q_output_loss: 3.0796 - val_q_output_sparse_categorical_accuracy: 0.2742 - val_type_output_accuracy: 0.4333 - val_type_output_loss: 0.6955\n",
"Epoch 11/30\n",
"\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 13ms/step - a_output_loss: 2.0414 - a_output_sparse_categorical_accuracy: 0.3698 - loss: 5.2346 - q_output_loss: 2.9916 - q_output_sparse_categorical_accuracy: 0.2940 - type_output_accuracy: 0.5497 - type_output_loss: 0.6878 - val_a_output_loss: 1.5271 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 4.7076 - val_q_output_loss: 2.9767 - val_q_output_sparse_categorical_accuracy: 0.3788 - val_type_output_accuracy: 0.6167 - val_type_output_loss: 0.6791\n",
"Epoch 12/30\n",
"\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 13ms/step - a_output_loss: 1.9689 - a_output_sparse_categorical_accuracy: 0.3747 - loss: 5.0683 - q_output_loss: 2.8947 - q_output_sparse_categorical_accuracy: 0.3778 - type_output_accuracy: 0.5652 - type_output_loss: 0.6839 - val_a_output_loss: 1.5182 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 4.6171 - val_q_output_loss: 2.8857 - val_q_output_sparse_categorical_accuracy: 0.3985 - val_type_output_accuracy: 0.3833 - val_type_output_loss: 0.7105\n",
"Epoch 13/30\n",
"\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 13ms/step - a_output_loss: 1.9573 - a_output_sparse_categorical_accuracy: 0.3675 - loss: 4.9542 - q_output_loss: 2.7911 - q_output_sparse_categorical_accuracy: 0.3790 - type_output_accuracy: 0.5301 - type_output_loss: 0.6814 - val_a_output_loss: 1.5033 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 4.5030 - val_q_output_loss: 2.7936 - val_q_output_sparse_categorical_accuracy: 0.3985 - val_type_output_accuracy: 0.5333 - val_type_output_loss: 0.6871\n",
"Epoch 14/30\n",
"\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 13ms/step - a_output_loss: 1.9229 - a_output_sparse_categorical_accuracy: 0.3745 - loss: 4.8469 - q_output_loss: 2.7145 - q_output_sparse_categorical_accuracy: 0.3810 - type_output_accuracy: 0.5883 - type_output_loss: 0.6817 - val_a_output_loss: 1.4761 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 4.3857 - val_q_output_loss: 2.7074 - val_q_output_sparse_categorical_accuracy: 0.3985 - val_type_output_accuracy: 0.6833 - val_type_output_loss: 0.6741\n",
"Epoch 15/30\n",
"\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 13ms/step - a_output_loss: 1.9625 - a_output_sparse_categorical_accuracy: 0.3703 - loss: 4.7880 - q_output_loss: 2.6232 - q_output_sparse_categorical_accuracy: 0.3790 - type_output_accuracy: 0.6521 - type_output_loss: 0.6739 - val_a_output_loss: 1.4591 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 4.2924 - val_q_output_loss: 2.6260 - val_q_output_sparse_categorical_accuracy: 0.3985 - val_type_output_accuracy: 0.5167 - val_type_output_loss: 0.6910\n",
"Epoch 16/30\n",
"\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 13ms/step - a_output_loss: 1.9763 - a_output_sparse_categorical_accuracy: 0.3661 - loss: 4.7331 - q_output_loss: 2.5511 - q_output_sparse_categorical_accuracy: 0.3784 - type_output_accuracy: 0.6424 - type_output_loss: 0.6679 - val_a_output_loss: 1.4432 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 4.1898 - val_q_output_loss: 2.5484 - val_q_output_sparse_categorical_accuracy: 0.4061 - val_type_output_accuracy: 0.6500 - val_type_output_loss: 0.6605\n",
"Epoch 17/30\n",
"\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 13ms/step - a_output_loss: 1.8243 - a_output_sparse_categorical_accuracy: 0.3712 - loss: 4.4878 - q_output_loss: 2.4615 - q_output_sparse_categorical_accuracy: 0.3942 - type_output_accuracy: 0.6414 - type_output_loss: 0.6622 - val_a_output_loss: 1.4167 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 4.0918 - val_q_output_loss: 2.4764 - val_q_output_sparse_categorical_accuracy: 0.4061 - val_type_output_accuracy: 0.6667 - val_type_output_loss: 0.6624\n",
"Epoch 18/30\n",
"\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 16ms/step - a_output_loss: 1.7940 - a_output_sparse_categorical_accuracy: 0.3750 - loss: 4.3987 - q_output_loss: 2.4076 - q_output_sparse_categorical_accuracy: 0.3949 - type_output_accuracy: 0.6488 - type_output_loss: 0.6552 - val_a_output_loss: 1.4059 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 4.0153 - val_q_output_loss: 2.4094 - val_q_output_sparse_categorical_accuracy: 0.4197 - val_type_output_accuracy: 0.6167 - val_type_output_loss: 0.6666\n",
"Epoch 19/30\n",
"\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 13ms/step - a_output_loss: 1.8070 - a_output_sparse_categorical_accuracy: 0.3722 - loss: 4.3404 - q_output_loss: 2.3366 - q_output_sparse_categorical_accuracy: 0.4081 - type_output_accuracy: 0.6783 - type_output_loss: 0.6439 - val_a_output_loss: 1.3923 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 3.9348 - val_q_output_loss: 2.3484 - val_q_output_sparse_categorical_accuracy: 0.4197 - val_type_output_accuracy: 0.6667 - val_type_output_loss: 0.6468\n",
"Epoch 20/30\n",
"\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 13ms/step - a_output_loss: 1.7701 - a_output_sparse_categorical_accuracy: 0.3707 - loss: 4.2359 - q_output_loss: 2.2700 - q_output_sparse_categorical_accuracy: 0.4075 - type_output_accuracy: 0.7000 - type_output_loss: 0.6303 - val_a_output_loss: 1.3730 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 3.8547 - val_q_output_loss: 2.2902 - val_q_output_sparse_categorical_accuracy: 0.4197 - val_type_output_accuracy: 0.6667 - val_type_output_loss: 0.6385\n",
"Epoch 21/30\n",
"\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 13ms/step - a_output_loss: 1.7757 - a_output_sparse_categorical_accuracy: 0.3722 - loss: 4.1755 - q_output_loss: 2.2117 - q_output_sparse_categorical_accuracy: 0.4081 - type_output_accuracy: 0.6587 - type_output_loss: 0.6294 - val_a_output_loss: 1.3681 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 3.7954 - val_q_output_loss: 2.2381 - val_q_output_sparse_categorical_accuracy: 0.4212 - val_type_output_accuracy: 0.6667 - val_type_output_loss: 0.6310\n",
"Epoch 22/30\n",
"\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 13ms/step - a_output_loss: 1.7072 - a_output_sparse_categorical_accuracy: 0.3720 - loss: 4.0450 - q_output_loss: 2.1546 - q_output_sparse_categorical_accuracy: 0.4168 - type_output_accuracy: 0.6542 - type_output_loss: 0.6179 - val_a_output_loss: 1.3568 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 3.7286 - val_q_output_loss: 2.1905 - val_q_output_sparse_categorical_accuracy: 0.4333 - val_type_output_accuracy: 0.7167 - val_type_output_loss: 0.6042\n",
"Epoch 23/30\n",
"\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 13ms/step - a_output_loss: 1.7266 - a_output_sparse_categorical_accuracy: 0.3729 - loss: 4.0193 - q_output_loss: 2.1160 - q_output_sparse_categorical_accuracy: 0.4216 - type_output_accuracy: 0.6948 - type_output_loss: 0.5943 - val_a_output_loss: 1.3405 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 3.6712 - val_q_output_loss: 2.1447 - val_q_output_sparse_categorical_accuracy: 0.4333 - val_type_output_accuracy: 0.6667 - val_type_output_loss: 0.6197\n",
"Epoch 24/30\n",
"\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 13ms/step - a_output_loss: 1.6801 - a_output_sparse_categorical_accuracy: 0.3773 - loss: 3.9227 - q_output_loss: 2.0804 - q_output_sparse_categorical_accuracy: 0.4252 - type_output_accuracy: 0.7457 - type_output_loss: 0.5462 - val_a_output_loss: 1.3245 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 3.6140 - val_q_output_loss: 2.1041 - val_q_output_sparse_categorical_accuracy: 0.4333 - val_type_output_accuracy: 0.6667 - val_type_output_loss: 0.6180\n",
"Epoch 25/30\n",
"\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 13ms/step - a_output_loss: 1.6800 - a_output_sparse_categorical_accuracy: 0.3685 - loss: 3.8625 - q_output_loss: 2.0179 - q_output_sparse_categorical_accuracy: 0.4229 - type_output_accuracy: 0.7322 - type_output_loss: 0.5264 - val_a_output_loss: 1.2944 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 3.5249 - val_q_output_loss: 2.0656 - val_q_output_sparse_categorical_accuracy: 0.4333 - val_type_output_accuracy: 0.7000 - val_type_output_loss: 0.5497\n",
"Epoch 26/30\n",
"\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 13ms/step - a_output_loss: 1.5786 - a_output_sparse_categorical_accuracy: 0.3724 - loss: 3.7165 - q_output_loss: 1.9876 - q_output_sparse_categorical_accuracy: 0.4253 - type_output_accuracy: 0.7853 - type_output_loss: 0.4956 - val_a_output_loss: 1.2680 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 3.4525 - val_q_output_loss: 2.0282 - val_q_output_sparse_categorical_accuracy: 0.4333 - val_type_output_accuracy: 0.7167 - val_type_output_loss: 0.5212\n",
"Epoch 27/30\n",
"\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 13ms/step - a_output_loss: 1.6078 - a_output_sparse_categorical_accuracy: 0.3703 - loss: 3.7078 - q_output_loss: 1.9502 - q_output_sparse_categorical_accuracy: 0.4227 - type_output_accuracy: 0.7938 - type_output_loss: 0.4730 - val_a_output_loss: 1.2467 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 3.3875 - val_q_output_loss: 1.9954 - val_q_output_sparse_categorical_accuracy: 0.4364 - val_type_output_accuracy: 0.7833 - val_type_output_loss: 0.4848\n",
"Epoch 28/30\n",
"\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 14ms/step - a_output_loss: 1.5627 - a_output_sparse_categorical_accuracy: 0.3749 - loss: 3.6136 - q_output_loss: 1.9109 - q_output_sparse_categorical_accuracy: 0.4312 - type_output_accuracy: 0.7908 - type_output_loss: 0.4706 - val_a_output_loss: 1.2410 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 3.3625 - val_q_output_loss: 1.9613 - val_q_output_sparse_categorical_accuracy: 0.4409 - val_type_output_accuracy: 0.7167 - val_type_output_loss: 0.5339\n",
"Epoch 29/30\n",
"\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 13ms/step - a_output_loss: 1.6156 - a_output_sparse_categorical_accuracy: 0.3713 - loss: 3.6313 - q_output_loss: 1.8821 - q_output_sparse_categorical_accuracy: 0.4360 - type_output_accuracy: 0.7989 - type_output_loss: 0.4600 - val_a_output_loss: 1.2161 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 3.2887 - val_q_output_loss: 1.9327 - val_q_output_sparse_categorical_accuracy: 0.4424 - val_type_output_accuracy: 0.7833 - val_type_output_loss: 0.4663\n",
"Epoch 30/30\n",
"\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 13ms/step - a_output_loss: 1.5910 - a_output_sparse_categorical_accuracy: 0.3707 - loss: 3.5723 - q_output_loss: 1.8495 - q_output_sparse_categorical_accuracy: 0.4395 - type_output_accuracy: 0.8028 - type_output_loss: 0.4485 - val_a_output_loss: 1.2147 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 3.2756 - val_q_output_loss: 1.9057 - val_q_output_sparse_categorical_accuracy: 0.4455 - val_type_output_accuracy: 0.7000 - val_type_output_loss: 0.5174\n"
]
}
],
"source": [
"losses = {\n",
" \"q_output\": \"sparse_categorical_crossentropy\",\n",
" \"a_output\": \"sparse_categorical_crossentropy\",\n",
" \"type_output\": \"sparse_categorical_crossentropy\",\n",
"}\n",
"loss_weights = {\"q_output\": 1.0, \"a_output\": 1.0, \"type_output\": 0.3}\n",
"\n",
"model.compile(\n",
" optimizer=\"adam\",\n",
" loss=losses,\n",
" loss_weights=loss_weights,\n",
" metrics={\n",
" \"q_output\": \"sparse_categorical_accuracy\",\n",
" \"a_output\": \"sparse_categorical_accuracy\",\n",
" \"type_output\": \"accuracy\",\n",
" },\n",
")\n",
"\n",
"history = model.fit(\n",
" [X_tok, X_ner, X_srl, dec_q_in, dec_a_in],\n",
" [dec_q_out, dec_a_out, y_type],\n",
" validation_split=0.1,\n",
" epochs=30,\n",
" batch_size=64,\n",
" callbacks=[tf.keras.callbacks.EarlyStopping(patience=4, restore_best_weights=True)],\n",
" verbose=1,\n",
")\n",
"\n",
"model.save(\"full_seq2seq.keras\")\n",
"\n",
"import json\n",
"import pickle\n",
"\n",
"# def save_vocab(vocab, path):\n",
"# with open(path, \"w\", encoding=\"utf-8\") as f:\n",
"# json.dump(vocab, f, ensure_ascii=False, indent=2)\n",
"\n",
"# # Simpan semua vocab\n",
"# save_vocab(vocab_tok, \"vocab_tok.json\")\n",
"# save_vocab(vocab_ner, \"vocab_ner.json\")\n",
"# save_vocab(vocab_srl, \"vocab_srl.json\")\n",
"# save_vocab(vocab_q, \"vocab_q.json\")\n",
"# save_vocab(vocab_a, \"vocab_a.json\")\n",
"# save_vocab(vocab_typ, \"vocab_typ.json\")\n",
"\n",
"\n",
"def save_vocab_pkl(vocab, path):\n",
" with open(path, \"wb\") as f:\n",
" pickle.dump(vocab, f)\n",
"\n",
"\n",
"# Simpan semua vocab\n",
"save_vocab_pkl(vocab_tok, \"vocab_tok.pkl\")\n",
"save_vocab_pkl(vocab_ner, \"vocab_ner.pkl\")\n",
"save_vocab_pkl(vocab_srl, \"vocab_srl.pkl\")\n",
"save_vocab_pkl(vocab_q, \"vocab_q.pkl\")\n",
"save_vocab_pkl(vocab_a, \"vocab_a.pkl\")\n",
"save_vocab_pkl(vocab_typ, \"vocab_typ.pkl\")"
]
},
{
"cell_type": "code",
"execution_count": 106,
"id": "3355c0c7",
"metadata": {},
"outputs": [],
"source": [
"import tensorflow as tf\n",
"import numpy as np\n",
"import pickle\n",
"from tensorflow.keras.models import load_model, Model\n",
"from tensorflow.keras.layers import Input, Concatenate\n",
"\n",
"# === Load Model Utama ===\n",
"model = load_model(\"full_seq2seq.keras\")\n",
"\n",
"\n",
"# === Load Vocabulary dari .pkl ===\n",
"def load_vocab(path):\n",
" with open(path, \"rb\") as f:\n",
" return pickle.load(f)\n",
"\n",
"\n",
"vocab_tok = load_vocab(\"vocab_tok.pkl\")\n",
"vocab_ner = load_vocab(\"vocab_ner.pkl\")\n",
"vocab_srl = load_vocab(\"vocab_srl.pkl\")\n",
"vocab_q = load_vocab(\"vocab_q.pkl\")\n",
"vocab_a = load_vocab(\"vocab_a.pkl\")\n",
"vocab_typ = load_vocab(\"vocab_typ.pkl\")\n",
"\n",
"inv_vocab_q = {v: k for k, v in vocab_q.items()}\n",
"inv_vocab_a = {v: k for k, v in vocab_a.items()}\n",
"\n",
"# === Build Encoder Model ===\n",
"MAX_SENT = model.input_shape[0][1] # Ambil shape dari model yang diload\n",
"MAX_Q = model.input_shape[3][1] # Max length for question\n",
"MAX_A = model.input_shape[4][1] # Max length for answer\n",
"\n",
"inp_tok_g = Input(shape=(MAX_SENT,), name=\"tok_in_g\")\n",
"inp_ner_g = Input(shape=(MAX_SENT,), name=\"ner_in_g\")\n",
"inp_srl_g = Input(shape=(MAX_SENT,), name=\"srl_in_g\")\n",
"\n",
"emb_tok = model.get_layer(\"embedding_tok\").call(inp_tok_g)\n",
"emb_ner = model.get_layer(\"embedding_ner\").call(inp_ner_g)\n",
"emb_srl = model.get_layer(\"embedding_srl\").call(inp_srl_g)\n",
"\n",
"enc_concat = Concatenate(name=\"concat_encoder\")([emb_tok, emb_ner, emb_srl])\n",
"\n",
"encoder_lstm = model.get_layer(\"encoder_lstm\")\n",
"enc_out, state_h, state_c = encoder_lstm(enc_concat)\n",
"\n",
"# Create encoder model with full output including enc_out\n",
"encoder_model = Model(\n",
" inputs=[inp_tok_g, inp_ner_g, inp_srl_g],\n",
" outputs=[enc_out, state_h, state_c],\n",
" name=\"encoder_model\",\n",
")\n",
"\n",
"# === Build Decoder for Question ===\n",
"dec_q_inp = Input(shape=(1,), name=\"dec_q_in\")\n",
"dec_emb_q = model.get_layer(\"embedding_q_decoder\").call(dec_q_inp)\n",
"\n",
"state_h_dec = Input(shape=(units,), name=\"state_h_dec\")\n",
"state_c_dec = Input(shape=(units,), name=\"state_c_dec\")\n",
"\n",
"lstm_decoder_q = model.get_layer(\"lstm_q_decoder\")\n",
"\n",
"dec_out_q, state_h_q, state_c_q = lstm_decoder_q(\n",
" dec_emb_q, initial_state=[state_h_dec, state_c_dec]\n",
")\n",
"\n",
"q_time_dist_layer = model.get_layer(\"q_output\")\n",
"dense_q = q_time_dist_layer.layer\n",
"q_output = dense_q(dec_out_q)\n",
"\n",
"decoder_q = Model(\n",
" inputs=[dec_q_inp, state_h_dec, state_c_dec],\n",
" outputs=[q_output, state_h_q, state_c_q],\n",
" name=\"decoder_question_model\",\n",
")\n",
"\n",
"# === Build Decoder for Answer ===\n",
"dec_a_inp = Input(shape=(1,), name=\"dec_a_in\")\n",
"dec_emb_a = model.get_layer(\"embedding_a_decoder\").call(dec_a_inp)\n",
"\n",
"state_h_a = Input(shape=(units,), name=\"state_h_a\")\n",
"state_c_a = Input(shape=(units,), name=\"state_c_a\")\n",
"\n",
"lstm_decoder_a = model.get_layer(\"lstm_a_decoder\")\n",
"\n",
"dec_out_a, state_h_a_out, state_c_a_out = lstm_decoder_a(\n",
" dec_emb_a, initial_state=[state_h_a, state_c_a]\n",
")\n",
"\n",
"a_time_dist_layer = model.get_layer(\"a_output\")\n",
"dense_a = a_time_dist_layer.layer\n",
"a_output = dense_a(dec_out_a)\n",
"\n",
"decoder_a = Model(\n",
" inputs=[dec_a_inp, state_h_a, state_c_a],\n",
" outputs=[a_output, state_h_a_out, state_c_a_out],\n",
" name=\"decoder_answer_model\",\n",
")\n",
"\n",
"# === Build Classifier for Question Type ===\n",
"type_dense = model.get_layer(\"type_output\")\n",
"type_out = type_dense(enc_out)\n",
"\n",
"classifier_model = Model(\n",
" inputs=[inp_tok_g, inp_ner_g, inp_srl_g], outputs=type_out, name=\"classifier_model\"\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 107,
"id": "d406e6ff",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Generated Question: dimana dimana lahir ___\n",
"Generated Answer : true\n",
"Question Type : true_false\n"
]
}
],
"source": [
"def encode(seq, vmap):\n",
" return [vmap.get(tok, vmap[\"<unk>\"]) for tok in seq]\n",
"\n",
"\n",
"def encode_and_pad(seq, vmap, max_len=MAX_SENT):\n",
" encoded = [vmap.get(tok, vmap[\"<unk>\"]) for tok in seq]\n",
" # Pad with vocab[\"<pad>\"] to the right if sequence is shorter than max_len\n",
" padded = encoded + [vmap[\"<pad>\"]] * (max_len - len(encoded))\n",
" return padded[:max_len] # Ensure it doesn't exceed max_len\n",
"\n",
"\n",
"def greedy_decode(tokens, ner, srl, max_q=20, max_a=10):\n",
" # --- encode encoder inputs -------------------------------------------\n",
" if isinstance(tokens, np.ndarray):\n",
" enc_tok = tokens\n",
" enc_ner = ner\n",
" enc_srl = srl\n",
" else:\n",
" enc_tok = np.array([encode_and_pad(tokens, vocab_tok, MAX_SENT)])\n",
" enc_ner = np.array([encode_and_pad(ner, vocab_ner, MAX_SENT)])\n",
" enc_srl = np.array([encode_and_pad(srl, vocab_srl, MAX_SENT)])\n",
"\n",
" # --- Get encoder outputs ---\n",
" enc_out, h, c = encoder_model.predict([enc_tok, enc_ner, enc_srl], verbose=0)\n",
"\n",
" # QUESTION Decoding\n",
" tgt = np.array([[vocab_q[\"<sos>\"]]])\n",
" question_ids = []\n",
" for _ in range(max_q):\n",
" logits, h, c = decoder_q.predict([tgt, h, c], verbose=0)\n",
" next_id = int(logits[0, 0].argmax()) # Get the predicted token ID\n",
" if next_id == vocab_q[\"<eos>\"]:\n",
" break\n",
" question_ids.append(next_id)\n",
" tgt = np.array([[next_id]]) # Feed the predicted token back as input\n",
"\n",
" # ANSWER Decoding - use encoder outputs again for fresh state\n",
" _, h, c = encoder_model.predict([enc_tok, enc_ner, enc_srl], verbose=0)\n",
" tgt = np.array([[vocab_a[\"<sos>\"]]])\n",
" answer_ids = []\n",
" for _ in range(max_a):\n",
" logits, h, c = decoder_a.predict([tgt, h, c], verbose=0)\n",
" next_id = int(logits[0, 0].argmax())\n",
" if next_id == vocab_a[\"<eos>\"]:\n",
" break\n",
" answer_ids.append(next_id)\n",
" tgt = np.array([[next_id]])\n",
"\n",
" # Question Type\n",
" qtype_logits = classifier_model.predict([enc_tok, enc_ner, enc_srl], verbose=0)\n",
" qtype_id = int(qtype_logits.argmax())\n",
"\n",
" # Final output\n",
" question = [inv_vocab_q.get(i, \"<unk>\") for i in question_ids]\n",
" answer = [inv_vocab_a.get(i, \"<unk>\") for i in answer_ids]\n",
" q_type = [k for k, v in vocab_typ.items() if v == qtype_id][0]\n",
"\n",
" return question, answer, q_type\n",
"\n",
"\n",
"def test_model():\n",
" test_data = {\n",
" \"tokens\": [\n",
" \"indra\",\n",
" \"maulana\",\n",
" \"lahir\",\n",
" \"di\",\n",
" \"semarang\",\n",
" \"pada\",\n",
" \"12\",\n",
" \"maret\",\n",
" \"1977\",\n",
" ],\n",
" \"ner\": [\"B-PER\", \"I-PER\", \"V\", \"O\", \"B-LOC\", \"O\", \"B-DATE\", \"I-DATE\", \"I-DATE\"],\n",
" \"srl\": [\n",
" \"ARG0\",\n",
" \"ARG0\",\n",
" \"V\",\n",
" \"O\",\n",
" \"ARGM-LOC\",\n",
" \"O\",\n",
" \"ARGM-TMP\",\n",
" \"ARGM-TMP\",\n",
" \"ARGM-TMP\",\n",
" ],\n",
" }\n",
" # tokens = [\n",
" # \"soekarno\",\n",
" # \"membacakan\",\n",
" # \"teks\",\n",
" # \"proklamasi\",\n",
" # \"pada\",\n",
" # \"17\",\n",
" # \"agustus\",\n",
" # \"1945\",\n",
" # ]\n",
" # ner_tags = [\"B-PER\", \"O\", \"O\", \"O\", \"O\", \"B-DATE\", \"I-DATE\", \"I-DATE\"]\n",
" # srl_tags = [\"ARG0\", \"V\", \"ARG1\", \"ARG1\", \"O\", \"ARGM-TMP\", \"ARGM-TMP\", \"ARGM-TMP\"]\n",
"\n",
" question, answer, q_type = greedy_decode(\n",
" test_data[\"tokens\"], test_data[\"ner\"], test_data[\"srl\"]\n",
" )\n",
" print(f\"Generated Question: {' '.join(question)}\")\n",
" print(f\"Generated Answer : {' '.join(answer)}\")\n",
" print(f\"Question Type : {q_type}\")\n",
"\n",
"\n",
"test_model()"
]
},
{
"cell_type": "code",
"execution_count": 108,
"id": "5adde3c3",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"BLEU : 0.0447\n",
"ROUGE1: 0.2281 | ROUGE-L: 0.2281\n"
]
}
],
"source": [
"from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction\n",
"from rouge_score import rouge_scorer\n",
"\n",
"smoothie = SmoothingFunction().method4\n",
"scorer = rouge_scorer.RougeScorer([\"rouge1\", \"rougeL\"], use_stemmer=True)\n",
"\n",
"\n",
"# Helper to strip special ids\n",
"def strip_special(ids, vocab):\n",
" pad = vocab[\"<pad>\"] if \"<pad>\" in vocab else None\n",
" eos = vocab[\"<eos>\"]\n",
" return [i for i in ids if i not in (pad, eos)]\n",
"\n",
"\n",
"def ids_to_text(ids, inv_vocab):\n",
" return \" \".join(inv_vocab[i] for i in ids)\n",
"\n",
"\n",
"# ---- evaluation over a set of indices ----\n",
"import random\n",
"\n",
"\n",
"def evaluate(indices=None):\n",
" if indices is None:\n",
" indices = random.sample(range(len(X_tok)), k=min(100, len(X_tok)))\n",
"\n",
" bleu_scores, rou1, rouL = [], [], []\n",
" for idx in indices:\n",
" # Ground truth\n",
" gt_q = strip_special(dec_q_out[idx], vocab_q)\n",
" gt_a = strip_special(dec_a_out[idx], vocab_a)\n",
" # Prediction\n",
" q_pred, a_pred, _ = greedy_decode(\n",
" X_tok[idx : idx + 1], X_ner[idx : idx + 1], X_srl[idx : idx + 1]\n",
" )\n",
"\n",
" # BLEU on question tokens\n",
" bleu_scores.append(\n",
" sentence_bleu(\n",
" [[inv_vocab_q[i] for i in gt_q]], q_pred, smoothing_function=smoothie\n",
" )\n",
" )\n",
" # ROUGE on question strings\n",
" r = scorer.score(ids_to_text(gt_q, inv_vocab_q), \" \".join(q_pred))\n",
" rou1.append(r[\"rouge1\"].fmeasure)\n",
" rouL.append(r[\"rougeL\"].fmeasure)\n",
"\n",
" print(f\"BLEU : {np.mean(bleu_scores):.4f}\")\n",
" print(f\"ROUGE1: {np.mean(rou1):.4f} | ROUGE-L: {np.mean(rouL):.4f}\")\n",
"\n",
"\n",
"evaluate()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "myenv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.16"
}
},
"nbformat": 4,
"nbformat_minor": 5
}