feat: adding dataset

This commit is contained in:
akhdanre 2025-05-13 13:42:02 +07:00
parent 1fd5a3dc95
commit ad4b6d6137
2 changed files with 492 additions and 99 deletions

View File

@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
"execution_count": 32,
"execution_count": null,
"id": "fb283f23",
"metadata": {},
"outputs": [
@ -20,7 +20,7 @@
"from itertools import chain\n",
"\n",
"RAW = json.loads(\n",
" Path(\"../dataset/dev_dataset_test.json\").read_text()\n",
" Path(\"../dataset/dev_dataset_qg.json\").read_text()\n",
") # ← file contoh Anda\n",
"\n",
"samples = []\n",
@ -46,7 +46,7 @@
},
{
"cell_type": "code",
"execution_count": 33,
"execution_count": 102,
"id": "fa4f979d",
"metadata": {},
"outputs": [
@ -80,7 +80,7 @@
},
{
"cell_type": "code",
"execution_count": 34,
"execution_count": 103,
"id": "d1a5b324",
"metadata": {},
"outputs": [],
@ -125,27 +125,27 @@
"dec_a_out = pad_sequences(\n",
" [encode(s[\"a_toks\"], vocab_a) for s in samples], maxlen=MAX_A, padding=\"post\"\n",
")\n",
"y_type = np.array([vocab_typ[s[\"q_type\"]] for s in samples])\n",
"\n",
"MAX_SENT = max(len(s[\"tokens\"]) for s in samples)\n",
"MAX_Q = max(len(s[\"q_toks\"]) for s in samples)\n",
"MAX_A = max(len(s[\"a_toks\"]) for s in samples)\n",
"y_type = np.array([vocab_typ[s[\"q_type\"]] for s in samples])"
"MAX_A = max(len(s[\"a_toks\"]) for s in samples)"
]
},
{
"cell_type": "code",
"execution_count": 35,
"execution_count": 104,
"id": "ff5bd85f",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">Model: \"functional_3\"</span>\n",
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">Model: \"functional_12\"</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1mModel: \"functional_3\"\u001b[0m\n"
"\u001b[1mModel: \"functional_12\"\u001b[0m\n"
]
},
"metadata": {},
@ -163,56 +163,56 @@
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ srl_in (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">InputLayer</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">9</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ - │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ embedding_tok │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">9</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">128</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">31,232</span> │ tok_in[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
"│ embedding_tok │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">9</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">32</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">7,808</span> │ tok_in[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Embedding</span>) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ embedding_ner │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">9</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">32</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">288</span> │ ner_in[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
"│ embedding_ner │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">9</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">16</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">144</span> │ ner_in[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Embedding</span>) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ embedding_srl │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">9</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">32</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">224</span> │ srl_in[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
"│ embedding_srl │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">9</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">16</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">112</span> │ srl_in[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Embedding</span>) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ dec_q_in │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">11</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ - │\n",
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">InputLayer</span>) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ concatenate_3 │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">9</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">192</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ embedding_tok[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>]… │\n",
"│ concatenate_14 │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">9</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">64</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ embedding_tok[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>]… │\n",
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Concatenate</span>) │ │ │ embedding_ner[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>]… │\n",
"│ │ │ │ embedding_srl[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>]… │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ dec_a_in │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">4</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ - │\n",
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">InputLayer</span>) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ embedding_q_decoder │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">11</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">128</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">27,008</span> │ dec_q_in[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
"│ embedding_q_decoder │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">11</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">32</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">6,752</span> │ dec_q_in[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Embedding</span>) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ encoder_lstm (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">LSTM</span>) │ [(<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">256</span>), │ <span style=\"color: #00af00; text-decoration-color: #00af00\">459,776</span> │ concatenate_3[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>]… │\n",
"│ │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">256</span>), │ │ │\n",
"│ │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">256</span>)] │ │ │\n",
"│ encoder_lstm (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">LSTM</span>) │ [(<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">64</span>), <span style=\"color: #00af00; text-decoration-color: #00af00\">33,024</span> │ concatenate_14[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>… │\n",
"│ │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">64</span>), │ │ │\n",
"│ │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">64</span>)] │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ embedding_a_decoder │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">4</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">128</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">14,336</span> │ dec_a_in[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
"│ embedding_a_decoder │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">4</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">32</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">3,584</span> │ dec_a_in[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Embedding</span>) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ lstm_q_decoder │ [(<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">11</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">256</span>), │ <span style=\"color: #00af00; text-decoration-color: #00af00\">394,240</span> │ embedding_q_deco… │\n",
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">LSTM</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">256</span>), │ │ encoder_lstm[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">…</span> │\n",
"│ │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">256</span>)] │ │ encoder_lstm[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">…</span> │\n",
"│ lstm_q_decoder │ [(<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">11</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">64</span>), <span style=\"color: #00af00; text-decoration-color: #00af00\">24,832</span> │ embedding_q_deco… │\n",
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">LSTM</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">64</span>), │ │ encoder_lstm[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">…</span> │\n",
"│ │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">64</span>)] │ │ encoder_lstm[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">…</span> │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ not_equal_12 │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">11</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ dec_q_in[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
"│ not_equal_51 │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">11</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ dec_q_in[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">NotEqual</span>) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ lstm_a_decoder │ [(<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">4</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">256</span>), │ <span style=\"color: #00af00; text-decoration-color: #00af00\">394,240</span> │ embedding_a_deco… │\n",
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">LSTM</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">256</span>), │ │ encoder_lstm[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">…</span> │\n",
"│ │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">256</span>)] │ │ encoder_lstm[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">…</span> │\n",
"│ lstm_a_decoder │ [(<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">4</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">64</span>), <span style=\"color: #00af00; text-decoration-color: #00af00\">24,832</span> │ embedding_a_deco… │\n",
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">LSTM</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">64</span>), │ │ encoder_lstm[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">…</span> │\n",
"│ │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">64</span>)] │ │ encoder_lstm[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">…</span> │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ not_equal_13 │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">4</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ dec_a_in[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
"│ not_equal_52 │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">4</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ dec_a_in[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">NotEqual</span>) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ q_output │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">11</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">211</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">54,227</span> │ lstm_q_decoder[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>… │\n",
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">TimeDistributed</span>) │ │ │ not_equal_12[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">…</span> │\n",
"│ q_output │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">11</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">211</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">13,715</span> │ lstm_q_decoder[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>… │\n",
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">TimeDistributed</span>) │ │ │ not_equal_51[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">…</span> │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ a_output │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">4</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">112</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">28,784</span> │ lstm_a_decoder[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>… │\n",
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">TimeDistributed</span>) │ │ │ not_equal_13[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">…</span> │\n",
"│ a_output │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">4</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">112</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">7,280</span> │ lstm_a_decoder[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>… │\n",
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">TimeDistributed</span>) │ │ │ not_equal_52[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">…</span> │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ type_output (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Dense</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">3</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">771</span> │ encoder_lstm[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">…</span> │\n",
"│ type_output (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Dense</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">3</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">195</span> │ encoder_lstm[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">…</span> │\n",
"└─────────────────────┴───────────────────┴────────────┴───────────────────┘\n",
"</pre>\n"
],
@ -226,56 +226,56 @@
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ srl_in (\u001b[38;5;33mInputLayer\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m9\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ embedding_tok │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m9\u001b[0m, \u001b[38;5;34m128\u001b[0m) │ \u001b[38;5;34m31,232\u001b[0m │ tok_in[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"│ embedding_tok │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m9\u001b[0m, \u001b[38;5;34m32\u001b[0m) │ \u001b[38;5;34m7,808\u001b[0m │ tok_in[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"│ (\u001b[38;5;33mEmbedding\u001b[0m) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ embedding_ner │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m9\u001b[0m, \u001b[38;5;34m32\u001b[0m) │ \u001b[38;5;34m288\u001b[0m │ ner_in[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"│ embedding_ner │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m9\u001b[0m, \u001b[38;5;34m16\u001b[0m) │ \u001b[38;5;34m144\u001b[0m │ ner_in[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"│ (\u001b[38;5;33mEmbedding\u001b[0m) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ embedding_srl │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m9\u001b[0m, \u001b[38;5;34m32\u001b[0m) │ \u001b[38;5;34m224\u001b[0m │ srl_in[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"│ embedding_srl │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m9\u001b[0m, \u001b[38;5;34m16\u001b[0m) │ \u001b[38;5;34m112\u001b[0m │ srl_in[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"│ (\u001b[38;5;33mEmbedding\u001b[0m) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ dec_q_in │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m11\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n",
"│ (\u001b[38;5;33mInputLayer\u001b[0m) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ concatenate_3 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m9\u001b[0m, \u001b[38;5;34m192\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ embedding_tok[\u001b[38;5;34m0\u001b[0m]… │\n",
"│ concatenate_14 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m9\u001b[0m, \u001b[38;5;34m64\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ embedding_tok[\u001b[38;5;34m0\u001b[0m]… │\n",
"│ (\u001b[38;5;33mConcatenate\u001b[0m) │ │ │ embedding_ner[\u001b[38;5;34m0\u001b[0m]… │\n",
"│ │ │ │ embedding_srl[\u001b[38;5;34m0\u001b[0m]… │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ dec_a_in │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m4\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n",
"│ (\u001b[38;5;33mInputLayer\u001b[0m) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ embedding_q_decoder │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m11\u001b[0m, \u001b[38;5;34m128\u001b[0m) │ \u001b[38;5;34m27,008\u001b[0m │ dec_q_in[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"│ embedding_q_decoder │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m11\u001b[0m, \u001b[38;5;34m32\u001b[0m) │ \u001b[38;5;34m6,752\u001b[0m │ dec_q_in[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"│ (\u001b[38;5;33mEmbedding\u001b[0m) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ encoder_lstm (\u001b[38;5;33mLSTM\u001b[0m) │ [(\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m256\u001b[0m), │ \u001b[38;5;34m459,776\u001b[0m │ concatenate_3[\u001b[38;5;34m0\u001b[0m]… │\n",
"│ │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m256\u001b[0m), │ │ │\n",
"│ │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m256\u001b[0m)] │ │ │\n",
"│ encoder_lstm (\u001b[38;5;33mLSTM\u001b[0m) │ [(\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m64\u001b[0m), │ \u001b[38;5;34m33,024\u001b[0m │ concatenate_14[\u001b[38;5;34m0\u001b[0m… │\n",
"│ │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m64\u001b[0m), │ │ │\n",
"│ │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m64\u001b[0m)] │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ embedding_a_decoder │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m4\u001b[0m, \u001b[38;5;34m128\u001b[0m) │ \u001b[38;5;34m14,336\u001b[0m │ dec_a_in[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"│ embedding_a_decoder │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m4\u001b[0m, \u001b[38;5;34m32\u001b[0m) │ \u001b[38;5;34m3,584\u001b[0m │ dec_a_in[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"│ (\u001b[38;5;33mEmbedding\u001b[0m) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ lstm_q_decoder │ [(\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m11\u001b[0m, \u001b[38;5;34m256\u001b[0m), │ \u001b[38;5;34m394,240\u001b[0m │ embedding_q_deco… │\n",
"│ (\u001b[38;5;33mLSTM\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m256\u001b[0m), │ │ encoder_lstm[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m…\u001b[0m │\n",
"│ │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m256\u001b[0m)] │ │ encoder_lstm[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m…\u001b[0m │\n",
"│ lstm_q_decoder │ [(\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m11\u001b[0m, \u001b[38;5;34m64\u001b[0m), │ \u001b[38;5;34m24,832\u001b[0m │ embedding_q_deco… │\n",
"│ (\u001b[38;5;33mLSTM\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m64\u001b[0m), │ │ encoder_lstm[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m…\u001b[0m │\n",
"│ │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m64\u001b[0m)] │ │ encoder_lstm[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m…\u001b[0m │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ not_equal_12 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m11\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ dec_q_in[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"│ not_equal_51 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m11\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ dec_q_in[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"│ (\u001b[38;5;33mNotEqual\u001b[0m) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ lstm_a_decoder │ [(\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m4\u001b[0m, \u001b[38;5;34m256\u001b[0m), │ \u001b[38;5;34m394,240\u001b[0m │ embedding_a_deco… │\n",
"│ (\u001b[38;5;33mLSTM\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m256\u001b[0m), │ │ encoder_lstm[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m…\u001b[0m │\n",
"│ │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m256\u001b[0m)] │ │ encoder_lstm[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m…\u001b[0m │\n",
"│ lstm_a_decoder │ [(\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m4\u001b[0m, \u001b[38;5;34m64\u001b[0m), │ \u001b[38;5;34m24,832\u001b[0m │ embedding_a_deco… │\n",
"│ (\u001b[38;5;33mLSTM\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m64\u001b[0m), │ │ encoder_lstm[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m…\u001b[0m │\n",
"│ │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m64\u001b[0m)] │ │ encoder_lstm[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m…\u001b[0m │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ not_equal_13 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m4\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ dec_a_in[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"│ not_equal_52 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m4\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ dec_a_in[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"│ (\u001b[38;5;33mNotEqual\u001b[0m) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ q_output │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m11\u001b[0m, \u001b[38;5;34m211\u001b[0m) │ \u001b[38;5;34m54,227\u001b[0m │ lstm_q_decoder[\u001b[38;5;34m0\u001b[0m… │\n",
"│ (\u001b[38;5;33mTimeDistributed\u001b[0m) │ │ │ not_equal_12[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m…\u001b[0m │\n",
"│ q_output │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m11\u001b[0m, \u001b[38;5;34m211\u001b[0m) │ \u001b[38;5;34m13,715\u001b[0m │ lstm_q_decoder[\u001b[38;5;34m0\u001b[0m… │\n",
"│ (\u001b[38;5;33mTimeDistributed\u001b[0m) │ │ │ not_equal_51[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m…\u001b[0m │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ a_output │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m4\u001b[0m, \u001b[38;5;34m112\u001b[0m) │ \u001b[38;5;34m28,784\u001b[0m │ lstm_a_decoder[\u001b[38;5;34m0\u001b[0m… │\n",
"│ (\u001b[38;5;33mTimeDistributed\u001b[0m) │ │ │ not_equal_13[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m…\u001b[0m │\n",
"│ a_output │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m4\u001b[0m, \u001b[38;5;34m112\u001b[0m) │ \u001b[38;5;34m7,280\u001b[0m │ lstm_a_decoder[\u001b[38;5;34m0\u001b[0m… │\n",
"│ (\u001b[38;5;33mTimeDistributed\u001b[0m) │ │ │ not_equal_52[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m…\u001b[0m │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ type_output (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m3\u001b[0m) │ \u001b[38;5;34m771\u001b[0m │ encoder_lstm[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m…\u001b[0m │\n",
"│ type_output (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m3\u001b[0m) │ \u001b[38;5;34m195\u001b[0m │ encoder_lstm[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m…\u001b[0m │\n",
"└─────────────────────┴───────────────────┴────────────┴───────────────────┘\n"
]
},
@ -285,11 +285,11 @@
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Total params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">1,405,126</span> (5.36 MB)\n",
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Total params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">122,278</span> (477.65 KB)\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m Total params: \u001b[0m\u001b[38;5;34m1,405,126\u001b[0m (5.36 MB)\n"
"\u001b[1m Total params: \u001b[0m\u001b[38;5;34m122,278\u001b[0m (477.65 KB)\n"
]
},
"metadata": {},
@ -298,11 +298,11 @@
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Trainable params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">1,405,126</span> (5.36 MB)\n",
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Trainable params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">122,278</span> (477.65 KB)\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m1,405,126\u001b[0m (5.36 MB)\n"
"\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m122,278\u001b[0m (477.65 KB)\n"
]
},
"metadata": {},
@ -335,9 +335,9 @@
"from tensorflow.keras.models import Model\n",
"\n",
"# ---- constants ---------------------------------------------------\n",
"d_tok = 128 # token embedding dim\n",
"d_tag = 32 # NER / SRL embedding dim\n",
"units = 256\n",
"d_tok = 32 # token embedding dim\n",
"d_tag = 16 # NER / SRL embedding dim\n",
"units = 64\n",
"\n",
"# ---- encoder -----------------------------------------------------\n",
"inp_tok = Input((MAX_SENT,), name=\"tok_in\")\n",
@ -399,7 +399,7 @@
},
{
"cell_type": "code",
"execution_count": 36,
"execution_count": 105,
"id": "fece1ae9",
"metadata": {},
"outputs": [
@ -408,61 +408,65 @@
"output_type": "stream",
"text": [
"Epoch 1/30\n",
"\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 83ms/step - a_output_loss: 4.5185 - a_output_sparse_categorical_accuracy: 0.1853 - loss: 10.1289 - q_output_loss: 5.2751 - q_output_sparse_categorical_accuracy: 0.1679 - type_output_accuracy: 0.3966 - type_output_loss: 1.0344 - val_a_output_loss: 2.2993 - val_a_output_sparse_categorical_accuracy: 0.2500 - val_loss: 6.6556 - val_q_output_loss: 4.1554 - val_q_output_sparse_categorical_accuracy: 0.1606 - val_type_output_accuracy: 0.6167 - val_type_output_loss: 0.6698\n",
"\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 69ms/step - a_output_loss: 4.6957 - a_output_sparse_categorical_accuracy: 0.1413 - loss: 10.3613 - q_output_loss: 5.3441 - q_output_sparse_categorical_accuracy: 0.0670 - type_output_accuracy: 0.4939 - type_output_loss: 1.0668 - val_a_output_loss: 4.5668 - val_a_output_sparse_categorical_accuracy: 0.2500 - val_loss: 10.1527 - val_q_output_loss: 5.3034 - val_q_output_sparse_categorical_accuracy: 0.1182 - val_type_output_accuracy: 0.7000 - val_type_output_loss: 0.9419\n",
"Epoch 2/30\n",
"\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 28ms/step - a_output_loss: 2.7148 - a_output_sparse_categorical_accuracy: 0.2674 - loss: 6.7993 - q_output_loss: 3.8706 - q_output_sparse_categorical_accuracy: 0.1625 - type_output_accuracy: 0.5096 - type_output_loss: 0.7111 - val_a_output_loss: 1.9687 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 5.6731 - val_q_output_loss: 3.4983 - val_q_output_sparse_categorical_accuracy: 0.2848 - val_type_output_accuracy: 0.8167 - val_type_output_loss: 0.6870\n",
"\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 13ms/step - a_output_loss: 4.4732 - a_output_sparse_categorical_accuracy: 0.2500 - loss: 10.0049 - q_output_loss: 5.2579 - q_output_sparse_categorical_accuracy: 0.1001 - type_output_accuracy: 0.5591 - type_output_loss: 0.8840 - val_a_output_loss: 3.6835 - val_a_output_sparse_categorical_accuracy: 0.2500 - val_loss: 8.8571 - val_q_output_loss: 4.9562 - val_q_output_sparse_categorical_accuracy: 0.1000 - val_type_output_accuracy: 0.3833 - val_type_output_loss: 0.7245\n",
"Epoch 3/30\n",
"\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 29ms/step - a_output_loss: 2.4494 - a_output_sparse_categorical_accuracy: 0.3225 - loss: 6.0211 - q_output_loss: 3.3610 - q_output_sparse_categorical_accuracy: 0.2195 - type_output_accuracy: 0.6316 - type_output_loss: 0.6890 - val_a_output_loss: 1.8200 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 5.1733 - val_q_output_loss: 3.1541 - val_q_output_sparse_categorical_accuracy: 0.2136 - val_type_output_accuracy: 0.6167 - val_type_output_loss: 0.6640\n",
"\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 13ms/step - a_output_loss: 3.4959 - a_output_sparse_categorical_accuracy: 0.2500 - loss: 8.4809 - q_output_loss: 4.7656 - q_output_sparse_categorical_accuracy: 0.1299 - type_output_accuracy: 0.5308 - type_output_loss: 0.7076 - val_a_output_loss: 2.4516 - val_a_output_sparse_categorical_accuracy: 0.2500 - val_loss: 6.9437 - val_q_output_loss: 4.2829 - val_q_output_sparse_categorical_accuracy: 0.1758 - val_type_output_accuracy: 0.3833 - val_type_output_loss: 0.6973\n",
"Epoch 4/30\n",
"\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 29ms/step - a_output_loss: 2.2031 - a_output_sparse_categorical_accuracy: 0.3726 - loss: 5.4427 - q_output_loss: 3.0396 - q_output_sparse_categorical_accuracy: 0.2869 - type_output_accuracy: 0.5478 - type_output_loss: 0.6786 - val_a_output_loss: 1.5951 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 4.6514 - val_q_output_loss: 2.8462 - val_q_output_sparse_categorical_accuracy: 0.3758 - val_type_output_accuracy: 0.3833 - val_type_output_loss: 0.7003\n",
"\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 13ms/step - a_output_loss: 2.6025 - a_output_sparse_categorical_accuracy: 0.2620 - loss: 6.9572 - q_output_loss: 4.1438 - q_output_sparse_categorical_accuracy: 0.1763 - type_output_accuracy: 0.5010 - type_output_loss: 0.6949 - val_a_output_loss: 1.9285 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 5.9705 - val_q_output_loss: 3.8322 - val_q_output_sparse_categorical_accuracy: 0.1682 - val_type_output_accuracy: 0.3833 - val_type_output_loss: 0.6996\n",
"Epoch 5/30\n",
"\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 30ms/step - a_output_loss: 2.0493 - a_output_sparse_categorical_accuracy: 0.3739 - loss: 4.9806 - q_output_loss: 2.7312 - q_output_sparse_categorical_accuracy: 0.3659 - type_output_accuracy: 0.5789 - type_output_loss: 0.6663 - val_a_output_loss: 1.5595 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 4.3330 - val_q_output_loss: 2.5741 - val_q_output_sparse_categorical_accuracy: 0.4197 - val_type_output_accuracy: 0.4833 - val_type_output_loss: 0.6650\n",
"\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 13ms/step - a_output_loss: 2.3723 - a_output_sparse_categorical_accuracy: 0.3693 - loss: 6.3040 - q_output_loss: 3.7185 - q_output_sparse_categorical_accuracy: 0.1705 - type_output_accuracy: 0.5228 - type_output_loss: 0.6954 - val_a_output_loss: 1.7395 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 5.5526 - val_q_output_loss: 3.6050 - val_q_output_sparse_categorical_accuracy: 0.1818 - val_type_output_accuracy: 0.4333 - val_type_output_loss: 0.6936\n",
"Epoch 6/30\n",
"\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 28ms/step - a_output_loss: 1.9310 - a_output_sparse_categorical_accuracy: 0.3769 - loss: 4.5682 - q_output_loss: 2.4562 - q_output_sparse_categorical_accuracy: 0.4147 - type_output_accuracy: 0.6767 - type_output_loss: 0.6249 - val_a_output_loss: 1.4318 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 3.9148 - val_q_output_loss: 2.3074 - val_q_output_sparse_categorical_accuracy: 0.4333 - val_type_output_accuracy: 0.8167 - val_type_output_loss: 0.5853\n",
"\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 13ms/step - a_output_loss: 2.2466 - a_output_sparse_categorical_accuracy: 0.3660 - loss: 5.9522 - q_output_loss: 3.4987 - q_output_sparse_categorical_accuracy: 0.1843 - type_output_accuracy: 0.5311 - type_output_loss: 0.6928 - val_a_output_loss: 1.6680 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 5.3542 - val_q_output_loss: 3.4792 - val_q_output_sparse_categorical_accuracy: 0.2591 - val_type_output_accuracy: 0.6167 - val_type_output_loss: 0.6898\n",
"Epoch 7/30\n",
"\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 28ms/step - a_output_loss: 1.8407 - a_output_sparse_categorical_accuracy: 0.3722 - loss: 4.1857 - q_output_loss: 2.1853 - q_output_sparse_categorical_accuracy: 0.4229 - type_output_accuracy: 0.7938 - type_output_loss: 0.5382 - val_a_output_loss: 1.3413 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 3.5705 - val_q_output_loss: 2.0979 - val_q_output_sparse_categorical_accuracy: 0.4515 - val_type_output_accuracy: 0.8500 - val_type_output_loss: 0.4377\n",
"\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 13ms/step - a_output_loss: 2.1706 - a_output_sparse_categorical_accuracy: 0.3748 - loss: 5.7833 - q_output_loss: 3.3972 - q_output_sparse_categorical_accuracy: 0.2581 - type_output_accuracy: 0.4983 - type_output_loss: 0.6943 - val_a_output_loss: 1.6312 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 5.2140 - val_q_output_loss: 3.3728 - val_q_output_sparse_categorical_accuracy: 0.2591 - val_type_output_accuracy: 0.3833 - val_type_output_loss: 0.7001\n",
"Epoch 8/30\n",
"\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 28ms/step - a_output_loss: 1.8257 - a_output_sparse_categorical_accuracy: 0.3606 - loss: 3.9578 - q_output_loss: 1.9912 - q_output_sparse_categorical_accuracy: 0.4426 - type_output_accuracy: 0.7711 - type_output_loss: 0.4701 - val_a_output_loss: 1.2847 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 3.3498 - val_q_output_loss: 1.9433 - val_q_output_sparse_categorical_accuracy: 0.4712 - val_type_output_accuracy: 0.8500 - val_type_output_loss: 0.4062\n",
"\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 13ms/step - a_output_loss: 2.0769 - a_output_sparse_categorical_accuracy: 0.3709 - loss: 5.5707 - q_output_loss: 3.2835 - q_output_sparse_categorical_accuracy: 0.2579 - type_output_accuracy: 0.5165 - type_output_loss: 0.6925 - val_a_output_loss: 1.5953 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 5.0738 - val_q_output_loss: 3.2679 - val_q_output_sparse_categorical_accuracy: 0.2591 - val_type_output_accuracy: 0.3833 - val_type_output_loss: 0.7022\n",
"Epoch 9/30\n",
"\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 29ms/step - a_output_loss: 1.6802 - a_output_sparse_categorical_accuracy: 0.3683 - loss: 3.6282 - q_output_loss: 1.8210 - q_output_sparse_categorical_accuracy: 0.4542 - type_output_accuracy: 0.7996 - type_output_loss: 0.4131 - val_a_output_loss: 1.2484 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 3.1973 - val_q_output_loss: 1.8318 - val_q_output_sparse_categorical_accuracy: 0.4667 - val_type_output_accuracy: 0.8500 - val_type_output_loss: 0.3905\n",
"\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 13ms/step - a_output_loss: 2.0811 - a_output_sparse_categorical_accuracy: 0.3700 - loss: 5.4626 - q_output_loss: 3.1783 - q_output_sparse_categorical_accuracy: 0.2586 - type_output_accuracy: 0.5294 - type_output_loss: 0.6897 - val_a_output_loss: 1.5771 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 4.9584 - val_q_output_loss: 3.1740 - val_q_output_sparse_categorical_accuracy: 0.2591 - val_type_output_accuracy: 0.5667 - val_type_output_loss: 0.6911\n",
"Epoch 10/30\n",
"\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 29ms/step - a_output_loss: 1.6225 - a_output_sparse_categorical_accuracy: 0.3738 - loss: 3.4468 - q_output_loss: 1.7095 - q_output_sparse_categorical_accuracy: 0.4516 - type_output_accuracy: 0.8104 - type_output_loss: 0.3915 - val_a_output_loss: 1.2075 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 3.0751 - val_q_output_loss: 1.7371 - val_q_output_sparse_categorical_accuracy: 0.4758 - val_type_output_accuracy: 0.8167 - val_type_output_loss: 0.4349\n",
"\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 13ms/step - a_output_loss: 2.1436 - a_output_sparse_categorical_accuracy: 0.3635 - loss: 5.4602 - q_output_loss: 3.0972 - q_output_sparse_categorical_accuracy: 0.2603 - type_output_accuracy: 0.5954 - type_output_loss: 0.6906 - val_a_output_loss: 1.5554 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 4.8437 - val_q_output_loss: 3.0796 - val_q_output_sparse_categorical_accuracy: 0.2742 - val_type_output_accuracy: 0.4333 - val_type_output_loss: 0.6955\n",
"Epoch 11/30\n",
"\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 28ms/step - a_output_loss: 1.5862 - a_output_sparse_categorical_accuracy: 0.3721 - loss: 3.3171 - q_output_loss: 1.6153 - q_output_sparse_categorical_accuracy: 0.4538 - type_output_accuracy: 0.8102 - type_output_loss: 0.3785 - val_a_output_loss: 1.1730 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 2.9582 - val_q_output_loss: 1.6659 - val_q_output_sparse_categorical_accuracy: 0.4697 - val_type_output_accuracy: 0.8500 - val_type_output_loss: 0.3979\n",
"\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 13ms/step - a_output_loss: 2.0414 - a_output_sparse_categorical_accuracy: 0.3698 - loss: 5.2346 - q_output_loss: 2.9916 - q_output_sparse_categorical_accuracy: 0.2940 - type_output_accuracy: 0.5497 - type_output_loss: 0.6878 - val_a_output_loss: 1.5271 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 4.7076 - val_q_output_loss: 2.9767 - val_q_output_sparse_categorical_accuracy: 0.3788 - val_type_output_accuracy: 0.6167 - val_type_output_loss: 0.6791\n",
"Epoch 12/30\n",
"\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 29ms/step - a_output_loss: 1.5335 - a_output_sparse_categorical_accuracy: 0.3664 - loss: 3.1939 - q_output_loss: 1.5335 - q_output_sparse_categorical_accuracy: 0.4518 - type_output_accuracy: 0.7775 - type_output_loss: 0.4105 - val_a_output_loss: 1.1304 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 2.8616 - val_q_output_loss: 1.6143 - val_q_output_sparse_categorical_accuracy: 0.4727 - val_type_output_accuracy: 0.8500 - val_type_output_loss: 0.3897\n",
"\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 13ms/step - a_output_loss: 1.9689 - a_output_sparse_categorical_accuracy: 0.3747 - loss: 5.0683 - q_output_loss: 2.8947 - q_output_sparse_categorical_accuracy: 0.3778 - type_output_accuracy: 0.5652 - type_output_loss: 0.6839 - val_a_output_loss: 1.5182 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 4.6171 - val_q_output_loss: 2.8857 - val_q_output_sparse_categorical_accuracy: 0.3985 - val_type_output_accuracy: 0.3833 - val_type_output_loss: 0.7105\n",
"Epoch 13/30\n",
"\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 28ms/step - a_output_loss: 1.4389 - a_output_sparse_categorical_accuracy: 0.3754 - loss: 3.0570 - q_output_loss: 1.5013 - q_output_sparse_categorical_accuracy: 0.4627 - type_output_accuracy: 0.8116 - type_output_loss: 0.3873 - val_a_output_loss: 1.1223 - val_a_output_sparse_categorical_accuracy: 0.4083 - val_loss: 2.8335 - val_q_output_loss: 1.5881 - val_q_output_sparse_categorical_accuracy: 0.4515 - val_type_output_accuracy: 0.8333 - val_type_output_loss: 0.4107\n",
"\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 13ms/step - a_output_loss: 1.9573 - a_output_sparse_categorical_accuracy: 0.3675 - loss: 4.9542 - q_output_loss: 2.7911 - q_output_sparse_categorical_accuracy: 0.3790 - type_output_accuracy: 0.5301 - type_output_loss: 0.6814 - val_a_output_loss: 1.5033 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 4.5030 - val_q_output_loss: 2.7936 - val_q_output_sparse_categorical_accuracy: 0.3985 - val_type_output_accuracy: 0.5333 - val_type_output_loss: 0.6871\n",
"Epoch 14/30\n",
"\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 30ms/step - a_output_loss: 1.4312 - a_output_sparse_categorical_accuracy: 0.3714 - loss: 2.9962 - q_output_loss: 1.4529 - q_output_sparse_categorical_accuracy: 0.4608 - type_output_accuracy: 0.7943 - type_output_loss: 0.3844 - val_a_output_loss: 1.1235 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 2.8020 - val_q_output_loss: 1.5603 - val_q_output_sparse_categorical_accuracy: 0.4652 - val_type_output_accuracy: 0.8500 - val_type_output_loss: 0.3939\n",
"\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 13ms/step - a_output_loss: 1.9229 - a_output_sparse_categorical_accuracy: 0.3745 - loss: 4.8469 - q_output_loss: 2.7145 - q_output_sparse_categorical_accuracy: 0.3810 - type_output_accuracy: 0.5883 - type_output_loss: 0.6817 - val_a_output_loss: 1.4761 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 4.3857 - val_q_output_loss: 2.7074 - val_q_output_sparse_categorical_accuracy: 0.3985 - val_type_output_accuracy: 0.6833 - val_type_output_loss: 0.6741\n",
"Epoch 15/30\n",
"\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 28ms/step - a_output_loss: 1.3670 - a_output_sparse_categorical_accuracy: 0.3787 - loss: 2.9339 - q_output_loss: 1.4391 - q_output_sparse_categorical_accuracy: 0.4609 - type_output_accuracy: 0.7846 - type_output_loss: 0.4245 - val_a_output_loss: 1.1046 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 2.7626 - val_q_output_loss: 1.5316 - val_q_output_sparse_categorical_accuracy: 0.4697 - val_type_output_accuracy: 0.8000 - val_type_output_loss: 0.4212\n",
"\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 13ms/step - a_output_loss: 1.9625 - a_output_sparse_categorical_accuracy: 0.3703 - loss: 4.7880 - q_output_loss: 2.6232 - q_output_sparse_categorical_accuracy: 0.3790 - type_output_accuracy: 0.6521 - type_output_loss: 0.6739 - val_a_output_loss: 1.4591 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 4.2924 - val_q_output_loss: 2.6260 - val_q_output_sparse_categorical_accuracy: 0.3985 - val_type_output_accuracy: 0.5167 - val_type_output_loss: 0.6910\n",
"Epoch 16/30\n",
"\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 28ms/step - a_output_loss: 1.3007 - a_output_sparse_categorical_accuracy: 0.3830 - loss: 2.8316 - q_output_loss: 1.4239 - q_output_sparse_categorical_accuracy: 0.4647 - type_output_accuracy: 0.8005 - type_output_loss: 0.3726 - val_a_output_loss: 1.1026 - val_a_output_sparse_categorical_accuracy: 0.4083 - val_loss: 2.7498 - val_q_output_loss: 1.5221 - val_q_output_sparse_categorical_accuracy: 0.4652 - val_type_output_accuracy: 0.8000 - val_type_output_loss: 0.4171\n",
"\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 13ms/step - a_output_loss: 1.9763 - a_output_sparse_categorical_accuracy: 0.3661 - loss: 4.7331 - q_output_loss: 2.5511 - q_output_sparse_categorical_accuracy: 0.3784 - type_output_accuracy: 0.6424 - type_output_loss: 0.6679 - val_a_output_loss: 1.4432 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 4.1898 - val_q_output_loss: 2.5484 - val_q_output_sparse_categorical_accuracy: 0.4061 - val_type_output_accuracy: 0.6500 - val_type_output_loss: 0.6605\n",
"Epoch 17/30\n",
"\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 31ms/step - a_output_loss: 1.3868 - a_output_sparse_categorical_accuracy: 0.3768 - loss: 2.8855 - q_output_loss: 1.3850 - q_output_sparse_categorical_accuracy: 0.4683 - type_output_accuracy: 0.8190 - type_output_loss: 0.3552 - val_a_output_loss: 1.0983 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 2.7381 - val_q_output_loss: 1.5079 - val_q_output_sparse_categorical_accuracy: 0.4636 - val_type_output_accuracy: 0.8000 - val_type_output_loss: 0.4401\n",
"\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 13ms/step - a_output_loss: 1.8243 - a_output_sparse_categorical_accuracy: 0.3712 - loss: 4.4878 - q_output_loss: 2.4615 - q_output_sparse_categorical_accuracy: 0.3942 - type_output_accuracy: 0.6414 - type_output_loss: 0.6622 - val_a_output_loss: 1.4167 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 4.0918 - val_q_output_loss: 2.4764 - val_q_output_sparse_categorical_accuracy: 0.4061 - val_type_output_accuracy: 0.6667 - val_type_output_loss: 0.6624\n",
"Epoch 18/30\n",
"\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 28ms/step - a_output_loss: 1.2897 - a_output_sparse_categorical_accuracy: 0.3837 - loss: 2.7760 - q_output_loss: 1.3759 - q_output_sparse_categorical_accuracy: 0.4749 - type_output_accuracy: 0.8087 - type_output_loss: 0.3802 - val_a_output_loss: 1.1044 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 2.7490 - val_q_output_loss: 1.5155 - val_q_output_sparse_categorical_accuracy: 0.4636 - val_type_output_accuracy: 0.8000 - val_type_output_loss: 0.4305\n",
"\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 16ms/step - a_output_loss: 1.7940 - a_output_sparse_categorical_accuracy: 0.3750 - loss: 4.3987 - q_output_loss: 2.4076 - q_output_sparse_categorical_accuracy: 0.3949 - type_output_accuracy: 0.6488 - type_output_loss: 0.6552 - val_a_output_loss: 1.4059 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 4.0153 - val_q_output_loss: 2.4094 - val_q_output_sparse_categorical_accuracy: 0.4197 - val_type_output_accuracy: 0.6167 - val_type_output_loss: 0.6666\n",
"Epoch 19/30\n",
"\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 30ms/step - a_output_loss: 1.3054 - a_output_sparse_categorical_accuracy: 0.3802 - loss: 2.7626 - q_output_loss: 1.3517 - q_output_sparse_categorical_accuracy: 0.4699 - type_output_accuracy: 0.8228 - type_output_loss: 0.3645 - val_a_output_loss: 1.0990 - val_a_output_sparse_categorical_accuracy: 0.4125 - val_loss: 2.7261 - val_q_output_loss: 1.4992 - val_q_output_sparse_categorical_accuracy: 0.4667 - val_type_output_accuracy: 0.8000 - val_type_output_loss: 0.4264\n",
"\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 13ms/step - a_output_loss: 1.8070 - a_output_sparse_categorical_accuracy: 0.3722 - loss: 4.3404 - q_output_loss: 2.3366 - q_output_sparse_categorical_accuracy: 0.4081 - type_output_accuracy: 0.6783 - type_output_loss: 0.6439 - val_a_output_loss: 1.3923 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 3.9348 - val_q_output_loss: 2.3484 - val_q_output_sparse_categorical_accuracy: 0.4197 - val_type_output_accuracy: 0.6667 - val_type_output_loss: 0.6468\n",
"Epoch 20/30\n",
"\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 28ms/step - a_output_loss: 1.2454 - a_output_sparse_categorical_accuracy: 0.3856 - loss: 2.7195 - q_output_loss: 1.3633 - q_output_sparse_categorical_accuracy: 0.4751 - type_output_accuracy: 0.8284 - type_output_loss: 0.3665 - val_a_output_loss: 1.1154 - val_a_output_sparse_categorical_accuracy: 0.4125 - val_loss: 2.7285 - val_q_output_loss: 1.4769 - val_q_output_sparse_categorical_accuracy: 0.4652 - val_type_output_accuracy: 0.7833 - val_type_output_loss: 0.4540\n",
"\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 13ms/step - a_output_loss: 1.7701 - a_output_sparse_categorical_accuracy: 0.3707 - loss: 4.2359 - q_output_loss: 2.2700 - q_output_sparse_categorical_accuracy: 0.4075 - type_output_accuracy: 0.7000 - type_output_loss: 0.6303 - val_a_output_loss: 1.3730 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 3.8547 - val_q_output_loss: 2.2902 - val_q_output_sparse_categorical_accuracy: 0.4197 - val_type_output_accuracy: 0.6667 - val_type_output_loss: 0.6385\n",
"Epoch 21/30\n",
"\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 29ms/step - a_output_loss: 1.2583 - a_output_sparse_categorical_accuracy: 0.3896 - loss: 2.7032 - q_output_loss: 1.3383 - q_output_sparse_categorical_accuracy: 0.4719 - type_output_accuracy: 0.8246 - type_output_loss: 0.3594 - val_a_output_loss: 1.1381 - val_a_output_sparse_categorical_accuracy: 0.4125 - val_loss: 2.8064 - val_q_output_loss: 1.4643 - val_q_output_sparse_categorical_accuracy: 0.4788 - val_type_output_accuracy: 0.6500 - val_type_output_loss: 0.6802\n",
"\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 13ms/step - a_output_loss: 1.7757 - a_output_sparse_categorical_accuracy: 0.3722 - loss: 4.1755 - q_output_loss: 2.2117 - q_output_sparse_categorical_accuracy: 0.4081 - type_output_accuracy: 0.6587 - type_output_loss: 0.6294 - val_a_output_loss: 1.3681 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 3.7954 - val_q_output_loss: 2.2381 - val_q_output_sparse_categorical_accuracy: 0.4212 - val_type_output_accuracy: 0.6667 - val_type_output_loss: 0.6310\n",
"Epoch 22/30\n",
"\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 28ms/step - a_output_loss: 1.3214 - a_output_sparse_categorical_accuracy: 0.3743 - loss: 2.7428 - q_output_loss: 1.3119 - q_output_sparse_categorical_accuracy: 0.4669 - type_output_accuracy: 0.8213 - type_output_loss: 0.3563 - val_a_output_loss: 1.1002 - val_a_output_sparse_categorical_accuracy: 0.4125 - val_loss: 2.7029 - val_q_output_loss: 1.4733 - val_q_output_sparse_categorical_accuracy: 0.4727 - val_type_output_accuracy: 0.8333 - val_type_output_loss: 0.4315\n",
"\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 13ms/step - a_output_loss: 1.7072 - a_output_sparse_categorical_accuracy: 0.3720 - loss: 4.0450 - q_output_loss: 2.1546 - q_output_sparse_categorical_accuracy: 0.4168 - type_output_accuracy: 0.6542 - type_output_loss: 0.6179 - val_a_output_loss: 1.3568 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 3.7286 - val_q_output_loss: 2.1905 - val_q_output_sparse_categorical_accuracy: 0.4333 - val_type_output_accuracy: 0.7167 - val_type_output_loss: 0.6042\n",
"Epoch 23/30\n",
"\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 29ms/step - a_output_loss: 1.2769 - a_output_sparse_categorical_accuracy: 0.3907 - loss: 2.7114 - q_output_loss: 1.3188 - q_output_sparse_categorical_accuracy: 0.4772 - type_output_accuracy: 0.8281 - type_output_loss: 0.3824 - val_a_output_loss: 1.1039 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 2.7224 - val_q_output_loss: 1.4924 - val_q_output_sparse_categorical_accuracy: 0.4727 - val_type_output_accuracy: 0.8167 - val_type_output_loss: 0.4205\n",
"\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 13ms/step - a_output_loss: 1.7266 - a_output_sparse_categorical_accuracy: 0.3729 - loss: 4.0193 - q_output_loss: 2.1160 - q_output_sparse_categorical_accuracy: 0.4216 - type_output_accuracy: 0.6948 - type_output_loss: 0.5943 - val_a_output_loss: 1.3405 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 3.6712 - val_q_output_loss: 2.1447 - val_q_output_sparse_categorical_accuracy: 0.4333 - val_type_output_accuracy: 0.6667 - val_type_output_loss: 0.6197\n",
"Epoch 24/30\n",
"\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 29ms/step - a_output_loss: 1.1957 - a_output_sparse_categorical_accuracy: 0.3969 - loss: 2.6335 - q_output_loss: 1.3299 - q_output_sparse_categorical_accuracy: 0.4776 - type_output_accuracy: 0.8342 - type_output_loss: 0.3683 - val_a_output_loss: 1.0974 - val_a_output_sparse_categorical_accuracy: 0.4167 - val_loss: 2.7003 - val_q_output_loss: 1.4677 - val_q_output_sparse_categorical_accuracy: 0.4667 - val_type_output_accuracy: 0.8000 - val_type_output_loss: 0.4504\n",
"\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 13ms/step - a_output_loss: 1.6801 - a_output_sparse_categorical_accuracy: 0.3773 - loss: 3.9227 - q_output_loss: 2.0804 - q_output_sparse_categorical_accuracy: 0.4252 - type_output_accuracy: 0.7457 - type_output_loss: 0.5462 - val_a_output_loss: 1.3245 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 3.6140 - val_q_output_loss: 2.1041 - val_q_output_sparse_categorical_accuracy: 0.4333 - val_type_output_accuracy: 0.6667 - val_type_output_loss: 0.6180\n",
"Epoch 25/30\n",
"\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 28ms/step - a_output_loss: 1.2393 - a_output_sparse_categorical_accuracy: 0.3898 - loss: 2.6442 - q_output_loss: 1.2937 - q_output_sparse_categorical_accuracy: 0.4804 - type_output_accuracy: 0.8273 - type_output_loss: 0.3564 - val_a_output_loss: 1.1525 - val_a_output_sparse_categorical_accuracy: 0.4125 - val_loss: 2.7743 - val_q_output_loss: 1.4611 - val_q_output_sparse_categorical_accuracy: 0.4682 - val_type_output_accuracy: 0.7833 - val_type_output_loss: 0.5356\n",
"\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 13ms/step - a_output_loss: 1.6800 - a_output_sparse_categorical_accuracy: 0.3685 - loss: 3.8625 - q_output_loss: 2.0179 - q_output_sparse_categorical_accuracy: 0.4229 - type_output_accuracy: 0.7322 - type_output_loss: 0.5264 - val_a_output_loss: 1.2944 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 3.5249 - val_q_output_loss: 2.0656 - val_q_output_sparse_categorical_accuracy: 0.4333 - val_type_output_accuracy: 0.7000 - val_type_output_loss: 0.5497\n",
"Epoch 26/30\n",
"\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 29ms/step - a_output_loss: 1.1872 - a_output_sparse_categorical_accuracy: 0.4001 - loss: 2.5917 - q_output_loss: 1.2930 - q_output_sparse_categorical_accuracy: 0.4802 - type_output_accuracy: 0.8280 - type_output_loss: 0.3555 - val_a_output_loss: 1.1505 - val_a_output_sparse_categorical_accuracy: 0.4083 - val_loss: 2.7684 - val_q_output_loss: 1.4587 - val_q_output_sparse_categorical_accuracy: 0.4742 - val_type_output_accuracy: 0.8000 - val_type_output_loss: 0.5307\n",
"\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 13ms/step - a_output_loss: 1.5786 - a_output_sparse_categorical_accuracy: 0.3724 - loss: 3.7165 - q_output_loss: 1.9876 - q_output_sparse_categorical_accuracy: 0.4253 - type_output_accuracy: 0.7853 - type_output_loss: 0.4956 - val_a_output_loss: 1.2680 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 3.4525 - val_q_output_loss: 2.0282 - val_q_output_sparse_categorical_accuracy: 0.4333 - val_type_output_accuracy: 0.7167 - val_type_output_loss: 0.5212\n",
"Epoch 27/30\n",
"\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 28ms/step - a_output_loss: 1.1337 - a_output_sparse_categorical_accuracy: 0.4077 - loss: 2.5274 - q_output_loss: 1.2917 - q_output_sparse_categorical_accuracy: 0.4856 - type_output_accuracy: 0.8328 - type_output_loss: 0.3424 - val_a_output_loss: 1.1274 - val_a_output_sparse_categorical_accuracy: 0.4208 - val_loss: 2.7139 - val_q_output_loss: 1.4500 - val_q_output_sparse_categorical_accuracy: 0.4788 - val_type_output_accuracy: 0.8000 - val_type_output_loss: 0.4548\n",
"\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 13ms/step - a_output_loss: 1.6078 - a_output_sparse_categorical_accuracy: 0.3703 - loss: 3.7078 - q_output_loss: 1.9502 - q_output_sparse_categorical_accuracy: 0.4227 - type_output_accuracy: 0.7938 - type_output_loss: 0.4730 - val_a_output_loss: 1.2467 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 3.3875 - val_q_output_loss: 1.9954 - val_q_output_sparse_categorical_accuracy: 0.4364 - val_type_output_accuracy: 0.7833 - val_type_output_loss: 0.4848\n",
"Epoch 28/30\n",
"\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 30ms/step - a_output_loss: 1.1081 - a_output_sparse_categorical_accuracy: 0.4104 - loss: 2.4859 - q_output_loss: 1.2903 - q_output_sparse_categorical_accuracy: 0.4845 - type_output_accuracy: 0.8714 - type_output_loss: 0.2936 - val_a_output_loss: 1.1394 - val_a_output_sparse_categorical_accuracy: 0.4167 - val_loss: 2.7244 - val_q_output_loss: 1.4512 - val_q_output_sparse_categorical_accuracy: 0.4652 - val_type_output_accuracy: 0.8167 - val_type_output_loss: 0.4457\n"
"\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 14ms/step - a_output_loss: 1.5627 - a_output_sparse_categorical_accuracy: 0.3749 - loss: 3.6136 - q_output_loss: 1.9109 - q_output_sparse_categorical_accuracy: 0.4312 - type_output_accuracy: 0.7908 - type_output_loss: 0.4706 - val_a_output_loss: 1.2410 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 3.3625 - val_q_output_loss: 1.9613 - val_q_output_sparse_categorical_accuracy: 0.4409 - val_type_output_accuracy: 0.7167 - val_type_output_loss: 0.5339\n",
"Epoch 29/30\n",
"\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 13ms/step - a_output_loss: 1.6156 - a_output_sparse_categorical_accuracy: 0.3713 - loss: 3.6313 - q_output_loss: 1.8821 - q_output_sparse_categorical_accuracy: 0.4360 - type_output_accuracy: 0.7989 - type_output_loss: 0.4600 - val_a_output_loss: 1.2161 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 3.2887 - val_q_output_loss: 1.9327 - val_q_output_sparse_categorical_accuracy: 0.4424 - val_type_output_accuracy: 0.7833 - val_type_output_loss: 0.4663\n",
"Epoch 30/30\n",
"\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 13ms/step - a_output_loss: 1.5910 - a_output_sparse_categorical_accuracy: 0.3707 - loss: 3.5723 - q_output_loss: 1.8495 - q_output_sparse_categorical_accuracy: 0.4395 - type_output_accuracy: 0.8028 - type_output_loss: 0.4485 - val_a_output_loss: 1.2147 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 3.2756 - val_q_output_loss: 1.9057 - val_q_output_sparse_categorical_accuracy: 0.4455 - val_type_output_accuracy: 0.7000 - val_type_output_loss: 0.5174\n"
]
}
],
@ -529,7 +533,7 @@
},
{
"cell_type": "code",
"execution_count": 37,
"execution_count": 106,
"id": "3355c0c7",
"metadata": {},
"outputs": [],
@ -589,8 +593,8 @@
"dec_q_inp = Input(shape=(1,), name=\"dec_q_in\")\n",
"dec_emb_q = model.get_layer(\"embedding_q_decoder\").call(dec_q_inp)\n",
"\n",
"state_h_dec = Input(shape=(256,), name=\"state_h_dec\")\n",
"state_c_dec = Input(shape=(256,), name=\"state_c_dec\")\n",
"state_h_dec = Input(shape=(units,), name=\"state_h_dec\")\n",
"state_c_dec = Input(shape=(units,), name=\"state_c_dec\")\n",
"\n",
"lstm_decoder_q = model.get_layer(\"lstm_q_decoder\")\n",
"\n",
@ -612,8 +616,8 @@
"dec_a_inp = Input(shape=(1,), name=\"dec_a_in\")\n",
"dec_emb_a = model.get_layer(\"embedding_a_decoder\").call(dec_a_inp)\n",
"\n",
"state_h_a = Input(shape=(256,), name=\"state_h_a\")\n",
"state_c_a = Input(shape=(256,), name=\"state_c_a\")\n",
"state_h_a = Input(shape=(units,), name=\"state_h_a\")\n",
"state_c_a = Input(shape=(units,), name=\"state_c_a\")\n",
"\n",
"lstm_decoder_a = model.get_layer(\"lstm_a_decoder\")\n",
"\n",
@ -642,7 +646,7 @@
},
{
"cell_type": "code",
"execution_count": 38,
"execution_count": 107,
"id": "d406e6ff",
"metadata": {},
"outputs": [
@ -650,7 +654,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Generated Question: gilang maulana lahir di semarang pada 26 november 1983 ___\n",
"Generated Question: dimana dimana lahir ___\n",
"Generated Answer : true\n",
"Question Type : true_false\n"
]
@ -769,7 +773,7 @@
},
{
"cell_type": "code",
"execution_count": 39,
"execution_count": 108,
"id": "5adde3c3",
"metadata": {},
"outputs": [
@ -777,8 +781,8 @@
"name": "stdout",
"output_type": "stream",
"text": [
"BLEU : 0.1185\n",
"ROUGE1: 0.3967 | ROUGE-L: 0.3967\n"
"BLEU : 0.0447\n",
"ROUGE1: 0.2281 | ROUGE-L: 0.2281\n"
]
}
],

389
question_generation/uji.py Normal file
View File

@ -0,0 +1,389 @@
# ===============================================================
# Seq2SeqLSTM + Luong Attention untuk QuestionAnswer Generator
# + Greedy & Beam Search decoding + BLEU4 evaluation
# ===============================================================
# • Semua embedding mask_zero=True (padding dimask)
# • Encoder = Bidirectional LSTM (return_sequences=True)
# • Decoder = LSTM + Luong Attention (keras.layers.Attention).
# • Greedy & beamsearch inference submodel dibangun terpisah (encoder,
# decoderQstep, decoderAstep).
# • BLEU score (nltk.corpus_bleu) untuk evaluasi pertanyaan & jawaban.
# ---------------------------------------------------------------
# PETUNJUK
# 1. pip install nltk
# 2. python seq2seq_qa_attention.py # train + simpan model
# 3. jalankan fungsi evaluate_bleu() # hitung BLEU di validation/test
# ===============================================================
import json
from pathlib import Path
from itertools import chain
import numpy as np
import tensorflow as tf
from tensorflow.keras.preprocessing.sequence import pad_sequences
from tensorflow.keras.layers import (
Input, Embedding, LSTM, Bidirectional, Concatenate,
Dense, TimeDistributed, Attention
)
from tensorflow.keras.models import Model
from nltk.translate.bleu_score import corpus_bleu # pip install nltk
# ----------------------- 1. Load & flatten data ----------------------------
RAW = json.loads(Path("../dataset/dev_dataset_test.json").read_text())
samples = []
for item in RAW:
for qp in item["quiz_posibility"]:
samp = {
"tokens": [t.lower() for t in item["tokens"]],
"ner": item["ner"],
"srl": item["srl"],
"q_type": qp["type"],
"q_toks": [t.lower() for t in qp["question"]] + ["<eos>"],
}
if isinstance(qp["answer"], list):
samp["a_toks"] = [t.lower() for t in qp["answer"]] + ["<eos>"]
else:
samp["a_toks"] = [qp["answer"].lower(), "<eos>"]
samples.append(samp)
print("Total flattened samples:", len(samples))
# ----------------------- 2. Build vocabularies -----------------------------
def build_vocab(seq_iter, reserved=("<pad>", "<unk>", "<sos>", "<eos>")):
vocab = {tok: idx for idx, tok in enumerate(reserved)}
for tok in chain.from_iterable(seq_iter):
if tok not in vocab:
vocab[tok] = len(vocab)
return vocab
v_tok = build_vocab((s["tokens"] for s in samples))
v_ner = build_vocab((s["ner"] for s in samples), reserved=("<pad>", "<unk>"))
v_srl = build_vocab((s["srl"] for s in samples), reserved=("<pad>", "<unk>"))
v_q = build_vocab((s["q_toks"] for s in samples))
v_a = build_vocab((s["a_toks"] for s in samples))
v_typ = {"isian": 0, "opsi": 1, "true_false": 2}
iv_q = {i: t for t, i in v_q.items()}
iv_a = {i: t for t, i in v_a.items()}
# ----------------------- 3. Vectorise + pad -------------------------------
def encode(seq, vmap):
return [vmap.get(tok, vmap["<unk>"]) for tok in seq]
MAX_SENT = max(len(s["tokens"]) for s in samples)
MAX_Q = max(len(s["q_toks"]) for s in samples)
MAX_A = max(len(s["a_toks"]) for s in samples)
X_tok_ids = pad_sequences([encode(s["tokens"], v_tok) for s in samples],
maxlen=MAX_SENT, padding="post")
X_ner_ids = pad_sequences([encode(s["ner"], v_ner) for s in samples],
maxlen=MAX_SENT, padding="post")
X_srl_ids = pad_sequences([encode(s["srl"], v_srl) for s in samples],
maxlen=MAX_SENT, padding="post")
q_in_ids = pad_sequences([[v_q["<sos>"], *encode(s["q_toks"][:-1], v_q)]
for s in samples], maxlen=MAX_Q, padding="post")
q_out_ids = pad_sequences([encode(s["q_toks"], v_q) for s in samples],
maxlen=MAX_Q, padding="post")
a_in_ids = pad_sequences([[v_a["<sos>"], *encode(s["a_toks"][:-1], v_a)]
for s in samples], maxlen=MAX_A, padding="post")
a_out_ids = pad_sequences([encode(s["a_toks"], v_a) for s in samples],
maxlen=MAX_A, padding="post")
y_type_ids = np.array([v_typ[s["q_type"]] for s in samples])
# ----------------------- 4. Hyperparams ----------------------------------
d_tok = 32 # token embedding dim
d_tag = 16 # NER / SRL embedding dim
units = 64 # per direction of BiLSTM
lat_dim = units * 2
# ----------------------- 5. Build model -----------------------------------
# Encoder ----------------------------------------------------------
tok_in = Input((MAX_SENT,), dtype="int32", name="tok_in")
ner_in = Input((MAX_SENT,), dtype="int32", name="ner_in")
srl_in = Input((MAX_SENT,), dtype="int32", name="srl_in")
emb_tok = Embedding(len(v_tok), d_tok, mask_zero=True, name="emb_tok")(tok_in)
emb_ner = Embedding(len(v_ner), d_tag, mask_zero=True, name="emb_ner")(ner_in)
emb_srl = Embedding(len(v_srl), d_tag, mask_zero=True, name="emb_srl")(srl_in)
enc_concat = Concatenate(name="enc_concat")([emb_tok, emb_ner, emb_srl])
bi_lstm = Bidirectional(LSTM(units, return_sequences=True, return_state=True),
name="encoder_bi_lstm")
enc_seq, f_h, f_c, b_h, b_c = bi_lstm(enc_concat)
enc_h = Concatenate()( [f_h, b_h] ) # (B, lat_dim)
enc_c = Concatenate()( [f_c, b_c] )
# Decoder  QUESTION ----------------------------------------------
q_in = Input((MAX_Q,), dtype="int32", name="q_in")
# 💡 mask_zero=False supaya Attention tidak bentrok dengan mask encoder
q_emb = Embedding(len(v_q), d_tok, mask_zero=False, name="q_emb")(q_in)
dec_q_lstm = LSTM(lat_dim, return_sequences=True, return_state=True,
name="decoder_q_lstm")
q_seq, q_h, q_c = dec_q_lstm(q_emb, initial_state=[enc_h, enc_c])
enc_proj_q = TimeDistributed(Dense(lat_dim), name="enc_proj_q")(enc_seq)
attn_q = Attention(name="attn_q")([q_seq, enc_proj_q])
q_concat = Concatenate(name="q_concat")([q_seq, attn_q])
q_out = TimeDistributed(Dense(len(v_q), activation="softmax"), name="q_out")(q_concat)
# Decoder  ANSWER -------------------------------------------------
a_in = Input((MAX_A,), dtype="int32", name="a_in")
# juga mask_zero=False
a_emb = Embedding(len(v_a), d_tok, mask_zero=False, name="a_emb")(a_in)
dec_a_lstm = LSTM(lat_dim, return_sequences=True, return_state=True,
name="decoder_a_lstm")
a_seq, _, _ = dec_a_lstm(a_emb, initial_state=[q_h, q_c])
enc_proj_a = TimeDistributed(Dense(lat_dim), name="enc_proj_a")(enc_seq)
attn_a = Attention(name="attn_a")([a_seq, enc_proj_a])
a_concat = Concatenate(name="a_concat")([a_seq, attn_a])
a_out = TimeDistributed(Dense(len(v_a), activation="softmax"), name="a_out")(a_concat)
# Classifier -------------------------------------------------------
type_dense = Dense(len(v_typ), activation="softmax", name="type_out")(enc_h)
model = Model(inputs=[tok_in, ner_in, srl_in, q_in, a_in],
outputs=[q_out, a_out, type_dense])
model.summary()
# ----------------------- 6. Compile & train ------------------------------
losses = {
"q_out": "sparse_categorical_crossentropy",
"a_out": "sparse_categorical_crossentropy",
"type_out": "sparse_categorical_crossentropy",
}
loss_weights = {"q_out": 1.0, "a_out": 1.0, "type_out": 0.3}
model.compile(optimizer="adam", loss=losses, loss_weights=loss_weights,
metrics={"q_out": "sparse_categorical_accuracy",
"a_out": "sparse_categorical_accuracy",
"type_out": "accuracy"})
history = model.fit(
[X_tok_ids, X_ner_ids, X_srl_ids, q_in_ids, a_in_ids],
[q_out_ids, a_out_ids, y_type_ids],
validation_split=0.1,
epochs=30,
batch_size=64,
callbacks=[tf.keras.callbacks.EarlyStopping(patience=4, restore_best_weights=True)],
verbose=1,
)
model.save("seq2seq_attn.keras")
print("Model saved to seq2seq_attn.keras")
# ----------------------- 7. Inference submodels --------------------------
# Encoder model
encoder_model = Model([tok_in, ner_in, srl_in], [enc_seq, enc_h, enc_c])
# Question decoder step model ------------------------------------------------
# Inputs
q_token_in = Input((1,), dtype="int32", name="q_token_in")
enc_seq_in = Input((MAX_SENT, lat_dim), name="enc_seq_in")
enc_proj_q_in = Input((MAX_SENT, lat_dim), name="enc_proj_q_in")
state_h_in = Input((lat_dim,), name="state_h_in")
state_c_in = Input((lat_dim,), name="state_c_in")
# Embedding
q_emb_step = model.get_layer("q_emb")(q_token_in)
# LSTM (reuse weights)
q_lstm_step, h_out, c_out = model.get_layer("decoder_q_lstm")(q_emb_step,
initial_state=[state_h_in, state_c_in])
# Attention
attn_step = model.get_layer("attn_q")([q_lstm_step, enc_proj_q_in])
q_concat_step = Concatenate()([q_lstm_step, attn_step])
q_logits_step = model.get_layer("q_out")(q_concat_step)
decoder_q_step = Model([q_token_in, enc_seq_in, enc_proj_q_in, state_h_in, state_c_in],
[q_logits_step, h_out, c_out])
# Answer decoder step model --------------------------------------------------
a_token_in = Input((1,), dtype="int32", name="a_token_in")
enc_proj_a_in = Input((MAX_SENT, lat_dim), name="enc_proj_a_in")
state_h_a_in = Input((lat_dim,), name="state_h_a_in")
state_c_a_in = Input((lat_dim,), name="state_c_a_in")
# Embedding reuse
a_emb_step = model.get_layer("a_emb")(a_token_in)
# LSTM reuse
a_lstm_step, h_a_out, c_a_out = model.get_layer("decoder_a_lstm")(a_emb_step,
initial_state=[state_h_a_in, state_c_a_in])
# Attention reuse
attn_a_step = model.get_layer("attn_a")([a_lstm_step, enc_proj_a_in])
a_concat_step = Concatenate()([a_lstm_step, attn_a_step])
a_logits_step = model.get_layer("a_out")(a_concat_step)
decoder_a_step = Model([a_token_in, enc_proj_a_in, state_h_a_in, state_c_a_in],
[a_logits_step, h_a_out, c_a_out])
# ----------------------- 8. Decoding helpers ------------------------------
def encode_and_pad(seq, vmap, max_len):
ids = encode(seq, vmap)
return ids + [vmap["<pad>"]] * (max_len - len(ids))
def greedy_decode(tokens, ner, srl, max_q=20, max_a=10):
"""Return generated (question_tokens, answer_tokens, q_type_str)"""
# --- encoder ---------------------------------------------------------
enc_tok = np.array([encode_and_pad(tokens, v_tok, MAX_SENT)])
enc_ner = np.array([encode_and_pad(ner, v_ner, MAX_SENT)])
enc_srl = np.array([encode_and_pad(srl, v_srl, MAX_SENT)])
enc_seq_val, h, c = encoder_model.predict([enc_tok, enc_ner, enc_srl], verbose=0)
enc_proj_q_val = model.get_layer("enc_proj_q")(enc_seq_val)
enc_proj_a_val = model.get_layer("enc_proj_a")(enc_seq_val)
# --- greedy Question --------------------------------------------------
q_ids = []
tgt = np.array([[v_q["<sos>"]]])
for _ in range(max_q):
logits, h, c = decoder_q_step.predict([tgt, enc_seq_val, enc_proj_q_val, h, c], verbose=0)
next_id = int(logits[0, 0].argmax())
if next_id == v_q["<eos>"]:
break
q_ids.append(next_id)
tgt = np.array([[next_id]])
# --- reset state for Answer -------------------------------------------
# Use last q_h, q_c (already in h,c)
a_ids = []
tgt_a = np.array([[v_a["<sos>"]]])
for _ in range(max_a):
logits_a, h, c = decoder_a_step.predict([tgt_a, enc_proj_a_val, h, c], verbose=0)
next_a = int(logits_a[0, 0].argmax())
if next_a == v_a["<eos>"]:
break
a_ids.append(next_a)
tgt_a = np.array([[next_a]])
# Question type
typ_logits = model.predict([enc_tok, enc_ner, enc_srl, np.zeros((1, MAX_Q)), np.zeros((1, MAX_A))], verbose=0)[2]
typ_id = int(typ_logits.argmax())
q_type = [k for k, v in v_typ.items() if v == typ_id][0]
question = [iv_q.get(i, "<unk>") for i in q_ids]
answer = [iv_a.get(i, "<unk>") for i in a_ids]
return question, answer, q_type
def beam_decode(tokens, ner, srl, beam_width=5, max_q=20, max_a=10):
"""Beam search decoding. Returns best (question_tokens, answer_tokens, q_type)"""
enc_tok = np.array([encode_and_pad(tokens, v_tok, MAX_SENT)])
enc_ner = np.array([encode_and_pad(ner, v_ner, MAX_SENT)])
enc_srl = np.array([encode_and_pad(srl, v_srl, MAX_SENT)])
enc_seq_val, h0, c0 = encoder_model.predict([enc_tok, enc_ner, enc_srl], verbose=0)
enc_proj_q_val = model.get_layer("enc_proj_q")(enc_seq_val)
enc_proj_a_val = model.get_layer("enc_proj_a")(enc_seq_val)
# ----- Beam for Question ----------------------------------------------
Beam = [( [v_q["<sos>"]], 0.0, h0, c0 )] # (sequence, logP, h, c)
completed_q = []
for _ in range(max_q):
new_beam = []
for seq, logp, h, c in Beam:
tgt = np.array([[seq[-1]]])
logits, next_h, next_c = decoder_q_step.predict([tgt, enc_seq_val, enc_proj_q_val, h, c], verbose=0)
log_probs = np.log(logits[0, 0] + 1e-8)
top_ids = np.argsort(log_probs)[-beam_width:]
for nid in top_ids:
new_seq = seq + [int(nid)]
new_logp = logp + log_probs[nid]
new_beam.append( (new_seq, new_logp, next_h, next_c) )
# keep best beam_width
Beam = sorted(new_beam, key=lambda x: x[1], reverse=True)[:beam_width]
# move completed
Beam, done = [], Beam # placeholder copy to modify
for seq, logp, h, c in done:
if seq[-1] == v_q["<eos>"] or len(seq) >= max_q:
completed_q.append( (seq, logp, h, c) )
else:
Beam.append( (seq, logp, h, c) )
if not Beam:
break
if completed_q:
best_q = max(completed_q, key=lambda x: x[1])
else:
best_q = max(Beam, key=lambda x: x[1])
q_seq_ids, _, h_q, c_q = best_q
q_ids = [i for i in q_seq_ids[1:] if i != v_q["<eos>"]]
# ----- Beam for Answer --------------------------------------------------
Beam = [( [v_a["<sos>"]], 0.0, h_q, c_q )]
completed_a = []
for _ in range(max_a):
new_beam = []
for seq, logp, h, c in Beam:
tgt = np.array([[seq[-1]]])
logits, next_h, next_c = decoder_a_step.predict([tgt, enc_proj_a_val, h, c], verbose=0)
log_probs = np.log(logits[0, 0] + 1e-8)
top_ids = np.argsort(log_probs)[-beam_width:]
for nid in top_ids:
new_seq = seq + [int(nid)]
new_logp = logp + log_probs[nid]
new_beam.append( (new_seq, new_logp, next_h, next_c) )
Beam = sorted(new_beam, key=lambda x: x[1], reverse=True)[:beam_width]
Beam, done = [], Beam
for seq, logp, h, c in done:
if seq[-1] == v_a["<eos>"] or len(seq) >= max_a:
completed_a.append( (seq, logp) )
else:
Beam.append( (seq, logp, h, c) )
if not Beam:
break
if completed_a:
best_a_seq, _ = max(completed_a, key=lambda x: x[1])
else:
best_a_seq, _ = max(Beam, key=lambda x: x[1])
a_ids = [i for i in best_a_seq[1:] if i != v_a["<eos>"]]
# Question type classification
typ_logits = model.predict([enc_tok, enc_ner, enc_srl, np.zeros((1, MAX_Q)), np.zeros((1, MAX_A))], verbose=0)[2]
typ_id = int(typ_logits.argmax())
q_type = [k for k, v in v_typ.items() if v == typ_id][0]
question = [iv_q.get(i, "<unk>") for i in q_ids]
answer = [iv_a.get(i, "<unk>") for i in a_ids]
return question, answer, q_type
# ----------------------- 9. BLEU evaluation -------------------------------
def evaluate_bleu(split_ratio=0.1, beam=False):
"""Compute corpus BLEU4 on holdout split."""
n_total = len(samples)
n_val = int(n_total * split_ratio)
idxs = np.random.choice(n_total, n_val, replace=False)
refs_q, hyps_q = [], []
refs_a, hyps_a = [], []
for i in idxs:
s = samples[i]
question_pred, answer_pred, _ = (beam_decode if beam else greedy_decode)(
s["tokens"], s["ner"], s["srl"],
)
refs_q.append([s["q_toks"][:-1]]) # exclude <eos>
hyps_q.append(question_pred)
refs_a.append([s["a_toks"][:-1]])
hyps_a.append(answer_pred)
bleu_q = corpus_bleu(refs_q, hyps_q)
bleu_a = corpus_bleu(refs_a, hyps_a)
print(f"BLEU4 Question: {bleu_q:.3f}\nBLEU4 Answer : {bleu_a:.3f}")
# Example usage (uncomment):
evaluate_bleu(beam=False)
evaluate_bleu(beam=True)