diff --git a/question_generation/qg_lstm.ipynb b/question_generation/qg_lstm.ipynb index 06e309a..04e4353 100644 --- a/question_generation/qg_lstm.ipynb +++ b/question_generation/qg_lstm.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 32, + "execution_count": null, "id": "fb283f23", "metadata": {}, "outputs": [ @@ -20,7 +20,7 @@ "from itertools import chain\n", "\n", "RAW = json.loads(\n", - " Path(\"../dataset/dev_dataset_test.json\").read_text()\n", + " Path(\"../dataset/dev_dataset_qg.json\").read_text()\n", ") # ← file contoh Anda\n", "\n", "samples = []\n", @@ -46,7 +46,7 @@ }, { "cell_type": "code", - "execution_count": 33, + "execution_count": 102, "id": "fa4f979d", "metadata": {}, "outputs": [ @@ -80,7 +80,7 @@ }, { "cell_type": "code", - "execution_count": 34, + "execution_count": 103, "id": "d1a5b324", "metadata": {}, "outputs": [], @@ -125,27 +125,27 @@ "dec_a_out = pad_sequences(\n", " [encode(s[\"a_toks\"], vocab_a) for s in samples], maxlen=MAX_A, padding=\"post\"\n", ")\n", + "y_type = np.array([vocab_typ[s[\"q_type\"]] for s in samples])\n", "\n", "MAX_SENT = max(len(s[\"tokens\"]) for s in samples)\n", "MAX_Q = max(len(s[\"q_toks\"]) for s in samples)\n", - "MAX_A = max(len(s[\"a_toks\"]) for s in samples)\n", - "y_type = np.array([vocab_typ[s[\"q_type\"]] for s in samples])" + "MAX_A = max(len(s[\"a_toks\"]) for s in samples)" ] }, { "cell_type": "code", - "execution_count": 35, + "execution_count": 104, "id": "ff5bd85f", "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
Model: \"functional_3\"\n",
+ "Model: \"functional_12\"\n",
"
\n"
],
"text/plain": [
- "\u001b[1mModel: \"functional_3\"\u001b[0m\n"
+ "\u001b[1mModel: \"functional_12\"\u001b[0m\n"
]
},
"metadata": {},
@@ -163,56 +163,56 @@
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ srl_in (InputLayer) │ (None, 9) │ 0 │ - │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
- "│ embedding_tok │ (None, 9, 128) │ 31,232 │ tok_in[0][0] │\n",
+ "│ embedding_tok │ (None, 9, 32) │ 7,808 │ tok_in[0][0] │\n",
"│ (Embedding) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
- "│ embedding_ner │ (None, 9, 32) │ 288 │ ner_in[0][0] │\n",
+ "│ embedding_ner │ (None, 9, 16) │ 144 │ ner_in[0][0] │\n",
"│ (Embedding) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
- "│ embedding_srl │ (None, 9, 32) │ 224 │ srl_in[0][0] │\n",
+ "│ embedding_srl │ (None, 9, 16) │ 112 │ srl_in[0][0] │\n",
"│ (Embedding) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ dec_q_in │ (None, 11) │ 0 │ - │\n",
"│ (InputLayer) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
- "│ concatenate_3 │ (None, 9, 192) │ 0 │ embedding_tok[0]… │\n",
+ "│ concatenate_14 │ (None, 9, 64) │ 0 │ embedding_tok[0]… │\n",
"│ (Concatenate) │ │ │ embedding_ner[0]… │\n",
"│ │ │ │ embedding_srl[0]… │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ dec_a_in │ (None, 4) │ 0 │ - │\n",
"│ (InputLayer) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
- "│ embedding_q_decoder │ (None, 11, 128) │ 27,008 │ dec_q_in[0][0] │\n",
+ "│ embedding_q_decoder │ (None, 11, 32) │ 6,752 │ dec_q_in[0][0] │\n",
"│ (Embedding) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
- "│ encoder_lstm (LSTM) │ [(None, 256), │ 459,776 │ concatenate_3[0]… │\n",
- "│ │ (None, 256), │ │ │\n",
- "│ │ (None, 256)] │ │ │\n",
+ "│ encoder_lstm (LSTM) │ [(None, 64), │ 33,024 │ concatenate_14[0… │\n",
+ "│ │ (None, 64), │ │ │\n",
+ "│ │ (None, 64)] │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
- "│ embedding_a_decoder │ (None, 4, 128) │ 14,336 │ dec_a_in[0][0] │\n",
+ "│ embedding_a_decoder │ (None, 4, 32) │ 3,584 │ dec_a_in[0][0] │\n",
"│ (Embedding) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
- "│ lstm_q_decoder │ [(None, 11, 256), │ 394,240 │ embedding_q_deco… │\n",
- "│ (LSTM) │ (None, 256), │ │ encoder_lstm[0][… │\n",
- "│ │ (None, 256)] │ │ encoder_lstm[0][… │\n",
+ "│ lstm_q_decoder │ [(None, 11, 64), │ 24,832 │ embedding_q_deco… │\n",
+ "│ (LSTM) │ (None, 64), │ │ encoder_lstm[0][… │\n",
+ "│ │ (None, 64)] │ │ encoder_lstm[0][… │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
- "│ not_equal_12 │ (None, 11) │ 0 │ dec_q_in[0][0] │\n",
+ "│ not_equal_51 │ (None, 11) │ 0 │ dec_q_in[0][0] │\n",
"│ (NotEqual) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
- "│ lstm_a_decoder │ [(None, 4, 256), │ 394,240 │ embedding_a_deco… │\n",
- "│ (LSTM) │ (None, 256), │ │ encoder_lstm[0][… │\n",
- "│ │ (None, 256)] │ │ encoder_lstm[0][… │\n",
+ "│ lstm_a_decoder │ [(None, 4, 64), │ 24,832 │ embedding_a_deco… │\n",
+ "│ (LSTM) │ (None, 64), │ │ encoder_lstm[0][… │\n",
+ "│ │ (None, 64)] │ │ encoder_lstm[0][… │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
- "│ not_equal_13 │ (None, 4) │ 0 │ dec_a_in[0][0] │\n",
+ "│ not_equal_52 │ (None, 4) │ 0 │ dec_a_in[0][0] │\n",
"│ (NotEqual) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
- "│ q_output │ (None, 11, 211) │ 54,227 │ lstm_q_decoder[0… │\n",
- "│ (TimeDistributed) │ │ │ not_equal_12[0][… │\n",
+ "│ q_output │ (None, 11, 211) │ 13,715 │ lstm_q_decoder[0… │\n",
+ "│ (TimeDistributed) │ │ │ not_equal_51[0][… │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
- "│ a_output │ (None, 4, 112) │ 28,784 │ lstm_a_decoder[0… │\n",
- "│ (TimeDistributed) │ │ │ not_equal_13[0][… │\n",
+ "│ a_output │ (None, 4, 112) │ 7,280 │ lstm_a_decoder[0… │\n",
+ "│ (TimeDistributed) │ │ │ not_equal_52[0][… │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
- "│ type_output (Dense) │ (None, 3) │ 771 │ encoder_lstm[0][… │\n",
+ "│ type_output (Dense) │ (None, 3) │ 195 │ encoder_lstm[0][… │\n",
"└─────────────────────┴───────────────────┴────────────┴───────────────────┘\n",
"
\n"
],
@@ -226,56 +226,56 @@
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ srl_in (\u001b[38;5;33mInputLayer\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m9\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
- "│ embedding_tok │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m9\u001b[0m, \u001b[38;5;34m128\u001b[0m) │ \u001b[38;5;34m31,232\u001b[0m │ tok_in[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
+ "│ embedding_tok │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m9\u001b[0m, \u001b[38;5;34m32\u001b[0m) │ \u001b[38;5;34m7,808\u001b[0m │ tok_in[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"│ (\u001b[38;5;33mEmbedding\u001b[0m) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
- "│ embedding_ner │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m9\u001b[0m, \u001b[38;5;34m32\u001b[0m) │ \u001b[38;5;34m288\u001b[0m │ ner_in[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
+ "│ embedding_ner │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m9\u001b[0m, \u001b[38;5;34m16\u001b[0m) │ \u001b[38;5;34m144\u001b[0m │ ner_in[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"│ (\u001b[38;5;33mEmbedding\u001b[0m) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
- "│ embedding_srl │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m9\u001b[0m, \u001b[38;5;34m32\u001b[0m) │ \u001b[38;5;34m224\u001b[0m │ srl_in[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
+ "│ embedding_srl │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m9\u001b[0m, \u001b[38;5;34m16\u001b[0m) │ \u001b[38;5;34m112\u001b[0m │ srl_in[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"│ (\u001b[38;5;33mEmbedding\u001b[0m) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ dec_q_in │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m11\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n",
"│ (\u001b[38;5;33mInputLayer\u001b[0m) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
- "│ concatenate_3 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m9\u001b[0m, \u001b[38;5;34m192\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ embedding_tok[\u001b[38;5;34m0\u001b[0m]… │\n",
+ "│ concatenate_14 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m9\u001b[0m, \u001b[38;5;34m64\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ embedding_tok[\u001b[38;5;34m0\u001b[0m]… │\n",
"│ (\u001b[38;5;33mConcatenate\u001b[0m) │ │ │ embedding_ner[\u001b[38;5;34m0\u001b[0m]… │\n",
"│ │ │ │ embedding_srl[\u001b[38;5;34m0\u001b[0m]… │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ dec_a_in │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m4\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n",
"│ (\u001b[38;5;33mInputLayer\u001b[0m) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
- "│ embedding_q_decoder │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m11\u001b[0m, \u001b[38;5;34m128\u001b[0m) │ \u001b[38;5;34m27,008\u001b[0m │ dec_q_in[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
+ "│ embedding_q_decoder │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m11\u001b[0m, \u001b[38;5;34m32\u001b[0m) │ \u001b[38;5;34m6,752\u001b[0m │ dec_q_in[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"│ (\u001b[38;5;33mEmbedding\u001b[0m) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
- "│ encoder_lstm (\u001b[38;5;33mLSTM\u001b[0m) │ [(\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m256\u001b[0m), │ \u001b[38;5;34m459,776\u001b[0m │ concatenate_3[\u001b[38;5;34m0\u001b[0m]… │\n",
- "│ │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m256\u001b[0m), │ │ │\n",
- "│ │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m256\u001b[0m)] │ │ │\n",
+ "│ encoder_lstm (\u001b[38;5;33mLSTM\u001b[0m) │ [(\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m64\u001b[0m), │ \u001b[38;5;34m33,024\u001b[0m │ concatenate_14[\u001b[38;5;34m0\u001b[0m… │\n",
+ "│ │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m64\u001b[0m), │ │ │\n",
+ "│ │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m64\u001b[0m)] │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
- "│ embedding_a_decoder │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m4\u001b[0m, \u001b[38;5;34m128\u001b[0m) │ \u001b[38;5;34m14,336\u001b[0m │ dec_a_in[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
+ "│ embedding_a_decoder │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m4\u001b[0m, \u001b[38;5;34m32\u001b[0m) │ \u001b[38;5;34m3,584\u001b[0m │ dec_a_in[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"│ (\u001b[38;5;33mEmbedding\u001b[0m) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
- "│ lstm_q_decoder │ [(\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m11\u001b[0m, \u001b[38;5;34m256\u001b[0m), │ \u001b[38;5;34m394,240\u001b[0m │ embedding_q_deco… │\n",
- "│ (\u001b[38;5;33mLSTM\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m256\u001b[0m), │ │ encoder_lstm[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m…\u001b[0m │\n",
- "│ │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m256\u001b[0m)] │ │ encoder_lstm[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m…\u001b[0m │\n",
+ "│ lstm_q_decoder │ [(\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m11\u001b[0m, \u001b[38;5;34m64\u001b[0m), │ \u001b[38;5;34m24,832\u001b[0m │ embedding_q_deco… │\n",
+ "│ (\u001b[38;5;33mLSTM\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m64\u001b[0m), │ │ encoder_lstm[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m…\u001b[0m │\n",
+ "│ │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m64\u001b[0m)] │ │ encoder_lstm[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m…\u001b[0m │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
- "│ not_equal_12 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m11\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ dec_q_in[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
+ "│ not_equal_51 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m11\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ dec_q_in[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"│ (\u001b[38;5;33mNotEqual\u001b[0m) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
- "│ lstm_a_decoder │ [(\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m4\u001b[0m, \u001b[38;5;34m256\u001b[0m), │ \u001b[38;5;34m394,240\u001b[0m │ embedding_a_deco… │\n",
- "│ (\u001b[38;5;33mLSTM\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m256\u001b[0m), │ │ encoder_lstm[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m…\u001b[0m │\n",
- "│ │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m256\u001b[0m)] │ │ encoder_lstm[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m…\u001b[0m │\n",
+ "│ lstm_a_decoder │ [(\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m4\u001b[0m, \u001b[38;5;34m64\u001b[0m), │ \u001b[38;5;34m24,832\u001b[0m │ embedding_a_deco… │\n",
+ "│ (\u001b[38;5;33mLSTM\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m64\u001b[0m), │ │ encoder_lstm[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m…\u001b[0m │\n",
+ "│ │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m64\u001b[0m)] │ │ encoder_lstm[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m…\u001b[0m │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
- "│ not_equal_13 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m4\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ dec_a_in[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
+ "│ not_equal_52 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m4\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ dec_a_in[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"│ (\u001b[38;5;33mNotEqual\u001b[0m) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
- "│ q_output │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m11\u001b[0m, \u001b[38;5;34m211\u001b[0m) │ \u001b[38;5;34m54,227\u001b[0m │ lstm_q_decoder[\u001b[38;5;34m0\u001b[0m… │\n",
- "│ (\u001b[38;5;33mTimeDistributed\u001b[0m) │ │ │ not_equal_12[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m…\u001b[0m │\n",
+ "│ q_output │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m11\u001b[0m, \u001b[38;5;34m211\u001b[0m) │ \u001b[38;5;34m13,715\u001b[0m │ lstm_q_decoder[\u001b[38;5;34m0\u001b[0m… │\n",
+ "│ (\u001b[38;5;33mTimeDistributed\u001b[0m) │ │ │ not_equal_51[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m…\u001b[0m │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
- "│ a_output │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m4\u001b[0m, \u001b[38;5;34m112\u001b[0m) │ \u001b[38;5;34m28,784\u001b[0m │ lstm_a_decoder[\u001b[38;5;34m0\u001b[0m… │\n",
- "│ (\u001b[38;5;33mTimeDistributed\u001b[0m) │ │ │ not_equal_13[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m…\u001b[0m │\n",
+ "│ a_output │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m4\u001b[0m, \u001b[38;5;34m112\u001b[0m) │ \u001b[38;5;34m7,280\u001b[0m │ lstm_a_decoder[\u001b[38;5;34m0\u001b[0m… │\n",
+ "│ (\u001b[38;5;33mTimeDistributed\u001b[0m) │ │ │ not_equal_52[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m…\u001b[0m │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
- "│ type_output (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m3\u001b[0m) │ \u001b[38;5;34m771\u001b[0m │ encoder_lstm[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m…\u001b[0m │\n",
+ "│ type_output (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m3\u001b[0m) │ \u001b[38;5;34m195\u001b[0m │ encoder_lstm[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m…\u001b[0m │\n",
"└─────────────────────┴───────────────────┴────────────┴───────────────────┘\n"
]
},
@@ -285,11 +285,11 @@
{
"data": {
"text/html": [
- "Total params: 1,405,126 (5.36 MB)\n", + "Total params: 122,278 (477.65 KB)\n", "\n" ], "text/plain": [ - "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m1,405,126\u001b[0m (5.36 MB)\n" + "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m122,278\u001b[0m (477.65 KB)\n" ] }, "metadata": {}, @@ -298,11 +298,11 @@ { "data": { "text/html": [ - "Trainable params: 1,405,126 (5.36 MB)\n", + "Trainable params: 122,278 (477.65 KB)\n", "\n" ], "text/plain": [ - "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m1,405,126\u001b[0m (5.36 MB)\n" + "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m122,278\u001b[0m (477.65 KB)\n" ] }, "metadata": {}, @@ -335,9 +335,9 @@ "from tensorflow.keras.models import Model\n", "\n", "# ---- constants ---------------------------------------------------\n", - "d_tok = 128 # token embedding dim\n", - "d_tag = 32 # NER / SRL embedding dim\n", - "units = 256\n", + "d_tok = 32 # token embedding dim\n", + "d_tag = 16 # NER / SRL embedding dim\n", + "units = 64\n", "\n", "# ---- encoder -----------------------------------------------------\n", "inp_tok = Input((MAX_SENT,), name=\"tok_in\")\n", @@ -399,7 +399,7 @@ }, { "cell_type": "code", - "execution_count": 36, + "execution_count": 105, "id": "fece1ae9", "metadata": {}, "outputs": [ @@ -408,61 +408,65 @@ "output_type": "stream", "text": [ "Epoch 1/30\n", - "\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 83ms/step - a_output_loss: 4.5185 - a_output_sparse_categorical_accuracy: 0.1853 - loss: 10.1289 - q_output_loss: 5.2751 - q_output_sparse_categorical_accuracy: 0.1679 - type_output_accuracy: 0.3966 - type_output_loss: 1.0344 - val_a_output_loss: 2.2993 - val_a_output_sparse_categorical_accuracy: 0.2500 - val_loss: 6.6556 - val_q_output_loss: 4.1554 - val_q_output_sparse_categorical_accuracy: 0.1606 - val_type_output_accuracy: 0.6167 - val_type_output_loss: 0.6698\n", + "\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 69ms/step - a_output_loss: 4.6957 - a_output_sparse_categorical_accuracy: 0.1413 - loss: 10.3613 - q_output_loss: 5.3441 - q_output_sparse_categorical_accuracy: 0.0670 - type_output_accuracy: 0.4939 - type_output_loss: 1.0668 - val_a_output_loss: 4.5668 - val_a_output_sparse_categorical_accuracy: 0.2500 - val_loss: 10.1527 - val_q_output_loss: 5.3034 - val_q_output_sparse_categorical_accuracy: 0.1182 - val_type_output_accuracy: 0.7000 - val_type_output_loss: 0.9419\n", "Epoch 2/30\n", - "\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 28ms/step - a_output_loss: 2.7148 - a_output_sparse_categorical_accuracy: 0.2674 - loss: 6.7993 - q_output_loss: 3.8706 - q_output_sparse_categorical_accuracy: 0.1625 - type_output_accuracy: 0.5096 - type_output_loss: 0.7111 - val_a_output_loss: 1.9687 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 5.6731 - val_q_output_loss: 3.4983 - val_q_output_sparse_categorical_accuracy: 0.2848 - val_type_output_accuracy: 0.8167 - val_type_output_loss: 0.6870\n", + "\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 13ms/step - a_output_loss: 4.4732 - a_output_sparse_categorical_accuracy: 0.2500 - loss: 10.0049 - q_output_loss: 5.2579 - q_output_sparse_categorical_accuracy: 0.1001 - type_output_accuracy: 0.5591 - type_output_loss: 0.8840 - val_a_output_loss: 3.6835 - val_a_output_sparse_categorical_accuracy: 0.2500 - val_loss: 8.8571 - val_q_output_loss: 4.9562 - val_q_output_sparse_categorical_accuracy: 0.1000 - val_type_output_accuracy: 0.3833 - val_type_output_loss: 0.7245\n", "Epoch 3/30\n", - "\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 29ms/step - a_output_loss: 2.4494 - a_output_sparse_categorical_accuracy: 0.3225 - loss: 6.0211 - q_output_loss: 3.3610 - q_output_sparse_categorical_accuracy: 0.2195 - type_output_accuracy: 0.6316 - type_output_loss: 0.6890 - val_a_output_loss: 1.8200 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 5.1733 - val_q_output_loss: 3.1541 - val_q_output_sparse_categorical_accuracy: 0.2136 - val_type_output_accuracy: 0.6167 - val_type_output_loss: 0.6640\n", + "\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 13ms/step - a_output_loss: 3.4959 - a_output_sparse_categorical_accuracy: 0.2500 - loss: 8.4809 - q_output_loss: 4.7656 - q_output_sparse_categorical_accuracy: 0.1299 - type_output_accuracy: 0.5308 - type_output_loss: 0.7076 - val_a_output_loss: 2.4516 - val_a_output_sparse_categorical_accuracy: 0.2500 - val_loss: 6.9437 - val_q_output_loss: 4.2829 - val_q_output_sparse_categorical_accuracy: 0.1758 - val_type_output_accuracy: 0.3833 - val_type_output_loss: 0.6973\n", "Epoch 4/30\n", - "\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 29ms/step - a_output_loss: 2.2031 - a_output_sparse_categorical_accuracy: 0.3726 - loss: 5.4427 - q_output_loss: 3.0396 - q_output_sparse_categorical_accuracy: 0.2869 - type_output_accuracy: 0.5478 - type_output_loss: 0.6786 - val_a_output_loss: 1.5951 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 4.6514 - val_q_output_loss: 2.8462 - val_q_output_sparse_categorical_accuracy: 0.3758 - val_type_output_accuracy: 0.3833 - val_type_output_loss: 0.7003\n", + "\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 13ms/step - a_output_loss: 2.6025 - a_output_sparse_categorical_accuracy: 0.2620 - loss: 6.9572 - q_output_loss: 4.1438 - q_output_sparse_categorical_accuracy: 0.1763 - type_output_accuracy: 0.5010 - type_output_loss: 0.6949 - val_a_output_loss: 1.9285 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 5.9705 - val_q_output_loss: 3.8322 - val_q_output_sparse_categorical_accuracy: 0.1682 - val_type_output_accuracy: 0.3833 - val_type_output_loss: 0.6996\n", "Epoch 5/30\n", - "\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 30ms/step - a_output_loss: 2.0493 - a_output_sparse_categorical_accuracy: 0.3739 - loss: 4.9806 - q_output_loss: 2.7312 - q_output_sparse_categorical_accuracy: 0.3659 - type_output_accuracy: 0.5789 - type_output_loss: 0.6663 - val_a_output_loss: 1.5595 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 4.3330 - val_q_output_loss: 2.5741 - val_q_output_sparse_categorical_accuracy: 0.4197 - val_type_output_accuracy: 0.4833 - val_type_output_loss: 0.6650\n", + "\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 13ms/step - a_output_loss: 2.3723 - a_output_sparse_categorical_accuracy: 0.3693 - loss: 6.3040 - q_output_loss: 3.7185 - q_output_sparse_categorical_accuracy: 0.1705 - type_output_accuracy: 0.5228 - type_output_loss: 0.6954 - val_a_output_loss: 1.7395 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 5.5526 - val_q_output_loss: 3.6050 - val_q_output_sparse_categorical_accuracy: 0.1818 - val_type_output_accuracy: 0.4333 - val_type_output_loss: 0.6936\n", "Epoch 6/30\n", - "\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 28ms/step - a_output_loss: 1.9310 - a_output_sparse_categorical_accuracy: 0.3769 - loss: 4.5682 - q_output_loss: 2.4562 - q_output_sparse_categorical_accuracy: 0.4147 - type_output_accuracy: 0.6767 - type_output_loss: 0.6249 - val_a_output_loss: 1.4318 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 3.9148 - val_q_output_loss: 2.3074 - val_q_output_sparse_categorical_accuracy: 0.4333 - val_type_output_accuracy: 0.8167 - val_type_output_loss: 0.5853\n", + "\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 13ms/step - a_output_loss: 2.2466 - a_output_sparse_categorical_accuracy: 0.3660 - loss: 5.9522 - q_output_loss: 3.4987 - q_output_sparse_categorical_accuracy: 0.1843 - type_output_accuracy: 0.5311 - type_output_loss: 0.6928 - val_a_output_loss: 1.6680 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 5.3542 - val_q_output_loss: 3.4792 - val_q_output_sparse_categorical_accuracy: 0.2591 - val_type_output_accuracy: 0.6167 - val_type_output_loss: 0.6898\n", "Epoch 7/30\n", - "\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 28ms/step - a_output_loss: 1.8407 - a_output_sparse_categorical_accuracy: 0.3722 - loss: 4.1857 - q_output_loss: 2.1853 - q_output_sparse_categorical_accuracy: 0.4229 - type_output_accuracy: 0.7938 - type_output_loss: 0.5382 - val_a_output_loss: 1.3413 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 3.5705 - val_q_output_loss: 2.0979 - val_q_output_sparse_categorical_accuracy: 0.4515 - val_type_output_accuracy: 0.8500 - val_type_output_loss: 0.4377\n", + "\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 13ms/step - a_output_loss: 2.1706 - a_output_sparse_categorical_accuracy: 0.3748 - loss: 5.7833 - q_output_loss: 3.3972 - q_output_sparse_categorical_accuracy: 0.2581 - type_output_accuracy: 0.4983 - type_output_loss: 0.6943 - val_a_output_loss: 1.6312 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 5.2140 - val_q_output_loss: 3.3728 - val_q_output_sparse_categorical_accuracy: 0.2591 - val_type_output_accuracy: 0.3833 - val_type_output_loss: 0.7001\n", "Epoch 8/30\n", - "\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 28ms/step - a_output_loss: 1.8257 - a_output_sparse_categorical_accuracy: 0.3606 - loss: 3.9578 - q_output_loss: 1.9912 - q_output_sparse_categorical_accuracy: 0.4426 - type_output_accuracy: 0.7711 - type_output_loss: 0.4701 - val_a_output_loss: 1.2847 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 3.3498 - val_q_output_loss: 1.9433 - val_q_output_sparse_categorical_accuracy: 0.4712 - val_type_output_accuracy: 0.8500 - val_type_output_loss: 0.4062\n", + "\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 13ms/step - a_output_loss: 2.0769 - a_output_sparse_categorical_accuracy: 0.3709 - loss: 5.5707 - q_output_loss: 3.2835 - q_output_sparse_categorical_accuracy: 0.2579 - type_output_accuracy: 0.5165 - type_output_loss: 0.6925 - val_a_output_loss: 1.5953 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 5.0738 - val_q_output_loss: 3.2679 - val_q_output_sparse_categorical_accuracy: 0.2591 - val_type_output_accuracy: 0.3833 - val_type_output_loss: 0.7022\n", "Epoch 9/30\n", - "\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 29ms/step - a_output_loss: 1.6802 - a_output_sparse_categorical_accuracy: 0.3683 - loss: 3.6282 - q_output_loss: 1.8210 - q_output_sparse_categorical_accuracy: 0.4542 - type_output_accuracy: 0.7996 - type_output_loss: 0.4131 - val_a_output_loss: 1.2484 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 3.1973 - val_q_output_loss: 1.8318 - val_q_output_sparse_categorical_accuracy: 0.4667 - val_type_output_accuracy: 0.8500 - val_type_output_loss: 0.3905\n", + "\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 13ms/step - a_output_loss: 2.0811 - a_output_sparse_categorical_accuracy: 0.3700 - loss: 5.4626 - q_output_loss: 3.1783 - q_output_sparse_categorical_accuracy: 0.2586 - type_output_accuracy: 0.5294 - type_output_loss: 0.6897 - val_a_output_loss: 1.5771 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 4.9584 - val_q_output_loss: 3.1740 - val_q_output_sparse_categorical_accuracy: 0.2591 - val_type_output_accuracy: 0.5667 - val_type_output_loss: 0.6911\n", "Epoch 10/30\n", - "\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 29ms/step - a_output_loss: 1.6225 - a_output_sparse_categorical_accuracy: 0.3738 - loss: 3.4468 - q_output_loss: 1.7095 - q_output_sparse_categorical_accuracy: 0.4516 - type_output_accuracy: 0.8104 - type_output_loss: 0.3915 - val_a_output_loss: 1.2075 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 3.0751 - val_q_output_loss: 1.7371 - val_q_output_sparse_categorical_accuracy: 0.4758 - val_type_output_accuracy: 0.8167 - val_type_output_loss: 0.4349\n", + "\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 13ms/step - a_output_loss: 2.1436 - a_output_sparse_categorical_accuracy: 0.3635 - loss: 5.4602 - q_output_loss: 3.0972 - q_output_sparse_categorical_accuracy: 0.2603 - type_output_accuracy: 0.5954 - type_output_loss: 0.6906 - val_a_output_loss: 1.5554 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 4.8437 - val_q_output_loss: 3.0796 - val_q_output_sparse_categorical_accuracy: 0.2742 - val_type_output_accuracy: 0.4333 - val_type_output_loss: 0.6955\n", "Epoch 11/30\n", - "\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 28ms/step - a_output_loss: 1.5862 - a_output_sparse_categorical_accuracy: 0.3721 - loss: 3.3171 - q_output_loss: 1.6153 - q_output_sparse_categorical_accuracy: 0.4538 - type_output_accuracy: 0.8102 - type_output_loss: 0.3785 - val_a_output_loss: 1.1730 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 2.9582 - val_q_output_loss: 1.6659 - val_q_output_sparse_categorical_accuracy: 0.4697 - val_type_output_accuracy: 0.8500 - val_type_output_loss: 0.3979\n", + "\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 13ms/step - a_output_loss: 2.0414 - a_output_sparse_categorical_accuracy: 0.3698 - loss: 5.2346 - q_output_loss: 2.9916 - q_output_sparse_categorical_accuracy: 0.2940 - type_output_accuracy: 0.5497 - type_output_loss: 0.6878 - val_a_output_loss: 1.5271 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 4.7076 - val_q_output_loss: 2.9767 - val_q_output_sparse_categorical_accuracy: 0.3788 - val_type_output_accuracy: 0.6167 - val_type_output_loss: 0.6791\n", "Epoch 12/30\n", - "\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 29ms/step - a_output_loss: 1.5335 - a_output_sparse_categorical_accuracy: 0.3664 - loss: 3.1939 - q_output_loss: 1.5335 - q_output_sparse_categorical_accuracy: 0.4518 - type_output_accuracy: 0.7775 - type_output_loss: 0.4105 - val_a_output_loss: 1.1304 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 2.8616 - val_q_output_loss: 1.6143 - val_q_output_sparse_categorical_accuracy: 0.4727 - val_type_output_accuracy: 0.8500 - val_type_output_loss: 0.3897\n", + "\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 13ms/step - a_output_loss: 1.9689 - a_output_sparse_categorical_accuracy: 0.3747 - loss: 5.0683 - q_output_loss: 2.8947 - q_output_sparse_categorical_accuracy: 0.3778 - type_output_accuracy: 0.5652 - type_output_loss: 0.6839 - val_a_output_loss: 1.5182 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 4.6171 - val_q_output_loss: 2.8857 - val_q_output_sparse_categorical_accuracy: 0.3985 - val_type_output_accuracy: 0.3833 - val_type_output_loss: 0.7105\n", "Epoch 13/30\n", - "\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 28ms/step - a_output_loss: 1.4389 - a_output_sparse_categorical_accuracy: 0.3754 - loss: 3.0570 - q_output_loss: 1.5013 - q_output_sparse_categorical_accuracy: 0.4627 - type_output_accuracy: 0.8116 - type_output_loss: 0.3873 - val_a_output_loss: 1.1223 - val_a_output_sparse_categorical_accuracy: 0.4083 - val_loss: 2.8335 - val_q_output_loss: 1.5881 - val_q_output_sparse_categorical_accuracy: 0.4515 - val_type_output_accuracy: 0.8333 - val_type_output_loss: 0.4107\n", + "\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 13ms/step - a_output_loss: 1.9573 - a_output_sparse_categorical_accuracy: 0.3675 - loss: 4.9542 - q_output_loss: 2.7911 - q_output_sparse_categorical_accuracy: 0.3790 - type_output_accuracy: 0.5301 - type_output_loss: 0.6814 - val_a_output_loss: 1.5033 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 4.5030 - val_q_output_loss: 2.7936 - val_q_output_sparse_categorical_accuracy: 0.3985 - val_type_output_accuracy: 0.5333 - val_type_output_loss: 0.6871\n", "Epoch 14/30\n", - "\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 30ms/step - a_output_loss: 1.4312 - a_output_sparse_categorical_accuracy: 0.3714 - loss: 2.9962 - q_output_loss: 1.4529 - q_output_sparse_categorical_accuracy: 0.4608 - type_output_accuracy: 0.7943 - type_output_loss: 0.3844 - val_a_output_loss: 1.1235 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 2.8020 - val_q_output_loss: 1.5603 - val_q_output_sparse_categorical_accuracy: 0.4652 - val_type_output_accuracy: 0.8500 - val_type_output_loss: 0.3939\n", + "\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 13ms/step - a_output_loss: 1.9229 - a_output_sparse_categorical_accuracy: 0.3745 - loss: 4.8469 - q_output_loss: 2.7145 - q_output_sparse_categorical_accuracy: 0.3810 - type_output_accuracy: 0.5883 - type_output_loss: 0.6817 - val_a_output_loss: 1.4761 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 4.3857 - val_q_output_loss: 2.7074 - val_q_output_sparse_categorical_accuracy: 0.3985 - val_type_output_accuracy: 0.6833 - val_type_output_loss: 0.6741\n", "Epoch 15/30\n", - "\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 28ms/step - a_output_loss: 1.3670 - a_output_sparse_categorical_accuracy: 0.3787 - loss: 2.9339 - q_output_loss: 1.4391 - q_output_sparse_categorical_accuracy: 0.4609 - type_output_accuracy: 0.7846 - type_output_loss: 0.4245 - val_a_output_loss: 1.1046 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 2.7626 - val_q_output_loss: 1.5316 - val_q_output_sparse_categorical_accuracy: 0.4697 - val_type_output_accuracy: 0.8000 - val_type_output_loss: 0.4212\n", + "\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 13ms/step - a_output_loss: 1.9625 - a_output_sparse_categorical_accuracy: 0.3703 - loss: 4.7880 - q_output_loss: 2.6232 - q_output_sparse_categorical_accuracy: 0.3790 - type_output_accuracy: 0.6521 - type_output_loss: 0.6739 - val_a_output_loss: 1.4591 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 4.2924 - val_q_output_loss: 2.6260 - val_q_output_sparse_categorical_accuracy: 0.3985 - val_type_output_accuracy: 0.5167 - val_type_output_loss: 0.6910\n", "Epoch 16/30\n", - "\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 28ms/step - a_output_loss: 1.3007 - a_output_sparse_categorical_accuracy: 0.3830 - loss: 2.8316 - q_output_loss: 1.4239 - q_output_sparse_categorical_accuracy: 0.4647 - type_output_accuracy: 0.8005 - type_output_loss: 0.3726 - val_a_output_loss: 1.1026 - val_a_output_sparse_categorical_accuracy: 0.4083 - val_loss: 2.7498 - val_q_output_loss: 1.5221 - val_q_output_sparse_categorical_accuracy: 0.4652 - val_type_output_accuracy: 0.8000 - val_type_output_loss: 0.4171\n", + "\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 13ms/step - a_output_loss: 1.9763 - a_output_sparse_categorical_accuracy: 0.3661 - loss: 4.7331 - q_output_loss: 2.5511 - q_output_sparse_categorical_accuracy: 0.3784 - type_output_accuracy: 0.6424 - type_output_loss: 0.6679 - val_a_output_loss: 1.4432 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 4.1898 - val_q_output_loss: 2.5484 - val_q_output_sparse_categorical_accuracy: 0.4061 - val_type_output_accuracy: 0.6500 - val_type_output_loss: 0.6605\n", "Epoch 17/30\n", - "\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 31ms/step - a_output_loss: 1.3868 - a_output_sparse_categorical_accuracy: 0.3768 - loss: 2.8855 - q_output_loss: 1.3850 - q_output_sparse_categorical_accuracy: 0.4683 - type_output_accuracy: 0.8190 - type_output_loss: 0.3552 - val_a_output_loss: 1.0983 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 2.7381 - val_q_output_loss: 1.5079 - val_q_output_sparse_categorical_accuracy: 0.4636 - val_type_output_accuracy: 0.8000 - val_type_output_loss: 0.4401\n", + "\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 13ms/step - a_output_loss: 1.8243 - a_output_sparse_categorical_accuracy: 0.3712 - loss: 4.4878 - q_output_loss: 2.4615 - q_output_sparse_categorical_accuracy: 0.3942 - type_output_accuracy: 0.6414 - type_output_loss: 0.6622 - val_a_output_loss: 1.4167 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 4.0918 - val_q_output_loss: 2.4764 - val_q_output_sparse_categorical_accuracy: 0.4061 - val_type_output_accuracy: 0.6667 - val_type_output_loss: 0.6624\n", "Epoch 18/30\n", - "\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 28ms/step - a_output_loss: 1.2897 - a_output_sparse_categorical_accuracy: 0.3837 - loss: 2.7760 - q_output_loss: 1.3759 - q_output_sparse_categorical_accuracy: 0.4749 - type_output_accuracy: 0.8087 - type_output_loss: 0.3802 - val_a_output_loss: 1.1044 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 2.7490 - val_q_output_loss: 1.5155 - val_q_output_sparse_categorical_accuracy: 0.4636 - val_type_output_accuracy: 0.8000 - val_type_output_loss: 0.4305\n", + "\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 16ms/step - a_output_loss: 1.7940 - a_output_sparse_categorical_accuracy: 0.3750 - loss: 4.3987 - q_output_loss: 2.4076 - q_output_sparse_categorical_accuracy: 0.3949 - type_output_accuracy: 0.6488 - type_output_loss: 0.6552 - val_a_output_loss: 1.4059 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 4.0153 - val_q_output_loss: 2.4094 - val_q_output_sparse_categorical_accuracy: 0.4197 - val_type_output_accuracy: 0.6167 - val_type_output_loss: 0.6666\n", "Epoch 19/30\n", - "\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 30ms/step - a_output_loss: 1.3054 - a_output_sparse_categorical_accuracy: 0.3802 - loss: 2.7626 - q_output_loss: 1.3517 - q_output_sparse_categorical_accuracy: 0.4699 - type_output_accuracy: 0.8228 - type_output_loss: 0.3645 - val_a_output_loss: 1.0990 - val_a_output_sparse_categorical_accuracy: 0.4125 - val_loss: 2.7261 - val_q_output_loss: 1.4992 - val_q_output_sparse_categorical_accuracy: 0.4667 - val_type_output_accuracy: 0.8000 - val_type_output_loss: 0.4264\n", + "\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 13ms/step - a_output_loss: 1.8070 - a_output_sparse_categorical_accuracy: 0.3722 - loss: 4.3404 - q_output_loss: 2.3366 - q_output_sparse_categorical_accuracy: 0.4081 - type_output_accuracy: 0.6783 - type_output_loss: 0.6439 - val_a_output_loss: 1.3923 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 3.9348 - val_q_output_loss: 2.3484 - val_q_output_sparse_categorical_accuracy: 0.4197 - val_type_output_accuracy: 0.6667 - val_type_output_loss: 0.6468\n", "Epoch 20/30\n", - "\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 28ms/step - a_output_loss: 1.2454 - a_output_sparse_categorical_accuracy: 0.3856 - loss: 2.7195 - q_output_loss: 1.3633 - q_output_sparse_categorical_accuracy: 0.4751 - type_output_accuracy: 0.8284 - type_output_loss: 0.3665 - val_a_output_loss: 1.1154 - val_a_output_sparse_categorical_accuracy: 0.4125 - val_loss: 2.7285 - val_q_output_loss: 1.4769 - val_q_output_sparse_categorical_accuracy: 0.4652 - val_type_output_accuracy: 0.7833 - val_type_output_loss: 0.4540\n", + "\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 13ms/step - a_output_loss: 1.7701 - a_output_sparse_categorical_accuracy: 0.3707 - loss: 4.2359 - q_output_loss: 2.2700 - q_output_sparse_categorical_accuracy: 0.4075 - type_output_accuracy: 0.7000 - type_output_loss: 0.6303 - val_a_output_loss: 1.3730 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 3.8547 - val_q_output_loss: 2.2902 - val_q_output_sparse_categorical_accuracy: 0.4197 - val_type_output_accuracy: 0.6667 - val_type_output_loss: 0.6385\n", "Epoch 21/30\n", - "\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 29ms/step - a_output_loss: 1.2583 - a_output_sparse_categorical_accuracy: 0.3896 - loss: 2.7032 - q_output_loss: 1.3383 - q_output_sparse_categorical_accuracy: 0.4719 - type_output_accuracy: 0.8246 - type_output_loss: 0.3594 - val_a_output_loss: 1.1381 - val_a_output_sparse_categorical_accuracy: 0.4125 - val_loss: 2.8064 - val_q_output_loss: 1.4643 - val_q_output_sparse_categorical_accuracy: 0.4788 - val_type_output_accuracy: 0.6500 - val_type_output_loss: 0.6802\n", + "\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 13ms/step - a_output_loss: 1.7757 - a_output_sparse_categorical_accuracy: 0.3722 - loss: 4.1755 - q_output_loss: 2.2117 - q_output_sparse_categorical_accuracy: 0.4081 - type_output_accuracy: 0.6587 - type_output_loss: 0.6294 - val_a_output_loss: 1.3681 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 3.7954 - val_q_output_loss: 2.2381 - val_q_output_sparse_categorical_accuracy: 0.4212 - val_type_output_accuracy: 0.6667 - val_type_output_loss: 0.6310\n", "Epoch 22/30\n", - "\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 28ms/step - a_output_loss: 1.3214 - a_output_sparse_categorical_accuracy: 0.3743 - loss: 2.7428 - q_output_loss: 1.3119 - q_output_sparse_categorical_accuracy: 0.4669 - type_output_accuracy: 0.8213 - type_output_loss: 0.3563 - val_a_output_loss: 1.1002 - val_a_output_sparse_categorical_accuracy: 0.4125 - val_loss: 2.7029 - val_q_output_loss: 1.4733 - val_q_output_sparse_categorical_accuracy: 0.4727 - val_type_output_accuracy: 0.8333 - val_type_output_loss: 0.4315\n", + "\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 13ms/step - a_output_loss: 1.7072 - a_output_sparse_categorical_accuracy: 0.3720 - loss: 4.0450 - q_output_loss: 2.1546 - q_output_sparse_categorical_accuracy: 0.4168 - type_output_accuracy: 0.6542 - type_output_loss: 0.6179 - val_a_output_loss: 1.3568 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 3.7286 - val_q_output_loss: 2.1905 - val_q_output_sparse_categorical_accuracy: 0.4333 - val_type_output_accuracy: 0.7167 - val_type_output_loss: 0.6042\n", "Epoch 23/30\n", - "\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 29ms/step - a_output_loss: 1.2769 - a_output_sparse_categorical_accuracy: 0.3907 - loss: 2.7114 - q_output_loss: 1.3188 - q_output_sparse_categorical_accuracy: 0.4772 - type_output_accuracy: 0.8281 - type_output_loss: 0.3824 - val_a_output_loss: 1.1039 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 2.7224 - val_q_output_loss: 1.4924 - val_q_output_sparse_categorical_accuracy: 0.4727 - val_type_output_accuracy: 0.8167 - val_type_output_loss: 0.4205\n", + "\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 13ms/step - a_output_loss: 1.7266 - a_output_sparse_categorical_accuracy: 0.3729 - loss: 4.0193 - q_output_loss: 2.1160 - q_output_sparse_categorical_accuracy: 0.4216 - type_output_accuracy: 0.6948 - type_output_loss: 0.5943 - val_a_output_loss: 1.3405 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 3.6712 - val_q_output_loss: 2.1447 - val_q_output_sparse_categorical_accuracy: 0.4333 - val_type_output_accuracy: 0.6667 - val_type_output_loss: 0.6197\n", "Epoch 24/30\n", - "\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 29ms/step - a_output_loss: 1.1957 - a_output_sparse_categorical_accuracy: 0.3969 - loss: 2.6335 - q_output_loss: 1.3299 - q_output_sparse_categorical_accuracy: 0.4776 - type_output_accuracy: 0.8342 - type_output_loss: 0.3683 - val_a_output_loss: 1.0974 - val_a_output_sparse_categorical_accuracy: 0.4167 - val_loss: 2.7003 - val_q_output_loss: 1.4677 - val_q_output_sparse_categorical_accuracy: 0.4667 - val_type_output_accuracy: 0.8000 - val_type_output_loss: 0.4504\n", + "\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 13ms/step - a_output_loss: 1.6801 - a_output_sparse_categorical_accuracy: 0.3773 - loss: 3.9227 - q_output_loss: 2.0804 - q_output_sparse_categorical_accuracy: 0.4252 - type_output_accuracy: 0.7457 - type_output_loss: 0.5462 - val_a_output_loss: 1.3245 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 3.6140 - val_q_output_loss: 2.1041 - val_q_output_sparse_categorical_accuracy: 0.4333 - val_type_output_accuracy: 0.6667 - val_type_output_loss: 0.6180\n", "Epoch 25/30\n", - "\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 28ms/step - a_output_loss: 1.2393 - a_output_sparse_categorical_accuracy: 0.3898 - loss: 2.6442 - q_output_loss: 1.2937 - q_output_sparse_categorical_accuracy: 0.4804 - type_output_accuracy: 0.8273 - type_output_loss: 0.3564 - val_a_output_loss: 1.1525 - val_a_output_sparse_categorical_accuracy: 0.4125 - val_loss: 2.7743 - val_q_output_loss: 1.4611 - val_q_output_sparse_categorical_accuracy: 0.4682 - val_type_output_accuracy: 0.7833 - val_type_output_loss: 0.5356\n", + "\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 13ms/step - a_output_loss: 1.6800 - a_output_sparse_categorical_accuracy: 0.3685 - loss: 3.8625 - q_output_loss: 2.0179 - q_output_sparse_categorical_accuracy: 0.4229 - type_output_accuracy: 0.7322 - type_output_loss: 0.5264 - val_a_output_loss: 1.2944 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 3.5249 - val_q_output_loss: 2.0656 - val_q_output_sparse_categorical_accuracy: 0.4333 - val_type_output_accuracy: 0.7000 - val_type_output_loss: 0.5497\n", "Epoch 26/30\n", - "\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 29ms/step - a_output_loss: 1.1872 - a_output_sparse_categorical_accuracy: 0.4001 - loss: 2.5917 - q_output_loss: 1.2930 - q_output_sparse_categorical_accuracy: 0.4802 - type_output_accuracy: 0.8280 - type_output_loss: 0.3555 - val_a_output_loss: 1.1505 - val_a_output_sparse_categorical_accuracy: 0.4083 - val_loss: 2.7684 - val_q_output_loss: 1.4587 - val_q_output_sparse_categorical_accuracy: 0.4742 - val_type_output_accuracy: 0.8000 - val_type_output_loss: 0.5307\n", + "\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 13ms/step - a_output_loss: 1.5786 - a_output_sparse_categorical_accuracy: 0.3724 - loss: 3.7165 - q_output_loss: 1.9876 - q_output_sparse_categorical_accuracy: 0.4253 - type_output_accuracy: 0.7853 - type_output_loss: 0.4956 - val_a_output_loss: 1.2680 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 3.4525 - val_q_output_loss: 2.0282 - val_q_output_sparse_categorical_accuracy: 0.4333 - val_type_output_accuracy: 0.7167 - val_type_output_loss: 0.5212\n", "Epoch 27/30\n", - "\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 28ms/step - a_output_loss: 1.1337 - a_output_sparse_categorical_accuracy: 0.4077 - loss: 2.5274 - q_output_loss: 1.2917 - q_output_sparse_categorical_accuracy: 0.4856 - type_output_accuracy: 0.8328 - type_output_loss: 0.3424 - val_a_output_loss: 1.1274 - val_a_output_sparse_categorical_accuracy: 0.4208 - val_loss: 2.7139 - val_q_output_loss: 1.4500 - val_q_output_sparse_categorical_accuracy: 0.4788 - val_type_output_accuracy: 0.8000 - val_type_output_loss: 0.4548\n", + "\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 13ms/step - a_output_loss: 1.6078 - a_output_sparse_categorical_accuracy: 0.3703 - loss: 3.7078 - q_output_loss: 1.9502 - q_output_sparse_categorical_accuracy: 0.4227 - type_output_accuracy: 0.7938 - type_output_loss: 0.4730 - val_a_output_loss: 1.2467 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 3.3875 - val_q_output_loss: 1.9954 - val_q_output_sparse_categorical_accuracy: 0.4364 - val_type_output_accuracy: 0.7833 - val_type_output_loss: 0.4848\n", "Epoch 28/30\n", - "\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 30ms/step - a_output_loss: 1.1081 - a_output_sparse_categorical_accuracy: 0.4104 - loss: 2.4859 - q_output_loss: 1.2903 - q_output_sparse_categorical_accuracy: 0.4845 - type_output_accuracy: 0.8714 - type_output_loss: 0.2936 - val_a_output_loss: 1.1394 - val_a_output_sparse_categorical_accuracy: 0.4167 - val_loss: 2.7244 - val_q_output_loss: 1.4512 - val_q_output_sparse_categorical_accuracy: 0.4652 - val_type_output_accuracy: 0.8167 - val_type_output_loss: 0.4457\n" + "\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 14ms/step - a_output_loss: 1.5627 - a_output_sparse_categorical_accuracy: 0.3749 - loss: 3.6136 - q_output_loss: 1.9109 - q_output_sparse_categorical_accuracy: 0.4312 - type_output_accuracy: 0.7908 - type_output_loss: 0.4706 - val_a_output_loss: 1.2410 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 3.3625 - val_q_output_loss: 1.9613 - val_q_output_sparse_categorical_accuracy: 0.4409 - val_type_output_accuracy: 0.7167 - val_type_output_loss: 0.5339\n", + "Epoch 29/30\n", + "\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 13ms/step - a_output_loss: 1.6156 - a_output_sparse_categorical_accuracy: 0.3713 - loss: 3.6313 - q_output_loss: 1.8821 - q_output_sparse_categorical_accuracy: 0.4360 - type_output_accuracy: 0.7989 - type_output_loss: 0.4600 - val_a_output_loss: 1.2161 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 3.2887 - val_q_output_loss: 1.9327 - val_q_output_sparse_categorical_accuracy: 0.4424 - val_type_output_accuracy: 0.7833 - val_type_output_loss: 0.4663\n", + "Epoch 30/30\n", + "\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 13ms/step - a_output_loss: 1.5910 - a_output_sparse_categorical_accuracy: 0.3707 - loss: 3.5723 - q_output_loss: 1.8495 - q_output_sparse_categorical_accuracy: 0.4395 - type_output_accuracy: 0.8028 - type_output_loss: 0.4485 - val_a_output_loss: 1.2147 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 3.2756 - val_q_output_loss: 1.9057 - val_q_output_sparse_categorical_accuracy: 0.4455 - val_type_output_accuracy: 0.7000 - val_type_output_loss: 0.5174\n" ] } ], @@ -529,7 +533,7 @@ }, { "cell_type": "code", - "execution_count": 37, + "execution_count": 106, "id": "3355c0c7", "metadata": {}, "outputs": [], @@ -589,8 +593,8 @@ "dec_q_inp = Input(shape=(1,), name=\"dec_q_in\")\n", "dec_emb_q = model.get_layer(\"embedding_q_decoder\").call(dec_q_inp)\n", "\n", - "state_h_dec = Input(shape=(256,), name=\"state_h_dec\")\n", - "state_c_dec = Input(shape=(256,), name=\"state_c_dec\")\n", + "state_h_dec = Input(shape=(units,), name=\"state_h_dec\")\n", + "state_c_dec = Input(shape=(units,), name=\"state_c_dec\")\n", "\n", "lstm_decoder_q = model.get_layer(\"lstm_q_decoder\")\n", "\n", @@ -612,8 +616,8 @@ "dec_a_inp = Input(shape=(1,), name=\"dec_a_in\")\n", "dec_emb_a = model.get_layer(\"embedding_a_decoder\").call(dec_a_inp)\n", "\n", - "state_h_a = Input(shape=(256,), name=\"state_h_a\")\n", - "state_c_a = Input(shape=(256,), name=\"state_c_a\")\n", + "state_h_a = Input(shape=(units,), name=\"state_h_a\")\n", + "state_c_a = Input(shape=(units,), name=\"state_c_a\")\n", "\n", "lstm_decoder_a = model.get_layer(\"lstm_a_decoder\")\n", "\n", @@ -642,7 +646,7 @@ }, { "cell_type": "code", - "execution_count": 38, + "execution_count": 107, "id": "d406e6ff", "metadata": {}, "outputs": [ @@ -650,7 +654,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "Generated Question: gilang maulana lahir di semarang pada 26 november 1983 ___\n", + "Generated Question: dimana dimana lahir ___\n", "Generated Answer : true\n", "Question Type : true_false\n" ] @@ -769,7 +773,7 @@ }, { "cell_type": "code", - "execution_count": 39, + "execution_count": 108, "id": "5adde3c3", "metadata": {}, "outputs": [ @@ -777,8 +781,8 @@ "name": "stdout", "output_type": "stream", "text": [ - "BLEU : 0.1185\n", - "ROUGE1: 0.3967 | ROUGE-L: 0.3967\n" + "BLEU : 0.0447\n", + "ROUGE1: 0.2281 | ROUGE-L: 0.2281\n" ] } ], diff --git a/question_generation/uji.py b/question_generation/uji.py new file mode 100644 index 0000000..b40936f --- /dev/null +++ b/question_generation/uji.py @@ -0,0 +1,389 @@ +# =============================================================== +# Seq2Seq‑LSTM + Luong Attention untuk Question‑Answer Generator +# + Greedy & Beam Search decoding + BLEU‑4 evaluation +# =============================================================== +# • Semua embedding mask_zero=True (padding di‑mask) +# • Encoder = Bidirectional LSTM (return_sequences=True) +# • Decoder = LSTM + Luong Attention (keras.layers.Attention). +# • Greedy & beam‑search inference sub‑model dibangun terpisah (encoder, +# decoder‑Q‑step, decoder‑A‑step). +# • BLEU score (nltk.corpus_bleu) untuk evaluasi pertanyaan & jawaban. +# --------------------------------------------------------------- +# PETUNJUK +# 1. pip install nltk +# 2. python seq2seq_qa_attention.py # train + simpan model +# 3. jalankan fungsi evaluate_bleu() # hitung BLEU di validation/test +# =============================================================== + +import json +from pathlib import Path +from itertools import chain +import numpy as np +import tensorflow as tf +from tensorflow.keras.preprocessing.sequence import pad_sequences +from tensorflow.keras.layers import ( + Input, Embedding, LSTM, Bidirectional, Concatenate, + Dense, TimeDistributed, Attention +) +from tensorflow.keras.models import Model +from nltk.translate.bleu_score import corpus_bleu # pip install nltk + +# ----------------------- 1. Load & flatten data ---------------------------- +RAW = json.loads(Path("../dataset/dev_dataset_test.json").read_text()) + +samples = [] +for item in RAW: + for qp in item["quiz_posibility"]: + samp = { + "tokens": [t.lower() for t in item["tokens"]], + "ner": item["ner"], + "srl": item["srl"], + "q_type": qp["type"], + "q_toks": [t.lower() for t in qp["question"]] + [""], + } + if isinstance(qp["answer"], list): + samp["a_toks"] = [t.lower() for t in qp["answer"]] + [" "] + else: + samp["a_toks"] = [qp["answer"].lower(), " "] + samples.append(samp) + +print("Total flattened samples:", len(samples)) + +# ----------------------- 2. Build vocabularies ----------------------------- + +def build_vocab(seq_iter, reserved=(" ", " ", " ", " ")): + vocab = {tok: idx for idx, tok in enumerate(reserved)} + for tok in chain.from_iterable(seq_iter): + if tok not in vocab: + vocab[tok] = len(vocab) + return vocab + +v_tok = build_vocab((s["tokens"] for s in samples)) +v_ner = build_vocab((s["ner"] for s in samples), reserved=(" ", " ")) +v_srl = build_vocab((s["srl"] for s in samples), reserved=(" ", " ")) +v_q = build_vocab((s["q_toks"] for s in samples)) +v_a = build_vocab((s["a_toks"] for s in samples)) +v_typ = {"isian": 0, "opsi": 1, "true_false": 2} + +iv_q = {i: t for t, i in v_q.items()} +iv_a = {i: t for t, i in v_a.items()} + +# ----------------------- 3. Vectorise + pad ------------------------------- + +def encode(seq, vmap): + return [vmap.get(tok, vmap[" "]) for tok in seq] + +MAX_SENT = max(len(s["tokens"]) for s in samples) +MAX_Q = max(len(s["q_toks"]) for s in samples) +MAX_A = max(len(s["a_toks"]) for s in samples) + +X_tok_ids = pad_sequences([encode(s["tokens"], v_tok) for s in samples], + maxlen=MAX_SENT, padding="post") +X_ner_ids = pad_sequences([encode(s["ner"], v_ner) for s in samples], + maxlen=MAX_SENT, padding="post") +X_srl_ids = pad_sequences([encode(s["srl"], v_srl) for s in samples], + maxlen=MAX_SENT, padding="post") + +q_in_ids = pad_sequences([[v_q[" "], *encode(s["q_toks"][:-1], v_q)] + for s in samples], maxlen=MAX_Q, padding="post") +q_out_ids = pad_sequences([encode(s["q_toks"], v_q) for s in samples], + maxlen=MAX_Q, padding="post") + +a_in_ids = pad_sequences([[v_a[" "], *encode(s["a_toks"][:-1], v_a)] + for s in samples], maxlen=MAX_A, padding="post") +a_out_ids = pad_sequences([encode(s["a_toks"], v_a) for s in samples], + maxlen=MAX_A, padding="post") + +y_type_ids = np.array([v_typ[s["q_type"]] for s in samples]) + +# ----------------------- 4. Hyper‑params ---------------------------------- +d_tok = 32 # token embedding dim +d_tag = 16 # NER / SRL embedding dim +units = 64 # per direction of BiLSTM +lat_dim = units * 2 + +# ----------------------- 5. Build model ----------------------------------- +# Encoder ---------------------------------------------------------- + +tok_in = Input((MAX_SENT,), dtype="int32", name="tok_in") +ner_in = Input((MAX_SENT,), dtype="int32", name="ner_in") +srl_in = Input((MAX_SENT,), dtype="int32", name="srl_in") + +emb_tok = Embedding(len(v_tok), d_tok, mask_zero=True, name="emb_tok")(tok_in) +emb_ner = Embedding(len(v_ner), d_tag, mask_zero=True, name="emb_ner")(ner_in) +emb_srl = Embedding(len(v_srl), d_tag, mask_zero=True, name="emb_srl")(srl_in) + +enc_concat = Concatenate(name="enc_concat")([emb_tok, emb_ner, emb_srl]) +bi_lstm = Bidirectional(LSTM(units, return_sequences=True, return_state=True), + name="encoder_bi_lstm") +enc_seq, f_h, f_c, b_h, b_c = bi_lstm(enc_concat) +enc_h = Concatenate()( [f_h, b_h] ) # (B, lat_dim) +enc_c = Concatenate()( [f_c, b_c] ) + +# Decoder – QUESTION ---------------------------------------------- +q_in = Input((MAX_Q,), dtype="int32", name="q_in") +# 💡 mask_zero=False supaya Attention tidak bentrok dengan mask encoder +q_emb = Embedding(len(v_q), d_tok, mask_zero=False, name="q_emb")(q_in) + +dec_q_lstm = LSTM(lat_dim, return_sequences=True, return_state=True, + name="decoder_q_lstm") +q_seq, q_h, q_c = dec_q_lstm(q_emb, initial_state=[enc_h, enc_c]) + +enc_proj_q = TimeDistributed(Dense(lat_dim), name="enc_proj_q")(enc_seq) +attn_q = Attention(name="attn_q")([q_seq, enc_proj_q]) +q_concat = Concatenate(name="q_concat")([q_seq, attn_q]) +q_out = TimeDistributed(Dense(len(v_q), activation="softmax"), name="q_out")(q_concat) + +# Decoder – ANSWER ------------------------------------------------- +a_in = Input((MAX_A,), dtype="int32", name="a_in") +# juga mask_zero=False +a_emb = Embedding(len(v_a), d_tok, mask_zero=False, name="a_emb")(a_in) + +dec_a_lstm = LSTM(lat_dim, return_sequences=True, return_state=True, + name="decoder_a_lstm") +a_seq, _, _ = dec_a_lstm(a_emb, initial_state=[q_h, q_c]) + +enc_proj_a = TimeDistributed(Dense(lat_dim), name="enc_proj_a")(enc_seq) +attn_a = Attention(name="attn_a")([a_seq, enc_proj_a]) +a_concat = Concatenate(name="a_concat")([a_seq, attn_a]) +a_out = TimeDistributed(Dense(len(v_a), activation="softmax"), name="a_out")(a_concat) + +# Classifier ------------------------------------------------------- +type_dense = Dense(len(v_typ), activation="softmax", name="type_out")(enc_h) + +model = Model(inputs=[tok_in, ner_in, srl_in, q_in, a_in], + outputs=[q_out, a_out, type_dense]) +model.summary() + +# ----------------------- 6. Compile & train ------------------------------ +losses = { + "q_out": "sparse_categorical_crossentropy", + "a_out": "sparse_categorical_crossentropy", + "type_out": "sparse_categorical_crossentropy", +} +loss_weights = {"q_out": 1.0, "a_out": 1.0, "type_out": 0.3} + +model.compile(optimizer="adam", loss=losses, loss_weights=loss_weights, + metrics={"q_out": "sparse_categorical_accuracy", + "a_out": "sparse_categorical_accuracy", + "type_out": "accuracy"}) + +history = model.fit( + [X_tok_ids, X_ner_ids, X_srl_ids, q_in_ids, a_in_ids], + [q_out_ids, a_out_ids, y_type_ids], + validation_split=0.1, + epochs=30, + batch_size=64, + callbacks=[tf.keras.callbacks.EarlyStopping(patience=4, restore_best_weights=True)], + verbose=1, +) + +model.save("seq2seq_attn.keras") +print("Model saved to seq2seq_attn.keras") + +# ----------------------- 7. Inference sub‑models -------------------------- +# Encoder model +encoder_model = Model([tok_in, ner_in, srl_in], [enc_seq, enc_h, enc_c]) + +# Question decoder step model ------------------------------------------------ +# Inputs +q_token_in = Input((1,), dtype="int32", name="q_token_in") +enc_seq_in = Input((MAX_SENT, lat_dim), name="enc_seq_in") +enc_proj_q_in = Input((MAX_SENT, lat_dim), name="enc_proj_q_in") +state_h_in = Input((lat_dim,), name="state_h_in") +state_c_in = Input((lat_dim,), name="state_c_in") + +# Embedding +q_emb_step = model.get_layer("q_emb")(q_token_in) + +# LSTM (reuse weights) +q_lstm_step, h_out, c_out = model.get_layer("decoder_q_lstm")(q_emb_step, + initial_state=[state_h_in, state_c_in]) +# Attention +attn_step = model.get_layer("attn_q")([q_lstm_step, enc_proj_q_in]) +q_concat_step = Concatenate()([q_lstm_step, attn_step]) +q_logits_step = model.get_layer("q_out")(q_concat_step) + +decoder_q_step = Model([q_token_in, enc_seq_in, enc_proj_q_in, state_h_in, state_c_in], + [q_logits_step, h_out, c_out]) + +# Answer decoder step model -------------------------------------------------- +a_token_in = Input((1,), dtype="int32", name="a_token_in") +enc_proj_a_in = Input((MAX_SENT, lat_dim), name="enc_proj_a_in") +state_h_a_in = Input((lat_dim,), name="state_h_a_in") +state_c_a_in = Input((lat_dim,), name="state_c_a_in") + +# Embedding reuse +a_emb_step = model.get_layer("a_emb")(a_token_in) + +# LSTM reuse +a_lstm_step, h_a_out, c_a_out = model.get_layer("decoder_a_lstm")(a_emb_step, + initial_state=[state_h_a_in, state_c_a_in]) +# Attention reuse +attn_a_step = model.get_layer("attn_a")([a_lstm_step, enc_proj_a_in]) +a_concat_step = Concatenate()([a_lstm_step, attn_a_step]) +a_logits_step = model.get_layer("a_out")(a_concat_step) + +decoder_a_step = Model([a_token_in, enc_proj_a_in, state_h_a_in, state_c_a_in], + [a_logits_step, h_a_out, c_a_out]) + +# ----------------------- 8. Decoding helpers ------------------------------ + +def encode_and_pad(seq, vmap, max_len): + ids = encode(seq, vmap) + return ids + [vmap[" "]] * (max_len - len(ids)) + + +def greedy_decode(tokens, ner, srl, max_q=20, max_a=10): + """Return generated (question_tokens, answer_tokens, q_type_str)""" + # --- encoder --------------------------------------------------------- + enc_tok = np.array([encode_and_pad(tokens, v_tok, MAX_SENT)]) + enc_ner = np.array([encode_and_pad(ner, v_ner, MAX_SENT)]) + enc_srl = np.array([encode_and_pad(srl, v_srl, MAX_SENT)]) + + enc_seq_val, h, c = encoder_model.predict([enc_tok, enc_ner, enc_srl], verbose=0) + enc_proj_q_val = model.get_layer("enc_proj_q")(enc_seq_val) + enc_proj_a_val = model.get_layer("enc_proj_a")(enc_seq_val) + + # --- greedy Question -------------------------------------------------- + q_ids = [] + tgt = np.array([[v_q[" "]]]) + for _ in range(max_q): + logits, h, c = decoder_q_step.predict([tgt, enc_seq_val, enc_proj_q_val, h, c], verbose=0) + next_id = int(logits[0, 0].argmax()) + if next_id == v_q[" "]: + break + q_ids.append(next_id) + tgt = np.array([[next_id]]) + + # --- reset state for Answer ------------------------------------------- + # Use last q_h, q_c (already in h,c) + a_ids = [] + tgt_a = np.array([[v_a[" "]]]) + for _ in range(max_a): + logits_a, h, c = decoder_a_step.predict([tgt_a, enc_proj_a_val, h, c], verbose=0) + next_a = int(logits_a[0, 0].argmax()) + if next_a == v_a[" "]: + break + a_ids.append(next_a) + tgt_a = np.array([[next_a]]) + + # Question type + typ_logits = model.predict([enc_tok, enc_ner, enc_srl, np.zeros((1, MAX_Q)), np.zeros((1, MAX_A))], verbose=0)[2] + typ_id = int(typ_logits.argmax()) + q_type = [k for k, v in v_typ.items() if v == typ_id][0] + + question = [iv_q.get(i, " ") for i in q_ids] + answer = [iv_a.get(i, " ") for i in a_ids] + return question, answer, q_type + + +def beam_decode(tokens, ner, srl, beam_width=5, max_q=20, max_a=10): + """Beam search decoding. Returns best (question_tokens, answer_tokens, q_type)""" + enc_tok = np.array([encode_and_pad(tokens, v_tok, MAX_SENT)]) + enc_ner = np.array([encode_and_pad(ner, v_ner, MAX_SENT)]) + enc_srl = np.array([encode_and_pad(srl, v_srl, MAX_SENT)]) + enc_seq_val, h0, c0 = encoder_model.predict([enc_tok, enc_ner, enc_srl], verbose=0) + enc_proj_q_val = model.get_layer("enc_proj_q")(enc_seq_val) + enc_proj_a_val = model.get_layer("enc_proj_a")(enc_seq_val) + + # ----- Beam for Question ---------------------------------------------- + Beam = [( [v_q[" "]], 0.0, h0, c0 )] # (sequence, logP, h, c) + completed_q = [] + for _ in range(max_q): + new_beam = [] + for seq, logp, h, c in Beam: + tgt = np.array([[seq[-1]]]) + logits, next_h, next_c = decoder_q_step.predict([tgt, enc_seq_val, enc_proj_q_val, h, c], verbose=0) + log_probs = np.log(logits[0, 0] + 1e-8) + top_ids = np.argsort(log_probs)[-beam_width:] + for nid in top_ids: + new_seq = seq + [int(nid)] + new_logp = logp + log_probs[nid] + new_beam.append( (new_seq, new_logp, next_h, next_c) ) + # keep best beam_width + Beam = sorted(new_beam, key=lambda x: x[1], reverse=True)[:beam_width] + # move completed + Beam, done = [], Beam # placeholder copy to modify + for seq, logp, h, c in done: + if seq[-1] == v_q[" "] or len(seq) >= max_q: + completed_q.append( (seq, logp, h, c) ) + else: + Beam.append( (seq, logp, h, c) ) + if not Beam: + break + if completed_q: + best_q = max(completed_q, key=lambda x: x[1]) + else: + best_q = max(Beam, key=lambda x: x[1]) + + q_seq_ids, _, h_q, c_q = best_q + q_ids = [i for i in q_seq_ids[1:] if i != v_q[" "]] + + # ----- Beam for Answer -------------------------------------------------- + Beam = [( [v_a[" "]], 0.0, h_q, c_q )] + completed_a = [] + for _ in range(max_a): + new_beam = [] + for seq, logp, h, c in Beam: + tgt = np.array([[seq[-1]]]) + logits, next_h, next_c = decoder_a_step.predict([tgt, enc_proj_a_val, h, c], verbose=0) + log_probs = np.log(logits[0, 0] + 1e-8) + top_ids = np.argsort(log_probs)[-beam_width:] + for nid in top_ids: + new_seq = seq + [int(nid)] + new_logp = logp + log_probs[nid] + new_beam.append( (new_seq, new_logp, next_h, next_c) ) + Beam = sorted(new_beam, key=lambda x: x[1], reverse=True)[:beam_width] + Beam, done = [], Beam + for seq, logp, h, c in done: + if seq[-1] == v_a[" "] or len(seq) >= max_a: + completed_a.append( (seq, logp) ) + else: + Beam.append( (seq, logp, h, c) ) + if not Beam: + break + if completed_a: + best_a_seq, _ = max(completed_a, key=lambda x: x[1]) + else: + best_a_seq, _ = max(Beam, key=lambda x: x[1]) + a_ids = [i for i in best_a_seq[1:] if i != v_a[" "]] + + # Question type classification + typ_logits = model.predict([enc_tok, enc_ner, enc_srl, np.zeros((1, MAX_Q)), np.zeros((1, MAX_A))], verbose=0)[2] + typ_id = int(typ_logits.argmax()) + q_type = [k for k, v in v_typ.items() if v == typ_id][0] + + question = [iv_q.get(i, " ") for i in q_ids] + answer = [iv_a.get(i, " ") for i in a_ids] + + return question, answer, q_type + +# ----------------------- 9. BLEU evaluation ------------------------------- + +def evaluate_bleu(split_ratio=0.1, beam=False): + """Compute corpus BLEU‑4 on hold‑out split.""" + n_total = len(samples) + n_val = int(n_total * split_ratio) + idxs = np.random.choice(n_total, n_val, replace=False) + + refs_q, hyps_q = [], [] + refs_a, hyps_a = [], [] + + for i in idxs: + s = samples[i] + question_pred, answer_pred, _ = (beam_decode if beam else greedy_decode)( + s["tokens"], s["ner"], s["srl"], + ) + refs_q.append([s["q_toks"][:-1]]) # exclude + hyps_q.append(question_pred) + refs_a.append([s["a_toks"][:-1]]) + hyps_a.append(answer_pred) + + bleu_q = corpus_bleu(refs_q, hyps_q) + bleu_a = corpus_bleu(refs_a, hyps_a) + print(f"BLEU‑4 Question: {bleu_q:.3f}\nBLEU‑4 Answer : {bleu_a:.3f}") + +# Example usage (uncomment): +evaluate_bleu(beam=False) +evaluate_bleu(beam=True)