feat: create model lstm for replace the old one and adding dataset

This commit is contained in:
akhdanre 2025-05-12 02:58:18 +07:00
parent a0f68a3c1b
commit 3cf689159c
37 changed files with 34304 additions and 672 deletions

Binary file not shown.

Before

Width:  |  Height:  |  Size: 66 KiB

After

Width:  |  Height:  |  Size: 61 KiB

File diff suppressed because one or more lines are too long

Binary file not shown.

Before

Width:  |  Height:  |  Size: 66 KiB

After

Width:  |  Height:  |  Size: 63 KiB

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

14591
dataset/dev_dataset_qg.json Normal file

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -7997,7 +7997,7 @@ negara O ARG2
Indonesia B-LOC ARG2 Indonesia B-LOC ARG2
. O O . O O
Pameran O EVENT ARG0 Pameran O ARG0
seni O ARG0 seni O ARG0
budaya O ARG0 budaya O ARG0
digelar O V digelar O V
@ -8037,7 +8037,7 @@ di O O
Bali B-LOC ARGM-LOC Bali B-LOC ARGM-LOC
. O O . O O
Pertunjukan O EVENT ARG0 Pertunjukan B-EVENT ARG0
wayang I-EVENT ARG0 wayang I-EVENT ARG0
kulit I-EVENT ARG0 kulit I-EVENT ARG0
menjadi O V menjadi O V

Can't render this file because it has a wrong number of fields in line 8040.

12002
dataset/test_dataset_qg.json Normal file

File diff suppressed because it is too large Load Diff

1007
dataset/test_dts.tsv Normal file

File diff suppressed because it is too large Load Diff

329
old/QC/model_tr.ipynb Normal file
View File

@ -0,0 +1,329 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "94d3889b",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2025-05-10 14:49:40.993078: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n",
"2025-05-10 14:49:40.996369: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.\n",
"2025-05-10 14:49:41.002001: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.\n",
"2025-05-10 14:49:41.015917: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
"WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n",
"E0000 00:00:1746863381.035097 166971 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
"E0000 00:00:1746863381.038978 166971 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n",
"W0000 00:00:1746863381.049265 166971 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.\n",
"W0000 00:00:1746863381.049288 166971 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.\n",
"W0000 00:00:1746863381.049289 166971 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.\n",
"W0000 00:00:1746863381.049290 166971 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.\n",
"2025-05-10 14:49:41.052642: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
"To enable the following instructions: AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n"
]
}
],
"source": [
"# -------------------------------------------------\n",
"# 0. Import & Konfigurasi\n",
"# -------------------------------------------------\n",
"import json, pickle\n",
"import numpy as np\n",
"from pathlib import Path\n",
"from collections import Counter\n",
"from sklearn.model_selection import train_test_split\n",
"\n",
"import tensorflow as tf\n",
"from tensorflow.keras.preprocessing.text import Tokenizer\n",
"from tensorflow.keras.preprocessing.sequence import pad_sequences\n",
"from tensorflow.keras.utils import to_categorical\n",
"from tensorflow.keras.layers import (\n",
" Input, Embedding, LSTM, Bidirectional, Dense, Concatenate,\n",
" TimeDistributed\n",
")\n",
"from tensorflow.keras.models import Model\n",
"from tensorflow.keras.callbacks import EarlyStopping\n",
"\n",
"PAD_TOKEN = \"<PAD>\"\n",
"UNK_TOKEN = \"UNK\"\n",
"START_TOKEN = \"<START>\"\n",
"END_TOKEN = \"<END>\"\n",
"MAXLEN_SRC = 100 # Panjang paragraf maksimal\n",
"MAXLEN_TGT = 40 # Panjang pertanyaan/jawaban maksimal\n",
"BATCH = 32\n",
"EPOCHS = 30"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "b528b34e",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Valid 325 / 325 (invalid index: [])\n"
]
}
],
"source": [
"raw = json.loads(Path(\"normalize_dataset.json\").read_text(encoding=\"utf-8\"))\n",
"\n",
"req = {\"tokens\",\"ner\",\"srl\",\"question\",\"answer\",\"type\"}\n",
"valid, bad = [], []\n",
"for i,item in enumerate(raw):\n",
" if (isinstance(item,dict) and not (req-item.keys())\n",
" and all(isinstance(item[k],list) for k in req-{\"type\"})\n",
" and isinstance(item[\"type\"],str)):\n",
" valid.append(item)\n",
" else:\n",
" bad.append(i)\n",
"\n",
"print(f\"Valid {len(valid)} / {len(raw)} (invalid index: {bad[:10]})\")"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "b18e4617",
"metadata": {},
"outputs": [],
"source": [
"for ex in valid:\n",
" ex[\"question_in\"] = [START_TOKEN] + ex[\"question\"]\n",
" ex[\"question_out\"] = ex[\"question\"] + [END_TOKEN]\n",
"\n",
" ex[\"answer_in\"] = [START_TOKEN] + ex[\"answer\"]\n",
" ex[\"answer_out\"] = ex[\"answer\"] + [END_TOKEN]"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "faa30b82",
"metadata": {},
"outputs": [],
"source": [
"tok_token = Tokenizer(oov_token=UNK_TOKEN, filters=\"\")\n",
"tok_ner = Tokenizer(lower=False, filters=\"\")\n",
"tok_srl = Tokenizer(lower=False, filters=\"\")\n",
"tok_q = Tokenizer(oov_token=UNK_TOKEN, filters=\"\")\n",
"tok_a = Tokenizer(oov_token=UNK_TOKEN, filters=\"\")\n",
"tok_type = Tokenizer(lower=False, filters=\"\")\n",
"\n",
"tok_token.fit_on_texts([ex[\"tokens\"] for ex in valid])\n",
"tok_ner.fit_on_texts([ex[\"ner\"] for ex in valid])\n",
"tok_srl.fit_on_texts([ex[\"srl\"] for ex in valid])\n",
"tok_q.fit_on_texts([ex[\"question_in\"]+ex[\"question_out\"] for ex in valid])\n",
"tok_a.fit_on_texts([ex[\"answer_in\"]+ex[\"answer_out\"] for ex in valid])\n",
"tok_type.fit_on_texts([ex[\"type\"] for ex in valid])\n",
"\n",
"# +1 utk padding\n",
"vocab_token = len(tok_token.word_index)+1\n",
"vocab_ner = len(tok_ner.word_index)+1\n",
"vocab_srl = len(tok_srl.word_index)+1\n",
"vocab_q = len(tok_q.word_index)+1\n",
"vocab_a = len(tok_a.word_index)+1\n",
"vocab_type = len(tok_type.word_index)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "c83ce734",
"metadata": {},
"outputs": [],
"source": [
"def seqs(field, tok, maxlen):\n",
" return pad_sequences(\n",
" tok.texts_to_sequences([ex[field] for ex in valid]),\n",
" maxlen=maxlen, padding=\"post\"\n",
" )\n",
"\n",
"X_tok = seqs(\"tokens\", tok_token, MAXLEN_SRC)\n",
"X_ner = seqs(\"ner\", tok_ner, MAXLEN_SRC)\n",
"X_srl = seqs(\"srl\", tok_srl, MAXLEN_SRC)\n",
"\n",
"Q_in = seqs(\"question_in\", tok_q, MAXLEN_TGT)\n",
"Q_out = seqs(\"question_out\", tok_q, MAXLEN_TGT)\n",
"A_in = seqs(\"answer_in\", tok_a, MAXLEN_TGT)\n",
"A_out = seqs(\"answer_out\", tok_a, MAXLEN_TGT)\n",
"\n",
"y_type = to_categorical(\n",
" np.array([seq[0]-1 for seq in tok_type.texts_to_sequences([ex[\"type\"] for ex in valid])]),\n",
" num_classes=vocab_type\n",
")\n",
"\n",
"# Expand dims → (batch, seq, 1) agar cocok dgn sparse_cce\n",
"Q_out = np.expand_dims(Q_out, -1)\n",
"A_out = np.expand_dims(A_out, -1)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "ad3fe7f2",
"metadata": {},
"outputs": [],
"source": [
"(X_tok_tr, X_tok_te,\n",
" X_ner_tr, X_ner_te,\n",
" X_srl_tr, X_srl_te,\n",
" Q_in_tr, Q_in_te,\n",
" Q_out_tr, Q_out_te,\n",
" A_in_tr, A_in_te,\n",
" A_out_tr, A_out_te,\n",
" y_type_tr,y_type_te) = train_test_split(\n",
" X_tok, X_ner, X_srl, Q_in, Q_out, A_in, A_out, y_type,\n",
" test_size=0.2, random_state=42\n",
" )\n",
" "
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "f20abfb5",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2025-05-10 14:49:43.127764: E external/local_xla/xla/stream_executor/cuda/cuda_platform.cc:51] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)\n"
]
},
{
"ename": "ValueError",
"evalue": "too many values to unpack (expected 3)",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[7], line 10\u001b[0m\n\u001b[1;32m 7\u001b[0m emb_srl \u001b[38;5;241m=\u001b[39m Embedding(vocab_srl, \u001b[38;5;241m16\u001b[39m, mask_zero\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)(enc_srl)\n\u001b[1;32m 9\u001b[0m enc_cat \u001b[38;5;241m=\u001b[39m Concatenate()([emb_tok, emb_ner, emb_srl])\n\u001b[0;32m---> 10\u001b[0m enc_out, state_h, state_c \u001b[38;5;241m=\u001b[39m Bidirectional(\n\u001b[1;32m 11\u001b[0m LSTM(\u001b[38;5;241m256\u001b[39m, return_state\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m, return_sequences\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m)\n\u001b[1;32m 12\u001b[0m )(enc_cat)\n\u001b[1;32m 14\u001b[0m \u001b[38;5;66;03m# ---------- Klasifikasi tipe ----------\u001b[39;00m\n\u001b[1;32m 15\u001b[0m type_out \u001b[38;5;241m=\u001b[39m Dense(vocab_type, activation\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124msoftmax\u001b[39m\u001b[38;5;124m\"\u001b[39m, name\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtype_output\u001b[39m\u001b[38;5;124m\"\u001b[39m)(enc_out)\n",
"\u001b[0;31mValueError\u001b[0m: too many values to unpack (expected 3)"
]
}
],
"source": [
"enc_tok = Input(shape=(None,), name=\"enc_tok\")\n",
"enc_ner = Input(shape=(None,), name=\"enc_ner\")\n",
"enc_srl = Input(shape=(None,), name=\"enc_srl\")\n",
"\n",
"emb_tok = Embedding(vocab_token, 128, mask_zero=True)(enc_tok)\n",
"emb_ner = Embedding(vocab_ner, 16, mask_zero=True)(enc_ner)\n",
"emb_srl = Embedding(vocab_srl, 16, mask_zero=True)(enc_srl)\n",
"\n",
"enc_cat = Concatenate()([emb_tok, emb_ner, emb_srl])\n",
"enc_out, state_h, state_c = Bidirectional(\n",
" LSTM(256, return_state=True, return_sequences=False)\n",
")(enc_cat)\n",
"\n",
"# ---------- Klasifikasi tipe ----------\n",
"type_out = Dense(vocab_type, activation=\"softmax\", name=\"type_output\")(enc_out)\n",
"\n",
"# ---------- Decoder QUESTION ----------\n",
"dec_q_in = Input(shape=(None,), name=\"dec_q_in\")\n",
"dec_q_emb = Embedding(vocab_q, 128, mask_zero=True)(dec_q_in)\n",
"dec_q_lstm = LSTM(256, return_sequences=True)\n",
"dec_q_out = dec_q_lstm(dec_q_emb, initial_state=[state_h, state_c])\n",
"q_out = TimeDistributed(Dense(vocab_q, activation=\"softmax\"), name=\"question_output\")(dec_q_out)\n",
"\n",
"# ---------- Decoder ANSWER ----------\n",
"dec_a_in = Input(shape=(None,), name=\"dec_a_in\")\n",
"dec_a_emb = Embedding(vocab_a, 128, mask_zero=True)(dec_a_in)\n",
"dec_a_lstm = LSTM(256, return_sequences=True)\n",
"dec_a_out = dec_a_lstm(dec_a_emb, initial_state=[state_h, state_c])\n",
"a_out = TimeDistributed(Dense(vocab_a, activation=\"softmax\"), name=\"answer_output\")(dec_a_out)\n",
"\n",
"# ---------- Build & compile ----------\n",
"model = Model(\n",
" inputs=[enc_tok, enc_ner, enc_srl, dec_q_in, dec_a_in],\n",
" outputs=[q_out, a_out, type_out]\n",
")\n",
"\n",
"model.compile(\n",
" optimizer=\"adam\",\n",
" loss={\n",
" \"question_output\": \"sparse_categorical_crossentropy\",\n",
" \"answer_output\" : \"sparse_categorical_crossentropy\",\n",
" \"type_output\" : \"categorical_crossentropy\"\n",
" },\n",
" loss_weights={\n",
" \"question_output\": 1.0,\n",
" \"answer_output\" : 1.0,\n",
" \"type_output\" : 0.3\n",
" },\n",
" metrics={\n",
" \"question_output\": \"accuracy\",\n",
" \"answer_output\" : \"accuracy\",\n",
" \"type_output\" : \"accuracy\"\n",
" }\n",
")\n",
"\n",
"model.summary()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c348406e",
"metadata": {},
"outputs": [],
"source": [
"early = EarlyStopping(patience=3, restore_best_weights=True)\n",
"\n",
"model.fit(\n",
" [X_tok_tr, X_ner_tr, X_srl_tr, Q_in_tr, A_in_tr],\n",
" {\"question_output\": Q_out_tr,\n",
" \"answer_output\" : A_out_tr,\n",
" \"type_output\" : y_type_tr},\n",
" batch_size=BATCH,\n",
" epochs=EPOCHS,\n",
" validation_split=0.1,\n",
" callbacks=[early]\n",
")\n",
"\n",
"# -------------------------------------------------\n",
"# 8. Simpan model & tokenizer\n",
"# -------------------------------------------------\n",
"model.save(\"qg_multitask.keras\")\n",
"with open(\"tokenizers.pkl\", \"wb\") as f:\n",
" pickle.dump({\n",
" \"token\": tok_token,\n",
" \"ner\" : tok_ner,\n",
" \"srl\" : tok_srl,\n",
" \"q\" : tok_q,\n",
" \"a\" : tok_a,\n",
" \"type\" : tok_type\n",
" }, f)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "myenv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.16"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@ -13716,10 +13716,10 @@
"type": "tof" "type": "tof"
}, },
{ {
"tokens": ["Indonesia", "terletak", "di", "Benua", "Afrika", "."], "tokens": ["Indonesia", "terletak", "di", "Benua", "asia", "."],
"ner": ["B-LOC", "O", "O", "O", "B-LOC", "O"], "ner": ["B-LOC", "O", "O", "O", "B-LOC", "O"],
"srl": ["ARG1", "V", "ARGM-LOC", "ARGM-LOC", "ARGM-LOC", "O"], "srl": ["ARG1", "V", "ARGM-LOC", "ARGM-LOC", "ARGM-LOC", "O"],
"question": ["Indonesia", "terletak", "di", "Benua", "Afrika", "."], "question": ["Indonesia", "terletak", "di", "Benua", "asia", "."],
"answer": ["false"], "answer": ["false"],
"type": "tof" "type": "tof"
} }

1914
old/QC/qg_dataset.json Normal file

File diff suppressed because it is too large Load Diff

View File

@ -2,30 +2,10 @@
"cells": [ "cells": [
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 4, "execution_count": 51,
"id": "9bf2159a", "id": "9bf2159a",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [],
{
"name": "stderr",
"output_type": "stream",
"text": [
"2025-05-02 15:16:40.916818: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n",
"2025-05-02 15:16:40.923426: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.\n",
"2025-05-02 15:16:40.983217: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.\n",
"2025-05-02 15:16:41.024477: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
"WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n",
"E0000 00:00:1746173801.069646 9825 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
"E0000 00:00:1746173801.081087 9825 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n",
"W0000 00:00:1746173801.169376 9825 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.\n",
"W0000 00:00:1746173801.169393 9825 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.\n",
"W0000 00:00:1746173801.169395 9825 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.\n",
"W0000 00:00:1746173801.169396 9825 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.\n",
"2025-05-02 15:16:41.179508: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
"To enable the following instructions: AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n"
]
}
],
"source": [ "source": [
"import json\n", "import json\n",
"import numpy as np\n", "import numpy as np\n",
@ -51,7 +31,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 5, "execution_count": 52,
"id": "50118278", "id": "50118278",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@ -60,15 +40,15 @@
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"\n", "\n",
" Jumlah data valid: 321 / 321\n", " Jumlah data valid: 70 / 70\n",
" Jumlah data tidak valid: 0\n", " Jumlah data tidak valid: 0\n",
"Counter({'ftb': 235, 'tof': 45, 'none': 41})\n" "Counter({'tof': 30, 'isian': 30, 'opsi': 10})\n"
] ]
} }
], ],
"source": [ "source": [
"# Load raw data\n", "# Load raw data\n",
"with open(\"normalize_dataset.json\", encoding=\"utf-8\") as f:\n", "with open(\"qg_dataset.json\", encoding=\"utf-8\") as f:\n",
" raw_data = json.load(f)\n", " raw_data = json.load(f)\n",
"\n", "\n",
"# Validasi lengkap\n", "# Validasi lengkap\n",
@ -123,7 +103,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 3, "execution_count": 53,
"id": "4e3a0088", "id": "4e3a0088",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -149,7 +129,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 4, "execution_count": 54,
"id": "555f9e22", "id": "555f9e22",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@ -157,7 +137,7 @@
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"{'ftb', 'tof', 'none'}\n" "{'isian', 'tof', 'opsi'}\n"
] ]
} }
], ],
@ -184,7 +164,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 5, "execution_count": 55,
"id": "f530cfe7", "id": "f530cfe7",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -200,25 +180,18 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 6, "execution_count": 56,
"id": "255e2a9a", "id": "255e2a9a",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2025-04-29 19:13:22.481835: E external/local_xla/xla/stream_executor/cuda/cuda_platform.cc:51] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)\n"
]
},
{ {
"data": { "data": {
"text/html": [ "text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">Model: \"functional\"</span>\n", "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">Model: \"functional_5\"</span>\n",
"</pre>\n" "</pre>\n"
], ],
"text/plain": [ "text/plain": [
"\u001b[1mModel: \"functional\"\u001b[0m\n" "\u001b[1mModel: \"functional_5\"\u001b[0m\n"
] ]
}, },
"metadata": {}, "metadata": {},
@ -239,30 +212,31 @@
"│ srl_input │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ - │\n", "│ srl_input │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ - │\n",
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">InputLayer</span>) │ │ │ │\n", "│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">InputLayer</span>) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ embedding │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">128</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">126,080</span> │ tok_input[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n", "│ embedding_15 │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">128</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">41,600</span> │ tok_input[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Embedding</span>) │ │ │ │\n", "│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Embedding</span>) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ embedding_1 │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">16</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">352</span> │ ner_input[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n", "│ embedding_16 │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">16</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">272</span> │ ner_input[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Embedding</span>) │ │ │ │\n", "│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Embedding</span>) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ embedding_2 │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">16</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">432</span> │ srl_input[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n", "│ embedding_17 │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">16</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">272</span> │ srl_input[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Embedding</span>) │ │ │ │\n", "│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Embedding</span>) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ concatenate │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">160</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ embedding[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>], │\n", "│ concatenate_5 │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">160</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ embedding_15[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">…</span> │\n",
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Concatenate</span>) │ │ │ embedding_1[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>… │\n", "│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Concatenate</span>) │ │ │ embedding_16[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">…</span> │\n",
"│ │ │ │ embedding_2[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n", "│ │ │ │ embedding_17[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">…</span> │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ lstm (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">LSTM</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">256</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">427,008</span> │ concatenate[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n", "│ lstm_5 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">LSTM</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">256</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">427,008</span> │ concatenate_5[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ get_item (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">GetItem</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">256</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ lstm[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n", "│ get_item_5 │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">256</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ lstm_5[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">GetItem</span>) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ question_output │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">473</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">121,561</span> │ lstm[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n", "│ question_output │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">272</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">69,904</span> │ lstm_5[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">TimeDistributed</span>) │ │ │ │\n", "│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">TimeDistributed</span>) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ answer_output │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">383</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">98,431</span> │ lstm[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n", "│ answer_output │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">60</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">15,420</span> │ lstm_5[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">TimeDistributed</span>) │ │ │ │\n", "│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">TimeDistributed</span>) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ type_output (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Dense</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">3</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">771</span> │ get_item[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n", "│ type_output (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Dense</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">3</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">771</span> │ get_item_5[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
"└─────────────────────┴───────────────────┴────────────┴───────────────────┘\n", "└─────────────────────┴───────────────────┴────────────┴───────────────────┘\n",
"</pre>\n" "</pre>\n"
], ],
@ -279,30 +253,31 @@
"│ srl_input │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n", "│ srl_input │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n",
"│ (\u001b[38;5;33mInputLayer\u001b[0m) │ │ │ │\n", "│ (\u001b[38;5;33mInputLayer\u001b[0m) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ embedding │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m128\u001b[0m) │ \u001b[38;5;34m126,080\u001b[0m │ tok_input[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", "│ embedding_15 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m128\u001b[0m) │ \u001b[38;5;34m41,600\u001b[0m │ tok_input[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"│ (\u001b[38;5;33mEmbedding\u001b[0m) │ │ │ │\n", "│ (\u001b[38;5;33mEmbedding\u001b[0m) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ embedding_1 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m16\u001b[0m) │ \u001b[38;5;34m352\u001b[0m │ ner_input[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", "│ embedding_16 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m16\u001b[0m) │ \u001b[38;5;34m272\u001b[0m │ ner_input[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"│ (\u001b[38;5;33mEmbedding\u001b[0m) │ │ │ │\n", "│ (\u001b[38;5;33mEmbedding\u001b[0m) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ embedding_2 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m16\u001b[0m) │ \u001b[38;5;34m432\u001b[0m │ srl_input[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", "│ embedding_17 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m16\u001b[0m) │ \u001b[38;5;34m272\u001b[0m │ srl_input[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"│ (\u001b[38;5;33mEmbedding\u001b[0m) │ │ │ │\n", "│ (\u001b[38;5;33mEmbedding\u001b[0m) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ concatenate │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m160\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ embedding[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n", "│ concatenate_5 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m160\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ embedding_15[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m…\u001b[0m │\n",
"│ (\u001b[38;5;33mConcatenate\u001b[0m) │ │ │ embedding_1[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m… │\n", "│ (\u001b[38;5;33mConcatenate\u001b[0m) │ │ │ embedding_16[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m…\u001b[0m │\n",
"│ │ │ │ embedding_2[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", "│ │ │ │ embedding_17[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m…\u001b[0m │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ lstm (\u001b[38;5;33mLSTM\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m256\u001b[0m) │ \u001b[38;5;34m427,008\u001b[0m │ concatenate[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", "│ lstm_5 (\u001b[38;5;33mLSTM\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m256\u001b[0m) │ \u001b[38;5;34m427,008\u001b[0m │ concatenate_5[\u001b[38;5;34m0\u001b[0m]… │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ get_item (\u001b[38;5;33mGetItem\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m256\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ lstm[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", "│ get_item_5 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m256\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ lstm_5[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"│ (\u001b[38;5;33mGetItem\u001b[0m) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ question_output │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m473\u001b[0m) │ \u001b[38;5;34m121,561\u001b[0m │ lstm[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", "│ question_output │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m272\u001b[0m) │ \u001b[38;5;34m69,904\u001b[0m │ lstm_5[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"│ (\u001b[38;5;33mTimeDistributed\u001b[0m) │ │ │ │\n", "│ (\u001b[38;5;33mTimeDistributed\u001b[0m) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ answer_output │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m383\u001b[0m) │ \u001b[38;5;34m98,431\u001b[0m │ lstm[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", "│ answer_output │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m60\u001b[0m) │ \u001b[38;5;34m15,420\u001b[0m │ lstm_5[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"│ (\u001b[38;5;33mTimeDistributed\u001b[0m) │ │ │ │\n", "│ (\u001b[38;5;33mTimeDistributed\u001b[0m) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ type_output (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m3\u001b[0m) │ \u001b[38;5;34m771\u001b[0m │ get_item[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", "│ type_output (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m3\u001b[0m) │ \u001b[38;5;34m771\u001b[0m │ get_item_5[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"└─────────────────────┴───────────────────┴────────────┴───────────────────┘\n" "└─────────────────────┴───────────────────┴────────────┴───────────────────┘\n"
] ]
}, },
@ -312,11 +287,11 @@
{ {
"data": { "data": {
"text/html": [ "text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Total params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">774,635</span> (2.95 MB)\n", "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Total params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">555,247</span> (2.12 MB)\n",
"</pre>\n" "</pre>\n"
], ],
"text/plain": [ "text/plain": [
"\u001b[1m Total params: \u001b[0m\u001b[38;5;34m774,635\u001b[0m (2.95 MB)\n" "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m555,247\u001b[0m (2.12 MB)\n"
] ]
}, },
"metadata": {}, "metadata": {},
@ -325,11 +300,11 @@
{ {
"data": { "data": {
"text/html": [ "text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Trainable params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">774,635</span> (2.95 MB)\n", "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Trainable params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">555,247</span> (2.12 MB)\n",
"</pre>\n" "</pre>\n"
], ],
"text/plain": [ "text/plain": [
"\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m774,635\u001b[0m (2.95 MB)\n" "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m555,247\u001b[0m (2.12 MB)\n"
] ]
}, },
"metadata": {}, "metadata": {},
@ -353,23 +328,61 @@
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"Epoch 1/30\n", "Epoch 1/30\n",
"\u001b[1m7/7\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 101ms/step - answer_output_accuracy: 0.5626 - answer_output_loss: 5.7629 - loss: 12.9112 - question_output_accuracy: 0.3867 - question_output_loss: 6.0185 - type_output_accuracy: 0.5290 - type_output_loss: 1.0943 - val_answer_output_accuracy: 0.9261 - val_answer_output_loss: 3.9036 - val_loss: 9.5865 - val_question_output_accuracy: 0.7500 - val_question_output_loss: 4.5947 - val_type_output_accuracy: 0.5652 - val_type_output_loss: 1.0883\n", "\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 3s/step - answer_output_accuracy: 0.0030 - answer_output_loss: 4.1163 - loss: 10.8193 - question_output_accuracy: 0.0030 - question_output_loss: 5.6031 - type_output_accuracy: 0.2000 - type_output_loss: 1.0999 - val_answer_output_accuracy: 0.8833 - val_answer_output_loss: 4.0123 - val_loss: 10.6706 - val_question_output_accuracy: 0.6000 - val_question_output_loss: 5.5595 - val_type_output_accuracy: 0.1667 - val_type_output_loss: 1.0987\n",
"Epoch 2/30\n", "Epoch 2/30\n",
"\u001b[1m7/7\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 32ms/step - answer_output_accuracy: 0.8791 - answer_output_loss: 2.9526 - loss: 7.7800 - question_output_accuracy: 0.6837 - question_output_loss: 3.7162 - type_output_accuracy: 0.7148 - type_output_loss: 1.0672 - val_answer_output_accuracy: 0.9261 - val_answer_output_loss: 1.1139 - val_loss: 4.1230 - val_question_output_accuracy: 0.7500 - val_question_output_loss: 1.9489 - val_type_output_accuracy: 0.5652 - val_type_output_loss: 1.0601\n", "\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 95ms/step - answer_output_accuracy: 0.8800 - answer_output_loss: 4.0174 - loss: 10.6778 - question_output_accuracy: 0.5640 - question_output_loss: 5.5631 - type_output_accuracy: 0.4200 - type_output_loss: 1.0973 - val_answer_output_accuracy: 0.9250 - val_answer_output_loss: 3.8939 - val_loss: 10.4860 - val_question_output_accuracy: 0.6250 - val_question_output_loss: 5.4945 - val_type_output_accuracy: 0.3333 - val_type_output_loss: 1.0976\n",
"Epoch 3/30\n", "Epoch 3/30\n",
"\u001b[1m7/7\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 34ms/step - answer_output_accuracy: 0.8726 - answer_output_loss: 1.2047 - loss: 4.4213 - question_output_accuracy: 0.6797 - question_output_loss: 2.2016 - type_output_accuracy: 0.7251 - type_output_loss: 1.0092 - val_answer_output_accuracy: 0.9261 - val_answer_output_loss: 0.7679 - val_loss: 3.7423 - val_question_output_accuracy: 0.7500 - val_question_output_loss: 1.9604 - val_type_output_accuracy: 0.5652 - val_type_output_loss: 1.0140\n", "\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 91ms/step - answer_output_accuracy: 0.9370 - answer_output_loss: 3.9075 - loss: 10.5064 - question_output_accuracy: 0.5870 - question_output_loss: 5.5043 - type_output_accuracy: 0.6200 - type_output_loss: 1.0946 - val_answer_output_accuracy: 0.9250 - val_answer_output_loss: 3.7157 - val_loss: 10.1938 - val_question_output_accuracy: 0.6250 - val_question_output_loss: 5.3815 - val_type_output_accuracy: 0.3333 - val_type_output_loss: 1.0965\n",
"Epoch 4/30\n", "Epoch 4/30\n",
"\u001b[1m7/7\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 32ms/step - answer_output_accuracy: 0.8633 - answer_output_loss: 1.1478 - loss: 4.4374 - question_output_accuracy: 0.6639 - question_output_loss: 2.3671 - type_output_accuracy: 0.7490 - type_output_loss: 0.9088 - val_answer_output_accuracy: 0.9261 - val_answer_output_loss: 0.7059 - val_loss: 3.6255 - val_question_output_accuracy: 0.7500 - val_question_output_loss: 1.9356 - val_type_output_accuracy: 0.5652 - val_type_output_loss: 0.9840\n", "\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 90ms/step - answer_output_accuracy: 0.9380 - answer_output_loss: 3.7435 - loss: 10.2381 - question_output_accuracy: 0.5890 - question_output_loss: 5.4027 - type_output_accuracy: 0.6200 - type_output_loss: 1.0919 - val_answer_output_accuracy: 0.9250 - val_answer_output_loss: 3.4257 - val_loss: 9.7085 - val_question_output_accuracy: 0.6250 - val_question_output_loss: 5.1873 - val_type_output_accuracy: 0.3333 - val_type_output_loss: 1.0955\n",
"Epoch 5/30\n", "Epoch 5/30\n",
"\u001b[1m7/7\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 33ms/step - answer_output_accuracy: 0.8783 - answer_output_loss: 1.0187 - loss: 4.0230 - question_output_accuracy: 0.6760 - question_output_loss: 2.1959 - type_output_accuracy: 0.7563 - type_output_loss: 0.8131 - val_answer_output_accuracy: 0.9261 - val_answer_output_loss: 0.6848 - val_loss: 3.5743 - val_question_output_accuracy: 0.7500 - val_question_output_loss: 1.9039 - val_type_output_accuracy: 0.5652 - val_type_output_loss: 0.9857\n", "\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 90ms/step - answer_output_accuracy: 0.9380 - answer_output_loss: 3.4788 - loss: 9.7970 - question_output_accuracy: 0.5850 - question_output_loss: 5.2288 - type_output_accuracy: 0.6600 - type_output_loss: 1.0894 - val_answer_output_accuracy: 0.9250 - val_answer_output_loss: 2.9617 - val_loss: 8.9146 - val_question_output_accuracy: 0.6250 - val_question_output_loss: 4.8585 - val_type_output_accuracy: 0.3333 - val_type_output_loss: 1.0944\n",
"Epoch 6/30\n", "Epoch 6/30\n",
"\u001b[1m7/7\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 33ms/step - answer_output_accuracy: 0.8800 - answer_output_loss: 0.9845 - loss: 3.8171 - question_output_accuracy: 0.6878 - question_output_loss: 2.0357 - type_output_accuracy: 0.7328 - type_output_loss: 0.7942 - val_answer_output_accuracy: 0.9261 - val_answer_output_loss: 0.6742 - val_loss: 3.5592 - val_question_output_accuracy: 0.7500 - val_question_output_loss: 1.8777 - val_type_output_accuracy: 0.5652 - val_type_output_loss: 1.0074\n", "\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 95ms/step - answer_output_accuracy: 0.9380 - answer_output_loss: 3.0565 - loss: 9.0790 - question_output_accuracy: 0.5850 - question_output_loss: 4.9355 - type_output_accuracy: 0.6600 - type_output_loss: 1.0869 - val_answer_output_accuracy: 0.9250 - val_answer_output_loss: 2.3649 - val_loss: 7.8024 - val_question_output_accuracy: 0.6250 - val_question_output_loss: 4.3441 - val_type_output_accuracy: 0.3333 - val_type_output_loss: 1.0933\n",
"Epoch 7/30\n", "Epoch 7/30\n",
"\u001b[1m7/7\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 32ms/step - answer_output_accuracy: 0.8768 - answer_output_loss: 0.9756 - loss: 3.8569 - question_output_accuracy: 0.6743 - question_output_loss: 2.0795 - type_output_accuracy: 0.7030 - type_output_loss: 0.8039 - val_answer_output_accuracy: 0.9261 - val_answer_output_loss: 0.6769 - val_loss: 3.5671 - val_question_output_accuracy: 0.7500 - val_question_output_loss: 1.8631 - val_type_output_accuracy: 0.5652 - val_type_output_loss: 1.0272\n", "\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 95ms/step - answer_output_accuracy: 0.9380 - answer_output_loss: 2.5004 - loss: 8.0585 - question_output_accuracy: 0.5850 - question_output_loss: 4.4735 - type_output_accuracy: 0.6600 - type_output_loss: 1.0845 - val_answer_output_accuracy: 0.9250 - val_answer_output_loss: 1.8898 - val_loss: 6.6823 - val_question_output_accuracy: 0.6250 - val_question_output_loss: 3.7005 - val_type_output_accuracy: 0.3333 - val_type_output_loss: 1.0920\n",
"Epoch 8/30\n", "Epoch 8/30\n",
"\u001b[1m7/7\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 32ms/step - answer_output_accuracy: 0.8814 - answer_output_loss: 0.9217 - loss: 3.7726 - question_output_accuracy: 0.6798 - question_output_loss: 2.0253 - type_output_accuracy: 0.6785 - type_output_loss: 0.8194 - val_answer_output_accuracy: 0.9261 - val_answer_output_loss: 0.6900 - val_loss: 3.5722 - val_question_output_accuracy: 0.7500 - val_question_output_loss: 1.8469 - val_type_output_accuracy: 0.5652 - val_type_output_loss: 1.0354\n", "\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 88ms/step - answer_output_accuracy: 0.9380 - answer_output_loss: 2.0239 - loss: 6.9823 - question_output_accuracy: 0.5850 - question_output_loss: 3.8764 - type_output_accuracy: 0.6600 - type_output_loss: 1.0821 - val_answer_output_accuracy: 0.9250 - val_answer_output_loss: 1.5873 - val_loss: 5.7713 - val_question_output_accuracy: 0.6250 - val_question_output_loss: 3.0934 - val_type_output_accuracy: 0.3333 - val_type_output_loss: 1.0906\n",
"Epoch 9/30\n", "Epoch 9/30\n",
"\u001b[1m7/7\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 31ms/step - answer_output_accuracy: 0.8703 - answer_output_loss: 0.9799 - loss: 3.6985 - question_output_accuracy: 0.6843 - question_output_loss: 1.9755 - type_output_accuracy: 0.7160 - type_output_loss: 0.7474 - val_answer_output_accuracy: 0.9261 - val_answer_output_loss: 0.6958 - val_loss: 3.5849 - val_question_output_accuracy: 0.7500 - val_question_output_loss: 1.8401 - val_type_output_accuracy: 0.5652 - val_type_output_loss: 1.0490\n" "\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 93ms/step - answer_output_accuracy: 0.9380 - answer_output_loss: 1.6939 - loss: 6.0594 - question_output_accuracy: 0.5850 - question_output_loss: 3.2857 - type_output_accuracy: 0.6600 - type_output_loss: 1.0798 - val_answer_output_accuracy: 0.9250 - val_answer_output_loss: 1.3585 - val_loss: 5.0778 - val_question_output_accuracy: 0.6250 - val_question_output_loss: 2.6303 - val_type_output_accuracy: 0.3333 - val_type_output_loss: 1.0890\n",
"Epoch 10/30\n",
"\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 97ms/step - answer_output_accuracy: 0.9380 - answer_output_loss: 1.4268 - loss: 5.3244 - question_output_accuracy: 0.5850 - question_output_loss: 2.8203 - type_output_accuracy: 0.6600 - type_output_loss: 1.0774 - val_answer_output_accuracy: 0.9250 - val_answer_output_loss: 1.1559 - val_loss: 4.5630 - val_question_output_accuracy: 0.6250 - val_question_output_loss: 2.3200 - val_type_output_accuracy: 0.3333 - val_type_output_loss: 1.0871\n",
"Epoch 11/30\n",
"\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 93ms/step - answer_output_accuracy: 0.9380 - answer_output_loss: 1.1880 - loss: 4.7795 - question_output_accuracy: 0.5850 - question_output_loss: 2.5167 - type_output_accuracy: 0.6600 - type_output_loss: 1.0748 - val_answer_output_accuracy: 0.9250 - val_answer_output_loss: 0.9716 - val_loss: 4.2001 - val_question_output_accuracy: 0.6250 - val_question_output_loss: 2.1437 - val_type_output_accuracy: 0.3333 - val_type_output_loss: 1.0848\n",
"Epoch 12/30\n",
"\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 95ms/step - answer_output_accuracy: 0.9380 - answer_output_loss: 0.9857 - loss: 4.4356 - question_output_accuracy: 0.5850 - question_output_loss: 2.3778 - type_output_accuracy: 0.6600 - type_output_loss: 1.0721 - val_answer_output_accuracy: 0.9250 - val_answer_output_loss: 0.8171 - val_loss: 3.9799 - val_question_output_accuracy: 0.6250 - val_question_output_loss: 2.0807 - val_type_output_accuracy: 0.3333 - val_type_output_loss: 1.0822\n",
"Epoch 13/30\n",
"\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 92ms/step - answer_output_accuracy: 0.9380 - answer_output_loss: 0.8280 - loss: 4.2715 - question_output_accuracy: 0.5850 - question_output_loss: 2.3745 - type_output_accuracy: 0.6600 - type_output_loss: 1.0690 - val_answer_output_accuracy: 0.9250 - val_answer_output_loss: 0.6995 - val_loss: 3.8760 - val_question_output_accuracy: 0.6250 - val_question_output_loss: 2.0974 - val_type_output_accuracy: 0.3333 - val_type_output_loss: 1.0790\n",
"Epoch 14/30\n",
"\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 91ms/step - answer_output_accuracy: 0.9380 - answer_output_loss: 0.7110 - loss: 4.2264 - question_output_accuracy: 0.5850 - question_output_loss: 2.4498 - type_output_accuracy: 0.6400 - type_output_loss: 1.0656 - val_answer_output_accuracy: 0.9250 - val_answer_output_loss: 0.6143 - val_loss: 3.8415 - val_question_output_accuracy: 0.6250 - val_question_output_loss: 2.1518 - val_type_output_accuracy: 0.5000 - val_type_output_loss: 1.0754\n",
"Epoch 15/30\n",
"\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 95ms/step - answer_output_accuracy: 0.9380 - answer_output_loss: 0.6254 - loss: 4.2353 - question_output_accuracy: 0.5850 - question_output_loss: 2.5482 - type_output_accuracy: 0.6000 - type_output_loss: 1.0617 - val_answer_output_accuracy: 0.9250 - val_answer_output_loss: 0.5529 - val_loss: 3.8335 - val_question_output_accuracy: 0.6250 - val_question_output_loss: 2.2091 - val_type_output_accuracy: 0.5000 - val_type_output_loss: 1.0714\n",
"Epoch 16/30\n",
"\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 94ms/step - answer_output_accuracy: 0.9380 - answer_output_loss: 0.5623 - loss: 4.2530 - question_output_accuracy: 0.5850 - question_output_loss: 2.6334 - type_output_accuracy: 0.6000 - type_output_loss: 1.0573 - val_answer_output_accuracy: 0.9250 - val_answer_output_loss: 0.5083 - val_loss: 3.8255 - val_question_output_accuracy: 0.6250 - val_question_output_loss: 2.2502 - val_type_output_accuracy: 0.5000 - val_type_output_loss: 1.0670\n",
"Epoch 17/30\n",
"\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 96ms/step - answer_output_accuracy: 0.9380 - answer_output_loss: 0.5150 - loss: 4.2561 - question_output_accuracy: 0.5850 - question_output_loss: 2.6886 - type_output_accuracy: 0.6000 - type_output_loss: 1.0525 - val_answer_output_accuracy: 0.9250 - val_answer_output_loss: 0.4752 - val_loss: 3.8053 - val_question_output_accuracy: 0.6250 - val_question_output_loss: 2.2678 - val_type_output_accuracy: 0.5000 - val_type_output_loss: 1.0623\n",
"Epoch 18/30\n",
"\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 104ms/step - answer_output_accuracy: 0.9380 - answer_output_loss: 0.4790 - loss: 4.2357 - question_output_accuracy: 0.5850 - question_output_loss: 2.7094 - type_output_accuracy: 0.6000 - type_output_loss: 1.0473 - val_answer_output_accuracy: 0.9250 - val_answer_output_loss: 0.4503 - val_loss: 3.7689 - val_question_output_accuracy: 0.6250 - val_question_output_loss: 2.2612 - val_type_output_accuracy: 0.5000 - val_type_output_loss: 1.0573\n",
"Epoch 19/30\n",
"\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 94ms/step - answer_output_accuracy: 0.9380 - answer_output_loss: 0.4511 - loss: 4.1904 - question_output_accuracy: 0.5850 - question_output_loss: 2.6974 - type_output_accuracy: 0.5600 - type_output_loss: 1.0419 - val_answer_output_accuracy: 0.9250 - val_answer_output_loss: 0.4313 - val_loss: 3.7162 - val_question_output_accuracy: 0.6250 - val_question_output_loss: 2.2327 - val_type_output_accuracy: 0.5000 - val_type_output_loss: 1.0523\n",
"Epoch 20/30\n",
"\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 94ms/step - answer_output_accuracy: 0.9380 - answer_output_loss: 0.4293 - loss: 4.1218 - question_output_accuracy: 0.5850 - question_output_loss: 2.6564 - type_output_accuracy: 0.5600 - type_output_loss: 1.0361 - val_answer_output_accuracy: 0.9250 - val_answer_output_loss: 0.4164 - val_loss: 3.6494 - val_question_output_accuracy: 0.6250 - val_question_output_loss: 2.1859 - val_type_output_accuracy: 0.5000 - val_type_output_loss: 1.0471\n",
"Epoch 21/30\n",
"\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 93ms/step - answer_output_accuracy: 0.9380 - answer_output_loss: 0.4118 - loss: 4.0332 - question_output_accuracy: 0.5850 - question_output_loss: 2.5912 - type_output_accuracy: 0.5600 - type_output_loss: 1.0302 - val_answer_output_accuracy: 0.9250 - val_answer_output_loss: 0.4046 - val_loss: 3.5722 - val_question_output_accuracy: 0.6250 - val_question_output_loss: 2.1256 - val_type_output_accuracy: 0.5000 - val_type_output_loss: 1.0420\n",
"Epoch 22/30\n",
"\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 95ms/step - answer_output_accuracy: 0.9380 - answer_output_loss: 0.3975 - loss: 3.9297 - question_output_accuracy: 0.5850 - question_output_loss: 2.5080 - type_output_accuracy: 0.5600 - type_output_loss: 1.0242 - val_answer_output_accuracy: 0.9250 - val_answer_output_loss: 0.3951 - val_loss: 3.4909 - val_question_output_accuracy: 0.6250 - val_question_output_loss: 2.0587 - val_type_output_accuracy: 0.5000 - val_type_output_loss: 1.0370\n",
"Epoch 23/30\n",
"\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 94ms/step - answer_output_accuracy: 0.9380 - answer_output_loss: 0.3858 - loss: 3.8184 - question_output_accuracy: 0.5850 - question_output_loss: 2.4147 - type_output_accuracy: 0.5600 - type_output_loss: 1.0180 - val_answer_output_accuracy: 0.9250 - val_answer_output_loss: 0.3875 - val_loss: 3.4143 - val_question_output_accuracy: 0.6250 - val_question_output_loss: 1.9948 - val_type_output_accuracy: 0.5000 - val_type_output_loss: 1.0321\n",
"Epoch 24/30\n",
"\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 91ms/step - answer_output_accuracy: 0.9380 - answer_output_loss: 0.3759 - loss: 3.7097 - question_output_accuracy: 0.5850 - question_output_loss: 2.3222 - type_output_accuracy: 0.5600 - type_output_loss: 1.0116 - val_answer_output_accuracy: 0.9250 - val_answer_output_loss: 0.3812 - val_loss: 3.3557 - val_question_output_accuracy: 0.6250 - val_question_output_loss: 1.9473 - val_type_output_accuracy: 0.5000 - val_type_output_loss: 1.0273\n",
"Epoch 25/30\n",
"\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 95ms/step - answer_output_accuracy: 0.9380 - answer_output_loss: 0.3674 - loss: 3.6180 - question_output_accuracy: 0.5850 - question_output_loss: 2.2455 - type_output_accuracy: 0.5600 - type_output_loss: 1.0051 - val_answer_output_accuracy: 0.9250 - val_answer_output_loss: 0.3759 - val_loss: 3.3316 - val_question_output_accuracy: 0.6250 - val_question_output_loss: 1.9330 - val_type_output_accuracy: 0.5000 - val_type_output_loss: 1.0227\n",
"Epoch 26/30\n",
"\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 96ms/step - answer_output_accuracy: 0.9380 - answer_output_loss: 0.3600 - loss: 3.5615 - question_output_accuracy: 0.5850 - question_output_loss: 2.2030 - type_output_accuracy: 0.5400 - type_output_loss: 0.9985 - val_answer_output_accuracy: 0.9250 - val_answer_output_loss: 0.3715 - val_loss: 3.3519 - val_question_output_accuracy: 0.6250 - val_question_output_loss: 1.9622 - val_type_output_accuracy: 0.5000 - val_type_output_loss: 1.0183\n",
"Epoch 27/30\n",
"\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 90ms/step - answer_output_accuracy: 0.9380 - answer_output_loss: 0.3534 - loss: 3.5516 - question_output_accuracy: 0.5850 - question_output_loss: 2.2064 - type_output_accuracy: 0.5400 - type_output_loss: 0.9917 - val_answer_output_accuracy: 0.9250 - val_answer_output_loss: 0.3677 - val_loss: 3.4014 - val_question_output_accuracy: 0.6250 - val_question_output_loss: 2.0195 - val_type_output_accuracy: 0.5000 - val_type_output_loss: 1.0141\n",
"Epoch 28/30\n",
"\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 96ms/step - answer_output_accuracy: 0.9380 - answer_output_loss: 0.3474 - loss: 3.5737 - question_output_accuracy: 0.5850 - question_output_loss: 2.2414 - type_output_accuracy: 0.5400 - type_output_loss: 0.9848 - val_answer_output_accuracy: 0.9250 - val_answer_output_loss: 0.3645 - val_loss: 3.4429 - val_question_output_accuracy: 0.6250 - val_question_output_loss: 2.0682 - val_type_output_accuracy: 0.5000 - val_type_output_loss: 1.0102\n"
] ]
} }
], ],
@ -426,7 +439,7 @@
" \"answer_output\": np.expand_dims(y_a_train, -1),\n", " \"answer_output\": np.expand_dims(y_a_train, -1),\n",
" \"type_output\": y_type_train,\n", " \"type_output\": y_type_train,\n",
" },\n", " },\n",
" batch_size=32,\n", " batch_size=64,\n",
" epochs=30,\n", " epochs=30,\n",
" validation_split=0.1,\n", " validation_split=0.1,\n",
" callbacks=[EarlyStopping(patience=3, restore_best_weights=True)],\n", " callbacks=[EarlyStopping(patience=3, restore_best_weights=True)],\n",
@ -450,7 +463,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 57,
"id": "06fd86c7", "id": "06fd86c7",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@ -458,12 +471,12 @@
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"\u001b[1m2/2\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 236ms/step\n", "\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 239ms/step\n",
"\n", "\n",
"=== Akurasi Detail ===\n", "=== Akurasi Detail ===\n",
"Question Accuracy (Token-level): 0.0000\n", "Question Accuracy (Token-level): 0.0000\n",
"Answer Accuracy (Token-level) : 0.0000\n", "Answer Accuracy (Token-level) : 0.0000\n",
"Type Accuracy (Class-level) : 0.68\n" "Type Accuracy (Class-level) : 0.29\n"
] ]
} }
], ],
@ -506,7 +519,36 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 8, "execution_count": 58,
"id": "b17b6470",
"metadata": {},
"outputs": [],
"source": [
"# import sacrebleu\n",
"# from sacrebleu.metrics import BLEU # optional kalau mau smoothing/effective_order\n",
"\n",
"# idx2tok = {v:k for k,v in word2idx.items()}\n",
"# PAD_ID = word2idx[\"PAD\"]\n",
"# SOS_ID = word2idx.get(\"SOS\", None)\n",
"# EOS_ID = word2idx.get(\"EOS\", None)\n",
"\n",
"# def seq2str(seq):\n",
"# \"\"\"Konversi list index -> kalimat string, sambil buang token spesial.\"\"\"\n",
"# toks = [idx2tok[i] for i in seq\n",
"# if i not in {PAD_ID, SOS_ID, EOS_ID}]\n",
"# return \" \".join(toks).strip().lower()\n",
"\n",
"# bleu_metric = BLEU(effective_order=True) # lebih stabil utk kalimat pendek\n",
"\n",
"# def bleu_corpus(pred_seqs, true_seqs):\n",
"# preds = [seq2str(p) for p in pred_seqs]\n",
"# refs = [[seq2str(t)] for t in true_seqs] # listoflist, satu ref/kalimat\n",
"# return bleu_metric.corpus_score(preds, refs).score\n"
]
},
{
"cell_type": "code",
"execution_count": 59,
"id": "d5ed106c", "id": "d5ed106c",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -519,7 +561,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 9, "execution_count": 60,
"id": "aa3860de", "id": "aa3860de",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],

View File

@ -62,45 +62,9 @@ if __name__ == "__main__":
# Example input # Example input
input_data = { input_data = {
"tokens": [ "tokens": ["Nama", "lengkap", "saya", "adalah", "Bayu", "Prabowo", "."],
"Ki", "ner": ["O", "O", "O", "O", "B-PER", "I-PER", "O"],
"Hajar", "srl": ["ARG1", "ARG1", "ARG2", "V", "ARG0", "ARG0", "O"],
"Dewantara",
"lahir",
"pada",
"2",
"Mei",
"1889",
"di",
"Yogyakarta",
".",
],
"ner": [
"B-PER",
"I-PER",
"I-PER",
"O",
"O",
"B-DATE",
"I-DATE",
"I-DATE",
"O",
"B-LOC",
"O",
],
"srl": [
"ARG0",
"ARG0",
"ARG0",
"V",
"O",
"ARGM-TMP",
"ARGM-TMP",
"ARGM-TMP",
"O",
"ARGM-LOC",
"O",
],
} }
# input_data = { # input_data = {

308
question_generation/qg.py Normal file
View File

@ -0,0 +1,308 @@
#!/usr/bin/env python3
# ===============================================================
# QuestionGeneration seqtoseq (tokens + NER + SRL → Q/A/type)
# revised version 20250511
# ===============================================================
import json, pickle, random
from pathlib import Path
from itertools import chain
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import (
Input, Embedding, LSTM, Concatenate,
Dense, TimeDistributed
)
from tensorflow.keras.models import Model
from nltk.translate.bleu_score import corpus_bleu, SmoothingFunction
from rouge_score import rouge_scorer, scoring
# -----------------------------------------------------------------
# 0. LOAD & FLATTEN DATA
# -----------------------------------------------------------------
RAW = json.loads(Path("../dataset/dev_dataset_qg.json").read_text())
samples = []
for item in RAW:
for qp in item["quiz_posibility"]:
samples.append({
"tokens" : item["tokens"],
"ner" : item["ner"],
"srl" : item["srl"],
"q_type" : qp["type"], # isian / opsi / benar_salah
"q_toks" : qp["question"] + ["<eos>"],
"a_toks" : (qp["answer"] if isinstance(qp["answer"], list)
else [qp["answer"]]) + ["<eos>"]
})
print("flattened samples :", len(samples))
# -----------------------------------------------------------------
# 1. VOCABULARIES
# -----------------------------------------------------------------
def build_vocab(seq_iter, reserved=("<pad>", "<unk>", "<sos>", "<eos>")):
vocab = {tok: idx for idx, tok in enumerate(reserved)}
for tok in chain.from_iterable(seq_iter):
vocab.setdefault(tok, len(vocab))
return vocab
vocab_tok = build_vocab((s["tokens"] for s in samples))
vocab_ner = build_vocab((s["ner"] for s in samples), reserved=("<pad>","<unk>"))
vocab_srl = build_vocab((s["srl"] for s in samples), reserved=("<pad>","<unk>"))
vocab_q = build_vocab((s["q_toks"] for s in samples))
vocab_a = build_vocab((s["a_toks"] for s in samples))
vocab_typ = {"isian":0, "opsi":1, "benar_salah":2}
# -----------------------------------------------------------------
# 2. ENCODING & PADDING
# -----------------------------------------------------------------
def enc(seq, v): return [v.get(t, v["<unk>"]) for t in seq]
MAX_SENT = max(len(s["tokens"]) for s in samples)
MAX_Q = max(len(s["q_toks"]) for s in samples)
MAX_A = max(len(s["a_toks"]) for s in samples)
def pad_batch(seqs, vmap, maxlen):
return tf.keras.preprocessing.sequence.pad_sequences(
[enc(s, vmap) for s in seqs], maxlen=maxlen, padding="post"
)
X_tok = pad_batch((s["tokens"] for s in samples), vocab_tok, MAX_SENT)
X_ner = pad_batch((s["ner"] for s in samples), vocab_ner, MAX_SENT)
X_srl = pad_batch((s["srl"] for s in samples), vocab_srl, MAX_SENT)
dec_q_in = pad_batch(
([["<sos>"]+s["q_toks"][:-1] for s in samples]), vocab_q, MAX_Q)
dec_q_out = pad_batch((s["q_toks"] for s in samples), vocab_q, MAX_Q)
dec_a_in = pad_batch(
([["<sos>"]+s["a_toks"][:-1] for s in samples]), vocab_a, MAX_A)
dec_a_out = pad_batch((s["a_toks"] for s in samples), vocab_a, MAX_A)
y_type = np.array([vocab_typ[s["q_type"]] for s in samples])
# -----------------------------------------------------------------
# 3. MODEL
# -----------------------------------------------------------------
d_tok, d_tag, units = 128, 32, 256
pad_tok, pad_q, pad_a = vocab_tok["<pad>"], vocab_q["<pad>"], vocab_a["<pad>"]
# ---- Encoder ----------------------------------------------------
inp_tok = Input((MAX_SENT,), name="tok_in")
inp_ner = Input((MAX_SENT,), name="ner_in")
inp_srl = Input((MAX_SENT,), name="srl_in")
emb_tok = Embedding(len(vocab_tok), d_tok, mask_zero=True, name="emb_tok")(inp_tok)
emb_ner = Embedding(len(vocab_ner), d_tag, mask_zero=True, name="emb_ner")(inp_ner)
emb_srl = Embedding(len(vocab_srl), d_tag, mask_zero=True, name="emb_srl")(inp_srl)
enc_concat = Concatenate()([emb_tok, emb_ner, emb_srl])
enc_out, state_h, state_c = LSTM(units, return_state=True, name="enc_lstm")(enc_concat)
# ---- Decoder : Question ----------------------------------------
dec_q_inp = Input((MAX_Q,), name="dec_q_in")
dec_emb_q = Embedding(len(vocab_q), d_tok, mask_zero=True, name="emb_q")(dec_q_inp)
dec_q_seq, _, _ = LSTM(units, return_sequences=True, return_state=True,
name="lstm_q")(dec_emb_q, initial_state=[state_h, state_c])
q_out = TimeDistributed(Dense(len(vocab_q), activation="softmax"), name="q_out")(dec_q_seq)
# ---- Decoder : Answer ------------------------------------------
dec_a_inp = Input((MAX_A,), name="dec_a_in")
dec_emb_a = Embedding(len(vocab_a), d_tok, mask_zero=True, name="emb_a")(dec_a_inp)
dec_a_seq, _, _ = LSTM(units, return_sequences=True, return_state=True,
name="lstm_a")(dec_emb_a, initial_state=[state_h, state_c])
a_out = TimeDistributed(Dense(len(vocab_a), activation="softmax"), name="a_out")(dec_a_seq)
# ---- Classifier -------------------------------------------------
type_out = Dense(len(vocab_typ), activation="softmax", name="type_out")(enc_out)
model = Model(
[inp_tok, inp_ner, inp_srl, dec_q_inp, dec_a_inp],
[q_out, a_out, type_out]
)
# ---- Masked loss helpers ---------------------------------------
scce = tf.keras.losses.SparseCategoricalCrossentropy(reduction="none")
def masked_loss_factory(pad_id):
def loss(y_true, y_pred):
l = scce(y_true, y_pred)
mask = tf.cast(tf.not_equal(y_true, pad_id), tf.float32)
return tf.reduce_sum(l*mask) / tf.reduce_sum(mask)
return loss
model.compile(
optimizer="adam",
loss = {"q_out":masked_loss_factory(pad_q),
"a_out":masked_loss_factory(pad_a),
"type_out":"sparse_categorical_crossentropy"},
loss_weights={"q_out":1.0, "a_out":1.0, "type_out":0.3},
metrics={"q_out":"sparse_categorical_accuracy",
"a_out":"sparse_categorical_accuracy",
"type_out":tf.keras.metrics.SparseCategoricalAccuracy(name="type_acc")}
)
model.summary()
# -----------------------------------------------------------------
# 4. TRAIN
# -----------------------------------------------------------------
history = model.fit(
[X_tok, X_ner, X_srl, dec_q_in, dec_a_in],
[dec_q_out, dec_a_out, y_type],
validation_split=0.1,
epochs=30,
batch_size=64,
callbacks=[tf.keras.callbacks.EarlyStopping(patience=4, restore_best_weights=True)],
verbose=2
)
model.save("full_seq2seq.keras")
# -----------------------------------------------------------------
# 5. SAVE VOCABS (.pkl keeps python dict intact)
# -----------------------------------------------------------------
def save_vocab(v, name): pickle.dump(v, open(name,"wb"))
save_vocab(vocab_tok,"vocab_tok.pkl"); save_vocab(vocab_ner,"vocab_ner.pkl")
save_vocab(vocab_srl,"vocab_srl.pkl"); save_vocab(vocab_q, "vocab_q.pkl")
save_vocab(vocab_a, "vocab_a.pkl"); save_vocab(vocab_typ,"vocab_typ.pkl")
# -----------------------------------------------------------------
# 6. INFERENCE MODELS (encoder & decoders)
# -----------------------------------------------------------------
def build_inference_models(trained):
# encoder
t_in = Input((MAX_SENT,), name="t_in")
n_in = Input((MAX_SENT,), name="n_in")
s_in = Input((MAX_SENT,), name="s_in")
e_t = trained.get_layer("emb_tok")(t_in)
e_n = trained.get_layer("emb_ner")(n_in)
e_s = trained.get_layer("emb_srl")(s_in)
concat = Concatenate()([e_t,e_n,e_s])
_, h, c = trained.get_layer("enc_lstm")(concat)
enc_model = Model([t_in,n_in,s_in],[h,c])
# questiondecoder
dq_in = Input((1,), name="dq_tok")
dh = Input((units,), name="dh"); dc = Input((units,), name="dc")
dq_emb = trained.get_layer("emb_q")(dq_in)
dq_lstm, nh, nc = trained.get_layer("lstm_q")(dq_emb, initial_state=[dh,dc])
dq_out = trained.get_layer("q_out").layer(dq_lstm)
dec_q_model = Model([dq_in, dh, dc], [dq_out, nh, nc])
# answerdecoder
da_in = Input((1,), name="da_tok")
ah = Input((units,), name="ah"); ac = Input((units,), name="ac")
da_emb = trained.get_layer("emb_a")(da_in)
da_lstm, nh2, nc2 = trained.get_layer("lstm_a")(da_emb, initial_state=[ah,ac])
da_out = trained.get_layer("a_out").layer(da_lstm)
dec_a_model = Model([da_in, ah, ac], [da_out, nh2, nc2])
# type classifier
type_dense = trained.get_layer("type_out")
type_model = Model([t_in,n_in,s_in], type_dense(_)) # use _ = enc_lstm output
return enc_model, dec_q_model, dec_a_model, type_model
encoder_model, decoder_q, decoder_a, classifier_model = build_inference_models(model)
inv_q = {v:k for k,v in vocab_q.items()}
inv_a = {v:k for k,v in vocab_a.items()}
def enc_pad(seq, vmap, maxlen):
x = [vmap.get(t, vmap["<unk>"]) for t in seq]
return x + [vmap["<pad>"]] * (maxlen-len(x))
def greedy_decode(tokens, ner, srl, max_q=20, max_a=10):
et = np.array([enc_pad(tokens, vocab_tok, MAX_SENT)])
en = np.array([enc_pad(ner, vocab_ner, MAX_SENT)])
es = np.array([enc_pad(srl, vocab_srl, MAX_SENT)])
h,c = encoder_model.predict([et,en,es], verbose=0)
# --- question
q_ids = []
tgt = np.array([[vocab_q["<sos>"]]])
for _ in range(max_q):
logits,h,c = decoder_q.predict([tgt,h,c], verbose=0)
nxt = int(logits[0,-1].argmax())
if nxt==vocab_q["<eos>"]: break
q_ids.append(nxt)
tgt = np.array([[nxt]])
# --- answer (reuse fresh h,c)
h,c = encoder_model.predict([et,en,es], verbose=0)
a_ids = []
tgt = np.array([[vocab_a["<sos>"]]])
for _ in range(max_a):
logits,h,c = decoder_a.predict([tgt,h,c], verbose=0)
nxt = int(logits[0,-1].argmax())
if nxt==vocab_a["<eos>"]: break
a_ids.append(nxt)
tgt = np.array([[nxt]])
# --- type
t_id = int(classifier_model.predict([et,en,es], verbose=0).argmax())
return [inv_q[i] for i in q_ids], [inv_a[i] for i in a_ids], \
[k for k,v in vocab_typ.items() if v==t_id][0]
# -----------------------------------------------------------------
# 7. QUICK DEMO
# -----------------------------------------------------------------
test_tokens = ["soekarno","membacakan","teks","proklamasi","pada",
"17","agustus","1945"]
test_ner = ["B-PER","O","O","O","O","B-DATE","I-DATE","I-DATE"]
test_srl = ["ARG0","V","ARG1","ARG1","O","ARGM-TMP","ARGM-TMP","ARGM-TMP"]
q,a,t = greedy_decode(test_tokens,test_ner,test_srl,max_q=MAX_Q,max_a=MAX_A)
print("\nDEMO\n----")
print("Q :", " ".join(q))
print("A :", " ".join(a))
print("T :", t)
# -----------------------------------------------------------------
# 8. EVALUATION (corpuslevel BLEU + ROUGE1/L)
# -----------------------------------------------------------------
smooth = SmoothingFunction().method4
r_scorer = rouge_scorer.RougeScorer(["rouge1","rougeL"], use_stemmer=True)
def strip_special(seq, pad_id, eos_id):
return [x for x in seq if x not in (pad_id, eos_id)]
def ids_to_text(ids, inv):
return " ".join(inv[i] for i in ids)
def evaluate(n=200):
idxs = random.sample(range(len(samples)), n)
refs, hyps = [], []
agg = scoring.BootstrapAggregator()
for i in idxs:
gt_ids = strip_special(dec_q_out[i], pad_q, vocab_q["<eos>"])
ref = ids_to_text(gt_ids, inv_q)
pred = " ".join(greedy_decode(
samples[i]["tokens"],
samples[i]["ner"],
samples[i]["srl"]
)[0])
refs.append([ref.split()])
hyps.append(pred.split())
agg.add_scores(r_scorer.score(ref, pred))
bleu = corpus_bleu(refs, hyps, smoothing_function=smooth)
r1 = agg.aggregate()["rouge1"].mid
rL = agg.aggregate()["rougeL"].mid
print(f"\nEVAL (n={n})")
print(f"BLEU4 : {bleu:.4f}")
print(f"ROUGE1 : P={r1.precision:.3f} R={r1.recall:.3f} F1={r1.fmeasure:.3f}")
print(f"ROUGEL : P={rL.precision:.3f} R={rL.recall:.3f} F1={rL.fmeasure:.3f}")
evaluate(2) # run on 150 random samples

View File

@ -0,0 +1,798 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 93,
"id": "fb283f23",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Total flattened samples: 342\n"
]
}
],
"source": [
"import json\n",
"from pathlib import Path\n",
"from itertools import chain\n",
"\n",
"RAW = json.loads(\n",
" Path(\"../dataset/dev_dataset_qg.json\").read_text()\n",
") # ← file contoh Anda\n",
"\n",
"samples = []\n",
"for item in RAW:\n",
" for qp in item[\"quiz_posibility\"]:\n",
" samp = {\n",
" \"tokens\": [tok.lower() for tok in item[\"tokens\"]],\n",
" \"ner\": item[\"ner\"],\n",
" \"srl\": item[\"srl\"],\n",
" \"q_type\": qp[\"type\"], # isian / opsi / benar_salah\n",
" \"q_toks\": [tok.lower() for tok in qp[\"question\"]]\n",
" + [\"<eos>\"], # tambahkan <eos>\n",
" }\n",
" # Jawaban bisa multi token\n",
" if isinstance(qp[\"answer\"], list):\n",
" samp[\"a_toks\"] = [tok.lower() for tok in qp[\"answer\"]] + [\"<eos>\"]\n",
" else:\n",
" samp[\"a_toks\"] = [qp[\"answer\"].lower(), \"<eos>\"]\n",
" samples.append(samp)\n",
"\n",
"print(\"Total flattened samples:\", len(samples))"
]
},
{
"cell_type": "code",
"execution_count": 94,
"id": "fa4f979d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'<pad>': 0, '<unk>': 1, '<sos>': 2, '<eos>': 3, 'jepara': 4, 'false': 5, 'trowulan': 6, '17': 7, 'agustus': 8, '1945': 9, 'soekarno': 10, 'mohammad hatta': 11, '365': 12, 'hari': 13, 'merkurius': 14, 'true': 15, 'mars': 16, 'jupiter': 17, 'saturnus': 18, 'uranus': 19, 'neptunus': 20, '5': 21, 'januari': 22, '2020': 23, '12': 24, 'februari': 25, '2019': 26, '23': 27, 'maret': 28, '2021': 29, '1': 30, 'april': 31, '2022': 32, '15': 33, 'mei': 34, '2023': 35, 'gunung': 36, 'everest': 37, 'amazon': 38, 'piramida': 39, 'giza': 40, 'benua': 41, 'asia': 42, 'colosseum': 43, 'taj': 44, 'mahal': 45, 'petra': 46, 'tembok': 47, 'cina': 48, 'chichen': 49, 'itza': 50, 'patung': 51, 'yesus': 52, 'penebus': 53, 'machu': 54, 'picchu': 55, 'stonehenge': 56, 'menara': 57, 'pisa': 58, 'angkot': 59, 'wat': 60, '8848': 61, 'meter': 62, '17 agustus 1945': 63, 'albert': 64, 'einstein': 65, 'jantung': 66, 'memompa darah': 67, 'tokyo': 68, '100': 69, 'derajat': 70, 'celsius': 71, 'thomas': 72, 'alva': 73, 'edison': 74, '1879': 75, 'ketiga': 76, 'leonardo': 77, 'da': 78, 'vinci': 79, 'leonardo da vinci': 80, '9,46': 81, 'triliun': 82, 'kilometer': 83, 'mahatma': 84, 'gandhi': 85, '1958': 86, 'kornea': 87, 'waterloo': 88, '1815': 89, 'indonesia': 90, 'marie': 91, 'curie': 92, 'fisika dan kimia': 93, 'inka': 94, 'oksigen': 95, 'karbon dioksida dan air': 96, 'vincent': 97, 'van': 98, 'gogh': 99, 'double': 100, 'helix': 101, 'double helix': 102, 'alexander': 103, 'fleming': 104, 'jeruk': 105, 'dan': 106, 'kiwi': 107, 'vitamin c': 108, 'nikola': 109, 'tesla': 110, 'sungai': 111, 'nil': 112, '6650 kilometer': 113, 'paus': 114, 'biru': 115, 'pankreas': 116, 'mengatur gula darah': 117, 'charles': 118, 'darwin': 119, 'shah': 120, 'jahan': 121, 'mumtaz mahal': 122, '44.58 juta km²': 123, '54': 124, 'di selatan laut mediterania': 125, 'eropa': 126, '10.18 juta km²': 127, 'atlantik': 128, 'pasifik': 129, 'hutan amazon': 130, 'australia': 131, 'belahan bumi selatan': 132, 'antartika': 133, 'kutub selatan': 134, '4.7 miliar': 135, 'kilimanjaro': 136, '5,895 meter': 137, 'sahara': 138, 'afrika': 139, 'alpen': 140, '8': 141, 'superior': 142, 'danau superior': 143, 'amerika selatan': 144, 'ali': 145, 'turnamen': 146, 'nina': 147, 'rapat': 148, 'farhan': 149, 'andi': 150, 'workshop': 151, 'lina': 152, 'pameran': 153, 'iqbal': 154, 'siti': 155, 'perlombaan': 156, 'konser': 157, 'fajar': 158, 'dina': 159, 'festival': 160, 'rian': 161, 'bazar': 162, 'tari': 163, 'seminar': 164, 'kompetisi': 165, 'rudi': 166, 'putri': 167, 'budi': 168, 'hana': 169, 'raka': 170, 'dewi': 171, 'surabaya': 172, 'yogyakarta': 173, 'kota': 174, 'jakarta': 175, 'bandung': 176, 'malang': 177, 'bali': 178, 'padang': 179, 'ibukota': 180, 'makassar': 181, 'medan': 182}\n"
]
}
],
"source": [
"def build_vocab(seq_iter, reserved=[\"<pad>\", \"<unk>\", \"<sos>\", \"<eos>\"]):\n",
" vocab = {tok: idx for idx, tok in enumerate(reserved)}\n",
" for tok in chain.from_iterable(seq_iter):\n",
" if tok not in vocab:\n",
" vocab[tok] = len(vocab)\n",
" return vocab\n",
"\n",
"\n",
"vocab_tok = build_vocab((s[\"tokens\"] for s in samples))\n",
"vocab_ner = build_vocab((s[\"ner\"] for s in samples), reserved=[\"<pad>\", \"<unk>\"])\n",
"vocab_srl = build_vocab((s[\"srl\"] for s in samples), reserved=[\"<pad>\", \"<unk>\"])\n",
"vocab_q = build_vocab((s[\"q_toks\"] for s in samples))\n",
"vocab_a = build_vocab((s[\"a_toks\"] for s in samples))\n",
"\n",
"vocab_typ = {\"isian\": 0, \"opsi\": 1, \"true_false\": 2}\n",
"\n",
"print(vocab_a)"
]
},
{
"cell_type": "code",
"execution_count": 95,
"id": "d1a5b324",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"from tensorflow.keras.preprocessing.sequence import pad_sequences\n",
"\n",
"\n",
"def encode(seq, vmap): # token  id\n",
" return [vmap.get(t, vmap[\"<unk>\"]) for t in seq]\n",
"\n",
"\n",
"MAX_SENT = max(len(s[\"tokens\"]) for s in samples)\n",
"MAX_Q = max(len(s[\"q_toks\"]) for s in samples)\n",
"MAX_A = max(len(s[\"a_toks\"]) for s in samples)\n",
"\n",
"X_tok = pad_sequences(\n",
" [encode(s[\"tokens\"], vocab_tok) for s in samples], maxlen=MAX_SENT, padding=\"post\"\n",
")\n",
"X_ner = pad_sequences(\n",
" [encode(s[\"ner\"], vocab_ner) for s in samples], maxlen=MAX_SENT, padding=\"post\"\n",
")\n",
"X_srl = pad_sequences(\n",
" [encode(s[\"srl\"], vocab_srl) for s in samples], maxlen=MAX_SENT, padding=\"post\"\n",
")\n",
"\n",
"# Decoder input = <sos> + target[:-1]\n",
"dec_q_in = pad_sequences(\n",
" [[vocab_q[\"<sos>\"], *encode(s[\"q_toks\"][:-1], vocab_q)] for s in samples],\n",
" maxlen=MAX_Q,\n",
" padding=\"post\",\n",
")\n",
"dec_q_out = pad_sequences(\n",
" [encode(s[\"q_toks\"], vocab_q) for s in samples], maxlen=MAX_Q, padding=\"post\"\n",
")\n",
"\n",
"dec_a_in = pad_sequences(\n",
" [[vocab_a[\"<sos>\"], *encode(s[\"a_toks\"][:-1], vocab_a)] for s in samples],\n",
" maxlen=MAX_A,\n",
" padding=\"post\",\n",
")\n",
"dec_a_out = pad_sequences(\n",
" [encode(s[\"a_toks\"], vocab_a) for s in samples], maxlen=MAX_A, padding=\"post\"\n",
")\n",
"\n",
"MAX_SENT = max(len(s[\"tokens\"]) for s in samples)\n",
"MAX_Q = max(len(s[\"q_toks\"]) for s in samples)\n",
"MAX_A = max(len(s[\"a_toks\"]) for s in samples)\n",
"y_type = np.array([vocab_typ[s[\"q_type\"]] for s in samples])"
]
},
{
"cell_type": "code",
"execution_count": 96,
"id": "ff5bd85f",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">Model: \"functional_8\"</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1mModel: \"functional_8\"\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓\n",
"┃<span style=\"font-weight: bold\"> Layer (type) </span>┃<span style=\"font-weight: bold\"> Output Shape </span>┃<span style=\"font-weight: bold\"> Param # </span>┃<span style=\"font-weight: bold\"> Connected to </span>┃\n",
"┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩\n",
"│ tok_in (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">InputLayer</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">16</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ - │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ ner_in (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">InputLayer</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">16</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ - │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ srl_in (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">InputLayer</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">16</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ - │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ embedding_tok │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">16</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">128</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">57,856</span> │ tok_in[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Embedding</span>) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ embedding_ner │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">16</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">32</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">1,248</span> │ ner_in[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Embedding</span>) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ embedding_srl │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">16</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">32</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">448</span> │ srl_in[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Embedding</span>) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ dec_q_in │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">13</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ - │\n",
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">InputLayer</span>) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ concatenate_8 │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">16</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">192</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ embedding_tok[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>]… │\n",
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Concatenate</span>) │ │ │ embedding_ner[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>]… │\n",
"│ │ │ │ embedding_srl[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>]… │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ dec_a_in │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">4</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ - │\n",
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">InputLayer</span>) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ embedding_q_decoder │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">13</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">128</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">52,096</span> │ dec_q_in[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Embedding</span>) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ encoder_lstm (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">LSTM</span>) │ [(<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">256</span>), │ <span style=\"color: #00af00; text-decoration-color: #00af00\">459,776</span> │ concatenate_8[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>]… │\n",
"│ │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">256</span>), │ │ │\n",
"│ │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">256</span>)] │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ embedding_a_decoder │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">4</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">128</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">23,424</span> │ dec_a_in[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Embedding</span>) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ lstm_q_decoder │ [(<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">13</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">256</span>), │ <span style=\"color: #00af00; text-decoration-color: #00af00\">394,240</span> │ embedding_q_deco… │\n",
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">LSTM</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">256</span>), │ │ encoder_lstm[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">…</span> │\n",
"│ │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">256</span>)] │ │ encoder_lstm[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">…</span> │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ not_equal_32 │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">13</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ dec_q_in[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">NotEqual</span>) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ lstm_a_decoder │ [(<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">4</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">256</span>), │ <span style=\"color: #00af00; text-decoration-color: #00af00\">394,240</span> │ embedding_a_deco… │\n",
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">LSTM</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">256</span>), │ │ encoder_lstm[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">…</span> │\n",
"│ │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">256</span>)] │ │ encoder_lstm[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">…</span> │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ not_equal_33 │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">4</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ dec_a_in[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">NotEqual</span>) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ q_output │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">13</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">407</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">104,599</span> │ lstm_q_decoder[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>… │\n",
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">TimeDistributed</span>) │ │ │ not_equal_32[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">…</span> │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ a_output │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">4</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">183</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">47,031</span> │ lstm_a_decoder[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>… │\n",
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">TimeDistributed</span>) │ │ │ not_equal_33[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">…</span> │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ type_output (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Dense</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">3</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">771</span> │ encoder_lstm[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">…</span> │\n",
"└─────────────────────┴───────────────────┴────────────┴───────────────────┘\n",
"</pre>\n"
],
"text/plain": [
"┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓\n",
"┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mConnected to \u001b[0m\u001b[1m \u001b[0m┃\n",
"┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩\n",
"│ tok_in (\u001b[38;5;33mInputLayer\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m16\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ ner_in (\u001b[38;5;33mInputLayer\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m16\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ srl_in (\u001b[38;5;33mInputLayer\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m16\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ embedding_tok │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m16\u001b[0m, \u001b[38;5;34m128\u001b[0m) │ \u001b[38;5;34m57,856\u001b[0m │ tok_in[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"│ (\u001b[38;5;33mEmbedding\u001b[0m) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ embedding_ner │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m16\u001b[0m, \u001b[38;5;34m32\u001b[0m) │ \u001b[38;5;34m1,248\u001b[0m │ ner_in[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"│ (\u001b[38;5;33mEmbedding\u001b[0m) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ embedding_srl │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m16\u001b[0m, \u001b[38;5;34m32\u001b[0m) │ \u001b[38;5;34m448\u001b[0m │ srl_in[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"│ (\u001b[38;5;33mEmbedding\u001b[0m) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ dec_q_in │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m13\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n",
"│ (\u001b[38;5;33mInputLayer\u001b[0m) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ concatenate_8 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m16\u001b[0m, \u001b[38;5;34m192\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ embedding_tok[\u001b[38;5;34m0\u001b[0m]… │\n",
"│ (\u001b[38;5;33mConcatenate\u001b[0m) │ │ │ embedding_ner[\u001b[38;5;34m0\u001b[0m]… │\n",
"│ │ │ │ embedding_srl[\u001b[38;5;34m0\u001b[0m]… │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ dec_a_in │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m4\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n",
"│ (\u001b[38;5;33mInputLayer\u001b[0m) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ embedding_q_decoder │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m13\u001b[0m, \u001b[38;5;34m128\u001b[0m) │ \u001b[38;5;34m52,096\u001b[0m │ dec_q_in[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"│ (\u001b[38;5;33mEmbedding\u001b[0m) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ encoder_lstm (\u001b[38;5;33mLSTM\u001b[0m) │ [(\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m256\u001b[0m), │ \u001b[38;5;34m459,776\u001b[0m │ concatenate_8[\u001b[38;5;34m0\u001b[0m]… │\n",
"│ │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m256\u001b[0m), │ │ │\n",
"│ │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m256\u001b[0m)] │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ embedding_a_decoder │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m4\u001b[0m, \u001b[38;5;34m128\u001b[0m) │ \u001b[38;5;34m23,424\u001b[0m │ dec_a_in[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"│ (\u001b[38;5;33mEmbedding\u001b[0m) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ lstm_q_decoder │ [(\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m13\u001b[0m, \u001b[38;5;34m256\u001b[0m), │ \u001b[38;5;34m394,240\u001b[0m │ embedding_q_deco… │\n",
"│ (\u001b[38;5;33mLSTM\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m256\u001b[0m), │ │ encoder_lstm[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m…\u001b[0m │\n",
"│ │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m256\u001b[0m)] │ │ encoder_lstm[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m…\u001b[0m │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ not_equal_32 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m13\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ dec_q_in[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"│ (\u001b[38;5;33mNotEqual\u001b[0m) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ lstm_a_decoder │ [(\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m4\u001b[0m, \u001b[38;5;34m256\u001b[0m), │ \u001b[38;5;34m394,240\u001b[0m │ embedding_a_deco… │\n",
"│ (\u001b[38;5;33mLSTM\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m256\u001b[0m), │ │ encoder_lstm[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m…\u001b[0m │\n",
"│ │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m256\u001b[0m)] │ │ encoder_lstm[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m…\u001b[0m │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ not_equal_33 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m4\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ dec_a_in[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"│ (\u001b[38;5;33mNotEqual\u001b[0m) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ q_output │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m13\u001b[0m, \u001b[38;5;34m407\u001b[0m) │ \u001b[38;5;34m104,599\u001b[0m │ lstm_q_decoder[\u001b[38;5;34m0\u001b[0m… │\n",
"│ (\u001b[38;5;33mTimeDistributed\u001b[0m) │ │ │ not_equal_32[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m…\u001b[0m │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ a_output │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m4\u001b[0m, \u001b[38;5;34m183\u001b[0m) │ \u001b[38;5;34m47,031\u001b[0m │ lstm_a_decoder[\u001b[38;5;34m0\u001b[0m… │\n",
"│ (\u001b[38;5;33mTimeDistributed\u001b[0m) │ │ │ not_equal_33[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m…\u001b[0m │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ type_output (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m3\u001b[0m) │ \u001b[38;5;34m771\u001b[0m │ encoder_lstm[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m…\u001b[0m │\n",
"└─────────────────────┴───────────────────┴────────────┴───────────────────┘\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Total params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">1,535,729</span> (5.86 MB)\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m Total params: \u001b[0m\u001b[38;5;34m1,535,729\u001b[0m (5.86 MB)\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Trainable params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">1,535,729</span> (5.86 MB)\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m1,535,729\u001b[0m (5.86 MB)\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Non-trainable params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> (0.00 B)\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import tensorflow as tf\n",
"from tensorflow.keras.layers import (\n",
" Input,\n",
" Embedding,\n",
" LSTM,\n",
" Concatenate,\n",
" Dense,\n",
" TimeDistributed,\n",
")\n",
"from tensorflow.keras.models import Model\n",
"\n",
"# ---- constants ---------------------------------------------------\n",
"d_tok = 128 # token embedding dim\n",
"d_tag = 32 # NER / SRL embedding dim\n",
"units = 256\n",
"\n",
"# ---- encoder -----------------------------------------------------\n",
"inp_tok = Input((MAX_SENT,), name=\"tok_in\")\n",
"inp_ner = Input((MAX_SENT,), name=\"ner_in\")\n",
"inp_srl = Input((MAX_SENT,), name=\"srl_in\")\n",
"\n",
"# make ALL streams mask the same way (here: no masking,\n",
"# we'll just pad with 0s and let the LSTM ignore them)\n",
"emb_tok = Embedding(len(vocab_tok), d_tok, mask_zero=False, name=\"embedding_tok\")(\n",
" inp_tok\n",
")\n",
"emb_ner = Embedding(len(vocab_ner), d_tag, mask_zero=False, name=\"embedding_ner\")(\n",
" inp_ner\n",
")\n",
"emb_srl = Embedding(len(vocab_srl), d_tag, mask_zero=False, name=\"embedding_srl\")(\n",
" inp_srl\n",
")\n",
"\n",
"enc_concat = Concatenate()([emb_tok, emb_ner, emb_srl])\n",
"enc_out, state_h, state_c = LSTM(units, return_state=True, name=\"encoder_lstm\")(\n",
" enc_concat\n",
")\n",
"\n",
"\n",
"# ---------- DECODER : Question ----------\n",
"dec_q_inp = Input(shape=(MAX_Q,), name=\"dec_q_in\")\n",
"dec_emb_q = Embedding(len(vocab_q), d_tok, mask_zero=True, name=\"embedding_q_decoder\")(\n",
" dec_q_inp\n",
")\n",
"dec_q, _, _ = LSTM(\n",
" units, return_state=True, return_sequences=True, name=\"lstm_q_decoder\"\n",
")(dec_emb_q, initial_state=[state_h, state_c])\n",
"q_out = TimeDistributed(\n",
" Dense(len(vocab_q), activation=\"softmax\", name=\"dense_q_output\"), name=\"q_output\"\n",
")(dec_q)\n",
"\n",
"# ---------- DECODER : Answer ----------\n",
"dec_a_inp = Input(shape=(MAX_A,), name=\"dec_a_in\")\n",
"dec_emb_a = Embedding(len(vocab_a), d_tok, mask_zero=True, name=\"embedding_a_decoder\")(\n",
" dec_a_inp\n",
")\n",
"dec_a, _, _ = LSTM(\n",
" units, return_state=True, return_sequences=True, name=\"lstm_a_decoder\"\n",
")(dec_emb_a, initial_state=[state_h, state_c])\n",
"a_out = TimeDistributed(\n",
" Dense(len(vocab_a), activation=\"softmax\", name=\"dense_a_output\"), name=\"a_output\"\n",
")(dec_a)\n",
"\n",
"# ---------- CLASSIFIER : Question Type ----------\n",
"type_out = Dense(len(vocab_typ), activation=\"softmax\", name=\"type_output\")(enc_out)\n",
"\n",
"model = Model(\n",
" inputs=[inp_tok, inp_ner, inp_srl, dec_q_inp, dec_a_inp],\n",
" outputs=[q_out, a_out, type_out],\n",
")\n",
"\n",
"model.summary()"
]
},
{
"cell_type": "code",
"execution_count": 97,
"id": "fece1ae9",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/30\n",
"\u001b[1m5/5\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 161ms/step - a_output_loss: 5.1540 - a_output_sparse_categorical_accuracy: 0.1507 - loss: 11.4761 - q_output_loss: 5.9970 - q_output_sparse_categorical_accuracy: 0.0600 - type_output_accuracy: 0.4506 - type_output_loss: 1.0728 - val_a_output_loss: 4.5900 - val_a_output_sparse_categorical_accuracy: 0.2500 - val_loss: 10.8292 - val_q_output_loss: 5.9316 - val_q_output_sparse_categorical_accuracy: 0.0769 - val_type_output_accuracy: 0.5143 - val_type_output_loss: 1.0253\n",
"Epoch 2/30\n",
"\u001b[1m5/5\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 48ms/step - a_output_loss: 4.2365 - a_output_sparse_categorical_accuracy: 0.2500 - loss: 10.2493 - q_output_loss: 5.6397 - q_output_sparse_categorical_accuracy: 0.1183 - type_output_accuracy: 0.5209 - type_output_loss: 1.2188 - val_a_output_loss: 3.2588 - val_a_output_sparse_categorical_accuracy: 0.2500 - val_loss: 9.0808 - val_q_output_loss: 5.4082 - val_q_output_sparse_categorical_accuracy: 0.0923 - val_type_output_accuracy: 0.5143 - val_type_output_loss: 1.3791\n",
"Epoch 3/30\n",
"\u001b[1m5/5\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 48ms/step - a_output_loss: 3.5259 - a_output_sparse_categorical_accuracy: 0.2500 - loss: 8.4974 - q_output_loss: 4.6444 - q_output_sparse_categorical_accuracy: 0.1174 - type_output_accuracy: 0.5233 - type_output_loss: 1.0788 - val_a_output_loss: 3.3879 - val_a_output_sparse_categorical_accuracy: 0.2500 - val_loss: 9.5209 - val_q_output_loss: 5.7546 - val_q_output_sparse_categorical_accuracy: 0.0769 - val_type_output_accuracy: 0.2000 - val_type_output_loss: 1.2615\n",
"Epoch 4/30\n",
"\u001b[1m5/5\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 48ms/step - a_output_loss: 3.3147 - a_output_sparse_categorical_accuracy: 0.2500 - loss: 8.1027 - q_output_loss: 4.4209 - q_output_sparse_categorical_accuracy: 0.1099 - type_output_accuracy: 0.3256 - type_output_loss: 1.2069 - val_a_output_loss: 3.0792 - val_a_output_sparse_categorical_accuracy: 0.2500 - val_loss: 9.2232 - val_q_output_loss: 5.8382 - val_q_output_sparse_categorical_accuracy: 0.0769 - val_type_output_accuracy: 0.5143 - val_type_output_loss: 1.0193\n",
"Epoch 5/30\n",
"\u001b[1m5/5\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 48ms/step - a_output_loss: 3.1559 - a_output_sparse_categorical_accuracy: 0.2500 - loss: 7.7733 - q_output_loss: 4.3048 - q_output_sparse_categorical_accuracy: 0.1120 - type_output_accuracy: 0.5160 - type_output_loss: 1.0414 - val_a_output_loss: 3.0450 - val_a_output_sparse_categorical_accuracy: 0.2500 - val_loss: 9.1657 - val_q_output_loss: 5.7943 - val_q_output_sparse_categorical_accuracy: 0.0923 - val_type_output_accuracy: 0.5143 - val_type_output_loss: 1.0881\n",
"Epoch 6/30\n",
"\u001b[1m5/5\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 48ms/step - a_output_loss: 3.0962 - a_output_sparse_categorical_accuracy: 0.2569 - loss: 7.6096 - q_output_loss: 4.1973 - q_output_sparse_categorical_accuracy: 0.1121 - type_output_accuracy: 0.5318 - type_output_loss: 1.0492 - val_a_output_loss: 3.1428 - val_a_output_sparse_categorical_accuracy: 0.3214 - val_loss: 9.2982 - val_q_output_loss: 5.8475 - val_q_output_sparse_categorical_accuracy: 0.0769 - val_type_output_accuracy: 0.5143 - val_type_output_loss: 1.0265\n"
]
}
],
"source": [
"losses = {\n",
" \"q_output\": \"sparse_categorical_crossentropy\",\n",
" \"a_output\": \"sparse_categorical_crossentropy\",\n",
" \"type_output\": \"sparse_categorical_crossentropy\",\n",
"}\n",
"loss_weights = {\"q_output\": 1.0, \"a_output\": 1.0, \"type_output\": 0.3}\n",
"\n",
"model.compile(\n",
" optimizer=\"adam\",\n",
" loss=losses,\n",
" loss_weights=loss_weights,\n",
" metrics={\n",
" \"q_output\": \"sparse_categorical_accuracy\",\n",
" \"a_output\": \"sparse_categorical_accuracy\",\n",
" \"type_output\": \"accuracy\",\n",
" },\n",
")\n",
"\n",
"history = model.fit(\n",
" [X_tok, X_ner, X_srl, dec_q_in, dec_a_in],\n",
" [dec_q_out, dec_a_out, y_type],\n",
" validation_split=0.1,\n",
" epochs=30,\n",
" batch_size=64,\n",
" callbacks=[tf.keras.callbacks.EarlyStopping(patience=4, restore_best_weights=True)],\n",
" verbose=1,\n",
")\n",
"\n",
"model.save(\"full_seq2seq.keras\")\n",
"\n",
"import json\n",
"import pickle\n",
"\n",
"# def save_vocab(vocab, path):\n",
"# with open(path, \"w\", encoding=\"utf-8\") as f:\n",
"# json.dump(vocab, f, ensure_ascii=False, indent=2)\n",
"\n",
"# # Simpan semua vocab\n",
"# save_vocab(vocab_tok, \"vocab_tok.json\")\n",
"# save_vocab(vocab_ner, \"vocab_ner.json\")\n",
"# save_vocab(vocab_srl, \"vocab_srl.json\")\n",
"# save_vocab(vocab_q, \"vocab_q.json\")\n",
"# save_vocab(vocab_a, \"vocab_a.json\")\n",
"# save_vocab(vocab_typ, \"vocab_typ.json\")\n",
"\n",
"\n",
"def save_vocab_pkl(vocab, path):\n",
" with open(path, \"wb\") as f:\n",
" pickle.dump(vocab, f)\n",
"\n",
"\n",
"# Simpan semua vocab\n",
"save_vocab_pkl(vocab_tok, \"vocab_tok.pkl\")\n",
"save_vocab_pkl(vocab_ner, \"vocab_ner.pkl\")\n",
"save_vocab_pkl(vocab_srl, \"vocab_srl.pkl\")\n",
"save_vocab_pkl(vocab_q, \"vocab_q.pkl\")\n",
"save_vocab_pkl(vocab_a, \"vocab_a.pkl\")\n",
"save_vocab_pkl(vocab_typ, \"vocab_typ.pkl\")"
]
},
{
"cell_type": "code",
"execution_count": 98,
"id": "3355c0c7",
"metadata": {},
"outputs": [],
"source": [
"import tensorflow as tf\n",
"import numpy as np\n",
"import pickle\n",
"from tensorflow.keras.models import load_model, Model\n",
"from tensorflow.keras.layers import Input, Concatenate\n",
"\n",
"# === Load Model Utama ===\n",
"model = load_model(\"full_seq2seq.keras\")\n",
"\n",
"\n",
"# === Load Vocabulary dari .pkl ===\n",
"def load_vocab(path):\n",
" with open(path, \"rb\") as f:\n",
" return pickle.load(f)\n",
"\n",
"\n",
"vocab_tok = load_vocab(\"vocab_tok.pkl\")\n",
"vocab_ner = load_vocab(\"vocab_ner.pkl\")\n",
"vocab_srl = load_vocab(\"vocab_srl.pkl\")\n",
"vocab_q = load_vocab(\"vocab_q.pkl\")\n",
"vocab_a = load_vocab(\"vocab_a.pkl\")\n",
"vocab_typ = load_vocab(\"vocab_typ.pkl\")\n",
"\n",
"inv_vocab_q = {v: k for k, v in vocab_q.items()}\n",
"inv_vocab_a = {v: k for k, v in vocab_a.items()}\n",
"\n",
"# === Build Encoder Model ===\n",
"MAX_SENT = model.input_shape[0][1] # Ambil shape dari model yang diload\n",
"MAX_Q = model.input_shape[3][1] # Max length for question\n",
"MAX_A = model.input_shape[4][1] # Max length for answer\n",
"\n",
"inp_tok_g = Input(shape=(MAX_SENT,), name=\"tok_in_g\")\n",
"inp_ner_g = Input(shape=(MAX_SENT,), name=\"ner_in_g\")\n",
"inp_srl_g = Input(shape=(MAX_SENT,), name=\"srl_in_g\")\n",
"\n",
"emb_tok = model.get_layer(\"embedding_tok\").call(inp_tok_g)\n",
"emb_ner = model.get_layer(\"embedding_ner\").call(inp_ner_g)\n",
"emb_srl = model.get_layer(\"embedding_srl\").call(inp_srl_g)\n",
"\n",
"enc_concat = Concatenate(name=\"concat_encoder\")([emb_tok, emb_ner, emb_srl])\n",
"\n",
"encoder_lstm = model.get_layer(\"encoder_lstm\")\n",
"enc_out, state_h, state_c = encoder_lstm(enc_concat)\n",
"\n",
"# Create encoder model with full output including enc_out\n",
"encoder_model = Model(\n",
" inputs=[inp_tok_g, inp_ner_g, inp_srl_g],\n",
" outputs=[enc_out, state_h, state_c],\n",
" name=\"encoder_model\",\n",
")\n",
"\n",
"# === Build Decoder for Question ===\n",
"dec_q_inp = Input(shape=(1,), name=\"dec_q_in\")\n",
"dec_emb_q = model.get_layer(\"embedding_q_decoder\").call(dec_q_inp)\n",
"\n",
"state_h_dec = Input(shape=(256,), name=\"state_h_dec\")\n",
"state_c_dec = Input(shape=(256,), name=\"state_c_dec\")\n",
"\n",
"lstm_decoder_q = model.get_layer(\"lstm_q_decoder\")\n",
"\n",
"dec_out_q, state_h_q, state_c_q = lstm_decoder_q(\n",
" dec_emb_q, initial_state=[state_h_dec, state_c_dec]\n",
")\n",
"\n",
"q_time_dist_layer = model.get_layer(\"q_output\")\n",
"dense_q = q_time_dist_layer.layer\n",
"q_output = dense_q(dec_out_q)\n",
"\n",
"decoder_q = Model(\n",
" inputs=[dec_q_inp, state_h_dec, state_c_dec],\n",
" outputs=[q_output, state_h_q, state_c_q],\n",
" name=\"decoder_question_model\",\n",
")\n",
"\n",
"# === Build Decoder for Answer ===\n",
"dec_a_inp = Input(shape=(1,), name=\"dec_a_in\")\n",
"dec_emb_a = model.get_layer(\"embedding_a_decoder\").call(dec_a_inp)\n",
"\n",
"state_h_a = Input(shape=(256,), name=\"state_h_a\")\n",
"state_c_a = Input(shape=(256,), name=\"state_c_a\")\n",
"\n",
"lstm_decoder_a = model.get_layer(\"lstm_a_decoder\")\n",
"\n",
"dec_out_a, state_h_a_out, state_c_a_out = lstm_decoder_a(\n",
" dec_emb_a, initial_state=[state_h_a, state_c_a]\n",
")\n",
"\n",
"a_time_dist_layer = model.get_layer(\"a_output\")\n",
"dense_a = a_time_dist_layer.layer\n",
"a_output = dense_a(dec_out_a)\n",
"\n",
"decoder_a = Model(\n",
" inputs=[dec_a_inp, state_h_a, state_c_a],\n",
" outputs=[a_output, state_h_a_out, state_c_a_out],\n",
" name=\"decoder_answer_model\",\n",
")\n",
"\n",
"# === Build Classifier for Question Type ===\n",
"type_dense = model.get_layer(\"type_output\")\n",
"type_out = type_dense(enc_out)\n",
"\n",
"classifier_model = Model(\n",
" inputs=[inp_tok_g, inp_ner_g, inp_srl_g], outputs=type_out, name=\"classifier_model\"\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 99,
"id": "d406e6ff",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Generated Question: menghadiri menghadiri ___ ___\n",
"Generated Answer : \n",
"Question Type : isian\n"
]
}
],
"source": [
"def encode(seq, vmap):\n",
" return [vmap.get(tok, vmap[\"<unk>\"]) for tok in seq]\n",
"\n",
"\n",
"def encode_and_pad(seq, vmap, max_len=MAX_SENT):\n",
" encoded = [vmap.get(tok, vmap[\"<unk>\"]) for tok in seq]\n",
" # Pad with vocab[\"<pad>\"] to the right if sequence is shorter than max_len\n",
" padded = encoded + [vmap[\"<pad>\"]] * (max_len - len(encoded))\n",
" return padded[:max_len] # Ensure it doesn't exceed max_len\n",
"\n",
"\n",
"def greedy_decode(tokens, ner, srl, max_q=20, max_a=10):\n",
" # --- encode encoder inputs -------------------------------------------\n",
" if isinstance(tokens, np.ndarray):\n",
" enc_tok = tokens\n",
" enc_ner = ner\n",
" enc_srl = srl\n",
" else:\n",
" enc_tok = np.array([encode_and_pad(tokens, vocab_tok, MAX_SENT)])\n",
" enc_ner = np.array([encode_and_pad(ner, vocab_ner, MAX_SENT)])\n",
" enc_srl = np.array([encode_and_pad(srl, vocab_srl, MAX_SENT)])\n",
"\n",
" # --- Get encoder outputs ---\n",
" enc_out, h, c = encoder_model.predict([enc_tok, enc_ner, enc_srl], verbose=0)\n",
"\n",
" # QUESTION Decoding\n",
" tgt = np.array([[vocab_q[\"<sos>\"]]])\n",
" question_ids = []\n",
" for _ in range(max_q):\n",
" logits, h, c = decoder_q.predict([tgt, h, c], verbose=0)\n",
" next_id = int(logits[0, 0].argmax()) # Get the predicted token ID\n",
" if next_id == vocab_q[\"<eos>\"]:\n",
" break\n",
" question_ids.append(next_id)\n",
" tgt = np.array([[next_id]]) # Feed the predicted token back as input\n",
"\n",
" # ANSWER Decoding - use encoder outputs again for fresh state\n",
" _, h, c = encoder_model.predict([enc_tok, enc_ner, enc_srl], verbose=0)\n",
" tgt = np.array([[vocab_a[\"<sos>\"]]])\n",
" answer_ids = []\n",
" for _ in range(max_a):\n",
" logits, h, c = decoder_a.predict([tgt, h, c], verbose=0)\n",
" next_id = int(logits[0, 0].argmax())\n",
" if next_id == vocab_a[\"<eos>\"]:\n",
" break\n",
" answer_ids.append(next_id)\n",
" tgt = np.array([[next_id]])\n",
"\n",
" # Question Type\n",
" qtype_logits = classifier_model.predict([enc_tok, enc_ner, enc_srl], verbose=0)\n",
" qtype_id = int(qtype_logits.argmax())\n",
"\n",
" # Final output\n",
" question = [inv_vocab_q.get(i, \"<unk>\") for i in question_ids]\n",
" answer = [inv_vocab_a.get(i, \"<unk>\") for i in answer_ids]\n",
" q_type = [k for k, v in vocab_typ.items() if v == qtype_id][0]\n",
"\n",
" return question, answer, q_type\n",
"\n",
"\n",
"def test_model():\n",
" test_data = {\n",
" \"tokens\": [\"nama\", \"lengkap\", \"saya\", \"Maya\", \"Maya\"],\n",
" \"ner\": [\"O\", \"O\", \"O\", \"B-PER\", \"B-PER\"],\n",
" \"srl\": [\"O\", \"O\", \"ARG0\", \"ARG0\", \"ARG0\"],\n",
" }\n",
" # tokens = [\n",
" # \"soekarno\",\n",
" # \"membacakan\",\n",
" # \"teks\",\n",
" # \"proklamasi\",\n",
" # \"pada\",\n",
" # \"17\",\n",
" # \"agustus\",\n",
" # \"1945\",\n",
" # ]\n",
" # ner_tags = [\"B-PER\", \"O\", \"O\", \"O\", \"O\", \"B-DATE\", \"I-DATE\", \"I-DATE\"]\n",
" # srl_tags = [\"ARG0\", \"V\", \"ARG1\", \"ARG1\", \"O\", \"ARGM-TMP\", \"ARGM-TMP\", \"ARGM-TMP\"]\n",
"\n",
" question, answer, q_type = greedy_decode(\n",
" test_data[\"tokens\"], test_data[\"ner\"], test_data[\"srl\"]\n",
" )\n",
" print(f\"Generated Question: {' '.join(question)}\")\n",
" print(f\"Generated Answer : {' '.join(answer)}\")\n",
" print(f\"Question Type : {q_type}\")\n",
"\n",
"\n",
"test_model()"
]
},
{
"cell_type": "code",
"execution_count": 100,
"id": "5adde3c3",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"BLEU : 0.0385\n",
"ROUGE1: 0.1052 | ROUGE-L: 0.1052\n"
]
}
],
"source": [
"from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction\n",
"from rouge_score import rouge_scorer\n",
"\n",
"smoothie = SmoothingFunction().method4\n",
"scorer = rouge_scorer.RougeScorer([\"rouge1\", \"rougeL\"], use_stemmer=True)\n",
"\n",
"\n",
"# Helper to strip special ids\n",
"def strip_special(ids, vocab):\n",
" pad = vocab[\"<pad>\"] if \"<pad>\" in vocab else None\n",
" eos = vocab[\"<eos>\"]\n",
" return [i for i in ids if i not in (pad, eos)]\n",
"\n",
"\n",
"def ids_to_text(ids, inv_vocab):\n",
" return \" \".join(inv_vocab[i] for i in ids)\n",
"\n",
"\n",
"# ---- evaluation over a set of indices ----\n",
"import random\n",
"\n",
"\n",
"def evaluate(indices=None):\n",
" if indices is None:\n",
" indices = random.sample(range(len(X_tok)), k=min(100, len(X_tok)))\n",
"\n",
" bleu_scores, rou1, rouL = [], [], []\n",
" for idx in indices:\n",
" # Ground truth\n",
" gt_q = strip_special(dec_q_out[idx], vocab_q)\n",
" gt_a = strip_special(dec_a_out[idx], vocab_a)\n",
" # Prediction\n",
" q_pred, a_pred, _ = greedy_decode(\n",
" X_tok[idx : idx + 1], X_ner[idx : idx + 1], X_srl[idx : idx + 1]\n",
" )\n",
"\n",
" # BLEU on question tokens\n",
" bleu_scores.append(\n",
" sentence_bleu(\n",
" [[inv_vocab_q[i] for i in gt_q]], q_pred, smoothing_function=smoothie\n",
" )\n",
" )\n",
" # ROUGE on question strings\n",
" r = scorer.score(ids_to_text(gt_q, inv_vocab_q), \" \".join(q_pred))\n",
" rou1.append(r[\"rouge1\"].fmeasure)\n",
" rouL.append(r[\"rougeL\"].fmeasure)\n",
"\n",
" print(f\"BLEU : {np.mean(bleu_scores):.4f}\")\n",
" print(f\"ROUGE1: {np.mean(rou1):.4f} | ROUGE-L: {np.mean(rouL):.4f}\")\n",
"\n",
"\n",
"evaluate()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "myenv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.16"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@ -0,0 +1,615 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 13,
"id": "58e41ccb",
"metadata": {},
"outputs": [],
"source": [
"import json, pickle, random\n",
"from pathlib import Path\n",
"from itertools import chain\n",
"\n",
"import numpy as np\n",
"import tensorflow as tf\n",
"from tensorflow.keras.layers import (\n",
" Input, Embedding, LSTM, Concatenate,\n",
" Dense, TimeDistributed\n",
")\n",
"from tensorflow.keras.models import Model\n",
"from nltk.translate.bleu_score import corpus_bleu, SmoothingFunction\n",
"from rouge_score import rouge_scorer, scoring\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "a94dd46a",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"flattened samples : 8\n"
]
}
],
"source": [
"RAW = json.loads(Path(\"../dataset/dev_dataset_qg.json\").read_text())\n",
"\n",
"samples = []\n",
"for item in RAW:\n",
" for qp in item[\"quiz_posibility\"]:\n",
" samples.append({\n",
" \"tokens\" : item[\"tokens\"],\n",
" \"ner\" : item[\"ner\"],\n",
" \"srl\" : item[\"srl\"],\n",
" \"q_type\" : qp[\"type\"], # isian / opsi / benar_salah\n",
" \"q_toks\" : qp[\"question\"] + [\"<eos>\"],\n",
" \"a_toks\" : (qp[\"answer\"] if isinstance(qp[\"answer\"], list)\n",
" else [qp[\"answer\"]]) + [\"<eos>\"]\n",
" })\n",
"\n",
"print(\"flattened samples :\", len(samples))\n"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "852fb9a8",
"metadata": {},
"outputs": [],
"source": [
"def build_vocab(seq_iter, reserved=(\"<pad>\", \"<unk>\", \"<sos>\", \"<eos>\")):\n",
" vocab = {tok: idx for idx, tok in enumerate(reserved)}\n",
" for tok in chain.from_iterable(seq_iter):\n",
" vocab.setdefault(tok, len(vocab))\n",
" return vocab\n",
"\n",
"vocab_tok = build_vocab((s[\"tokens\"] for s in samples))\n",
"vocab_ner = build_vocab((s[\"ner\"] for s in samples), reserved=(\"<pad>\",\"<unk>\"))\n",
"vocab_srl = build_vocab((s[\"srl\"] for s in samples), reserved=(\"<pad>\",\"<unk>\"))\n",
"vocab_q = build_vocab((s[\"q_toks\"] for s in samples))\n",
"vocab_a = build_vocab((s[\"a_toks\"] for s in samples))\n",
"vocab_typ = {\"isian\":0, \"opsi\":1, \"benar_salah\":2}"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "fdf696cf",
"metadata": {},
"outputs": [],
"source": [
"def enc(seq, v): return [v.get(t, v[\"<unk>\"]) for t in seq]\n",
"\n",
"MAX_SENT = max(len(s[\"tokens\"]) for s in samples)\n",
"MAX_Q = max(len(s[\"q_toks\"]) for s in samples)\n",
"MAX_A = max(len(s[\"a_toks\"]) for s in samples)\n",
"\n",
"def pad_batch(seqs, vmap, maxlen):\n",
" return tf.keras.preprocessing.sequence.pad_sequences(\n",
" [enc(s, vmap) for s in seqs], maxlen=maxlen, padding=\"post\"\n",
" )\n",
"\n",
"X_tok = pad_batch((s[\"tokens\"] for s in samples), vocab_tok, MAX_SENT)\n",
"X_ner = pad_batch((s[\"ner\"] for s in samples), vocab_ner, MAX_SENT)\n",
"X_srl = pad_batch((s[\"srl\"] for s in samples), vocab_srl, MAX_SENT)\n",
"\n",
"dec_q_in = pad_batch(\n",
" ([[\"<sos>\"]+s[\"q_toks\"][:-1] for s in samples]), vocab_q, MAX_Q)\n",
"dec_q_out = pad_batch((s[\"q_toks\"] for s in samples), vocab_q, MAX_Q)\n",
"\n",
"dec_a_in = pad_batch(\n",
" ([[\"<sos>\"]+s[\"a_toks\"][:-1] for s in samples]), vocab_a, MAX_A)\n",
"dec_a_out = pad_batch((s[\"a_toks\"] for s in samples), vocab_a, MAX_A)\n",
"\n",
"y_type = np.array([vocab_typ[s[\"q_type\"]] for s in samples])\n"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "33074619",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">Model: \"functional_2\"</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1mModel: \"functional_2\"\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓\n",
"┃<span style=\"font-weight: bold\"> Layer (type) </span>┃<span style=\"font-weight: bold\"> Output Shape </span>┃<span style=\"font-weight: bold\"> Param # </span>┃<span style=\"font-weight: bold\"> Connected to </span>┃\n",
"┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩\n",
"│ tok_in (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">InputLayer</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">11</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ - │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ ner_in (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">InputLayer</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">11</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ - │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ srl_in (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">InputLayer</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">11</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ - │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ not_equal_8 │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">11</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ tok_in[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">NotEqual</span>) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ emb_ner (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Embedding</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">11</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">32</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">352</span> │ ner_in[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ emb_srl (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Embedding</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">11</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">32</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">288</span> │ srl_in[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ expand_dims_4 │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">11</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">1</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ not_equal_8[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">ExpandDims</span>) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ broadcast_to_4 │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">11</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">128</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ expand_dims_4[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>]… │\n",
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">BroadcastTo</span>) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ ones_like_2 │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">11</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">32</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ emb_ner[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">OnesLike</span>) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ ones_like_3 │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">11</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">32</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ emb_srl[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">OnesLike</span>) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ emb_tok (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Embedding</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">11</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">128</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">4,992</span> │ tok_in[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ concatenate_5 │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">11</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">192</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ broadcast_to_4[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>… │\n",
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Concatenate</span>) │ │ │ ones_like_2[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>… │\n",
"│ │ │ │ ones_like_3[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ dec_q_in │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">9</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ - │\n",
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">InputLayer</span>) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ concatenate_4 │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">11</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">192</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ emb_tok[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>], │\n",
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Concatenate</span>) │ │ │ emb_ner[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>], │\n",
"│ │ │ │ emb_srl[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ any_2 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Any</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">11</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ concatenate_5[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>]… │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ dec_a_in │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">4</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ - │\n",
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">InputLayer</span>) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ emb_q (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Embedding</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">9</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">128</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">3,968</span> │ dec_q_in[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ enc_lstm (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">LSTM</span>) │ [(<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">256</span>), │ <span style=\"color: #00af00; text-decoration-color: #00af00\">459,776</span> │ concatenate_4[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>]… │\n",
"│ │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">256</span>), │ │ any_2[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
"│ │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">256</span>)] │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ emb_a (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Embedding</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">4</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">128</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">1,792</span> │ dec_a_in[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ lstm_q (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">LSTM</span>) │ [(<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">9</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">256</span>), │ <span style=\"color: #00af00; text-decoration-color: #00af00\">394,240</span> │ emb_q[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>], │\n",
"│ │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">256</span>), │ │ enc_lstm[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">1</span>], │\n",
"│ │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">256</span>)] │ │ enc_lstm[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">2</span>] │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ not_equal_9 │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">9</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ dec_q_in[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">NotEqual</span>) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ lstm_a (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">LSTM</span>) │ [(<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">4</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">256</span>), │ <span style=\"color: #00af00; text-decoration-color: #00af00\">394,240</span> │ emb_a[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>], │\n",
"│ │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">256</span>), │ │ enc_lstm[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">1</span>], │\n",
"│ │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">256</span>)] │ │ enc_lstm[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">2</span>] │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ not_equal_10 │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">4</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ dec_a_in[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">NotEqual</span>) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ q_out │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">9</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">31</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">7,967</span> │ lstm_q[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>], │\n",
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">TimeDistributed</span>) │ │ │ not_equal_9[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ a_out │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">4</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">14</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">3,598</span> │ lstm_a[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>], │\n",
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">TimeDistributed</span>) │ │ │ not_equal_10[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">…</span> │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ type_out (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Dense</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">3</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">771</span> │ enc_lstm[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
"└─────────────────────┴───────────────────┴────────────┴───────────────────┘\n",
"</pre>\n"
],
"text/plain": [
"┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓\n",
"┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mConnected to \u001b[0m\u001b[1m \u001b[0m┃\n",
"┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩\n",
"│ tok_in (\u001b[38;5;33mInputLayer\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m11\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ ner_in (\u001b[38;5;33mInputLayer\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m11\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ srl_in (\u001b[38;5;33mInputLayer\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m11\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ not_equal_8 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m11\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ tok_in[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"│ (\u001b[38;5;33mNotEqual\u001b[0m) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ emb_ner (\u001b[38;5;33mEmbedding\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m11\u001b[0m, \u001b[38;5;34m32\u001b[0m) │ \u001b[38;5;34m352\u001b[0m │ ner_in[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ emb_srl (\u001b[38;5;33mEmbedding\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m11\u001b[0m, \u001b[38;5;34m32\u001b[0m) │ \u001b[38;5;34m288\u001b[0m │ srl_in[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ expand_dims_4 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m11\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ not_equal_8[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"│ (\u001b[38;5;33mExpandDims\u001b[0m) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ broadcast_to_4 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m11\u001b[0m, \u001b[38;5;34m128\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ expand_dims_4[\u001b[38;5;34m0\u001b[0m]… │\n",
"│ (\u001b[38;5;33mBroadcastTo\u001b[0m) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ ones_like_2 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m11\u001b[0m, \u001b[38;5;34m32\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ emb_ner[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"│ (\u001b[38;5;33mOnesLike\u001b[0m) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ ones_like_3 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m11\u001b[0m, \u001b[38;5;34m32\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ emb_srl[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"│ (\u001b[38;5;33mOnesLike\u001b[0m) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ emb_tok (\u001b[38;5;33mEmbedding\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m11\u001b[0m, \u001b[38;5;34m128\u001b[0m) │ \u001b[38;5;34m4,992\u001b[0m │ tok_in[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ concatenate_5 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m11\u001b[0m, \u001b[38;5;34m192\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ broadcast_to_4[\u001b[38;5;34m0\u001b[0m… │\n",
"│ (\u001b[38;5;33mConcatenate\u001b[0m) │ │ │ ones_like_2[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m… │\n",
"│ │ │ │ ones_like_3[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ dec_q_in │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m9\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n",
"│ (\u001b[38;5;33mInputLayer\u001b[0m) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ concatenate_4 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m11\u001b[0m, \u001b[38;5;34m192\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ emb_tok[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n",
"│ (\u001b[38;5;33mConcatenate\u001b[0m) │ │ │ emb_ner[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n",
"│ │ │ │ emb_srl[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ any_2 (\u001b[38;5;33mAny\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m11\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ concatenate_5[\u001b[38;5;34m0\u001b[0m]… │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ dec_a_in │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m4\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n",
"│ (\u001b[38;5;33mInputLayer\u001b[0m) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ emb_q (\u001b[38;5;33mEmbedding\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m9\u001b[0m, \u001b[38;5;34m128\u001b[0m) │ \u001b[38;5;34m3,968\u001b[0m │ dec_q_in[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ enc_lstm (\u001b[38;5;33mLSTM\u001b[0m) │ [(\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m256\u001b[0m), │ \u001b[38;5;34m459,776\u001b[0m │ concatenate_4[\u001b[38;5;34m0\u001b[0m]… │\n",
"│ │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m256\u001b[0m), │ │ any_2[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"│ │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m256\u001b[0m)] │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ emb_a (\u001b[38;5;33mEmbedding\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m4\u001b[0m, \u001b[38;5;34m128\u001b[0m) │ \u001b[38;5;34m1,792\u001b[0m │ dec_a_in[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ lstm_q (\u001b[38;5;33mLSTM\u001b[0m) │ [(\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m9\u001b[0m, \u001b[38;5;34m256\u001b[0m), │ \u001b[38;5;34m394,240\u001b[0m │ emb_q[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n",
"│ │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m256\u001b[0m), │ │ enc_lstm[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m1\u001b[0m], │\n",
"│ │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m256\u001b[0m)] │ │ enc_lstm[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m2\u001b[0m] │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ not_equal_9 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m9\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ dec_q_in[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"│ (\u001b[38;5;33mNotEqual\u001b[0m) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ lstm_a (\u001b[38;5;33mLSTM\u001b[0m) │ [(\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m4\u001b[0m, \u001b[38;5;34m256\u001b[0m), │ \u001b[38;5;34m394,240\u001b[0m │ emb_a[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n",
"│ │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m256\u001b[0m), │ │ enc_lstm[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m1\u001b[0m], │\n",
"│ │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m256\u001b[0m)] │ │ enc_lstm[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m2\u001b[0m] │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ not_equal_10 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m4\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ dec_a_in[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"│ (\u001b[38;5;33mNotEqual\u001b[0m) │ │ │ │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ q_out │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m9\u001b[0m, \u001b[38;5;34m31\u001b[0m) │ \u001b[38;5;34m7,967\u001b[0m │ lstm_q[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n",
"│ (\u001b[38;5;33mTimeDistributed\u001b[0m) │ │ │ not_equal_9[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ a_out │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m4\u001b[0m, \u001b[38;5;34m14\u001b[0m) │ \u001b[38;5;34m3,598\u001b[0m │ lstm_a[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n",
"│ (\u001b[38;5;33mTimeDistributed\u001b[0m) │ │ │ not_equal_10[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m…\u001b[0m │\n",
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
"│ type_out (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m3\u001b[0m) │ \u001b[38;5;34m771\u001b[0m │ enc_lstm[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"└─────────────────────┴───────────────────┴────────────┴───────────────────┘\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Total params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">1,271,984</span> (4.85 MB)\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m Total params: \u001b[0m\u001b[38;5;34m1,271,984\u001b[0m (4.85 MB)\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Trainable params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">1,271,984</span> (4.85 MB)\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m1,271,984\u001b[0m (4.85 MB)\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Non-trainable params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> (0.00 B)\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"d_tok, d_tag, units = 128, 32, 256\n",
"pad_tok, pad_q, pad_a = vocab_tok[\"<pad>\"], vocab_q[\"<pad>\"], vocab_a[\"<pad>\"]\n",
"\n",
"# ---- Encoder ----------------------------------------------------\n",
"inp_tok = Input((MAX_SENT,), name=\"tok_in\")\n",
"inp_ner = Input((MAX_SENT,), name=\"ner_in\")\n",
"inp_srl = Input((MAX_SENT,), name=\"srl_in\")\n",
"\n",
"emb_tok = Embedding(len(vocab_tok), d_tok, mask_zero=True, name=\"emb_tok\")(inp_tok)\n",
"emb_ner = Embedding(len(vocab_ner), d_tag, mask_zero=False, name=\"emb_ner\")(inp_ner)\n",
"emb_srl = Embedding(len(vocab_srl), d_tag, mask_zero=False, name=\"emb_srl\")(inp_srl)\n",
"\n",
"enc_concat = Concatenate()([emb_tok, emb_ner, emb_srl])\n",
"enc_out, state_h, state_c = LSTM(units, return_state=True, name=\"enc_lstm\")(enc_concat)\n",
"\n",
"# ---- Decoder : Question ----------------------------------------\n",
"dec_q_inp = Input((MAX_Q,), name=\"dec_q_in\")\n",
"dec_emb_q = Embedding(len(vocab_q), d_tok, mask_zero=True, name=\"emb_q\")(dec_q_inp)\n",
"dec_q_seq, _, _ = LSTM(units, return_sequences=True, return_state=True,\n",
" name=\"lstm_q\")(dec_emb_q, initial_state=[state_h, state_c])\n",
"q_out = TimeDistributed(Dense(len(vocab_q), activation=\"softmax\"), name=\"q_out\")(dec_q_seq)\n",
"\n",
"# ---- Decoder : Answer ------------------------------------------\n",
"dec_a_inp = Input((MAX_A,), name=\"dec_a_in\")\n",
"dec_emb_a = Embedding(len(vocab_a), d_tok, mask_zero=True, name=\"emb_a\")(dec_a_inp)\n",
"dec_a_seq, _, _ = LSTM(units, return_sequences=True, return_state=True,\n",
" name=\"lstm_a\")(dec_emb_a, initial_state=[state_h, state_c])\n",
"a_out = TimeDistributed(Dense(len(vocab_a), activation=\"softmax\"), name=\"a_out\")(dec_a_seq)\n",
"\n",
"# ---- Classifier -------------------------------------------------\n",
"type_out = Dense(len(vocab_typ), activation=\"softmax\", name=\"type_out\")(enc_out)\n",
"\n",
"model = Model(\n",
" [inp_tok, inp_ner, inp_srl, dec_q_inp, dec_a_inp],\n",
" [q_out, a_out, type_out]\n",
")\n",
"\n",
"# ---- Masked loss helpers ---------------------------------------\n",
"scce = tf.keras.losses.SparseCategoricalCrossentropy(reduction=\"none\")\n",
"def masked_loss_factory(pad_id):\n",
" def loss(y_true, y_pred):\n",
" l = scce(y_true, y_pred)\n",
" mask = tf.cast(tf.not_equal(y_true, pad_id), tf.float32)\n",
" return tf.reduce_sum(l*mask) / tf.reduce_sum(mask)\n",
" return loss\n",
"\n",
"model.compile(\n",
" optimizer=\"adam\",\n",
" loss = {\"q_out\":masked_loss_factory(pad_q),\n",
" \"a_out\":masked_loss_factory(pad_a),\n",
" \"type_out\":\"sparse_categorical_crossentropy\"},\n",
" loss_weights={\"q_out\":1.0, \"a_out\":1.0, \"type_out\":0.3},\n",
" metrics={\"q_out\":\"sparse_categorical_accuracy\",\n",
" \"a_out\":\"sparse_categorical_accuracy\",\n",
" \"type_out\":tf.keras.metrics.SparseCategoricalAccuracy(name=\"type_acc\")}\n",
")\n",
"model.summary()\n"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "44d36899",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/30\n"
]
},
{
"ename": "TypeError",
"evalue": "Exception encountered when calling BroadcastTo.call().\n\n\u001b[1mFailed to convert elements of (None, 11, 128) to Tensor. Consider casting elements to a supported type. See https://www.tensorflow.org/api_docs/python/tf/dtypes for supported TF dtypes.\u001b[0m\n\nArguments received by BroadcastTo.call():\n • x=tf.Tensor(shape=(None, 11, 1), dtype=bool)",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[18], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m history \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfit\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 2\u001b[0m \u001b[43m \u001b[49m\u001b[43m[\u001b[49m\u001b[43mX_tok\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mX_ner\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mX_srl\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdec_q_in\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdec_a_in\u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3\u001b[0m \u001b[43m \u001b[49m\u001b[43m[\u001b[49m\u001b[43mdec_q_out\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdec_a_out\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43my_type\u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 4\u001b[0m \u001b[43m \u001b[49m\u001b[43mvalidation_split\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m0.1\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 5\u001b[0m \u001b[43m \u001b[49m\u001b[43mepochs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m30\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 6\u001b[0m \u001b[43m \u001b[49m\u001b[43mbatch_size\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m64\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 7\u001b[0m \u001b[43m \u001b[49m\u001b[43mcallbacks\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m[\u001b[49m\u001b[43mtf\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mkeras\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcallbacks\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mEarlyStopping\u001b[49m\u001b[43m(\u001b[49m\u001b[43mpatience\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m4\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mrestore_best_weights\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 8\u001b[0m \u001b[43m \u001b[49m\u001b[43mverbose\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m2\u001b[39;49m\n\u001b[1;32m 9\u001b[0m \u001b[43m)\u001b[49m\n\u001b[1;32m 10\u001b[0m model\u001b[38;5;241m.\u001b[39msave(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfull_seq2seq.keras\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 13\u001b[0m \u001b[38;5;66;03m# -----------------------------------------------------------------\u001b[39;00m\n\u001b[1;32m 14\u001b[0m \u001b[38;5;66;03m# 5. SAVE VOCABS (.pkl keeps python dict intact)\u001b[39;00m\n\u001b[1;32m 15\u001b[0m \u001b[38;5;66;03m# -----------------------------------------------------------------\u001b[39;00m\n",
"File \u001b[0;32m/mnt/disc1/code/thesis_quiz_project/lstm-quiz/myenv/lib64/python3.10/site-packages/keras/src/utils/traceback_utils.py:122\u001b[0m, in \u001b[0;36mfilter_traceback.<locals>.error_handler\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 119\u001b[0m filtered_tb \u001b[38;5;241m=\u001b[39m _process_traceback_frames(e\u001b[38;5;241m.\u001b[39m__traceback__)\n\u001b[1;32m 120\u001b[0m \u001b[38;5;66;03m# To get the full stack trace, call:\u001b[39;00m\n\u001b[1;32m 121\u001b[0m \u001b[38;5;66;03m# `keras.config.disable_traceback_filtering()`\u001b[39;00m\n\u001b[0;32m--> 122\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m e\u001b[38;5;241m.\u001b[39mwith_traceback(filtered_tb) \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 123\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m 124\u001b[0m \u001b[38;5;28;01mdel\u001b[39;00m filtered_tb\n",
"File \u001b[0;32m/mnt/disc1/code/thesis_quiz_project/lstm-quiz/myenv/lib64/python3.10/site-packages/keras/src/utils/traceback_utils.py:122\u001b[0m, in \u001b[0;36mfilter_traceback.<locals>.error_handler\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 119\u001b[0m filtered_tb \u001b[38;5;241m=\u001b[39m _process_traceback_frames(e\u001b[38;5;241m.\u001b[39m__traceback__)\n\u001b[1;32m 120\u001b[0m \u001b[38;5;66;03m# To get the full stack trace, call:\u001b[39;00m\n\u001b[1;32m 121\u001b[0m \u001b[38;5;66;03m# `keras.config.disable_traceback_filtering()`\u001b[39;00m\n\u001b[0;32m--> 122\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m e\u001b[38;5;241m.\u001b[39mwith_traceback(filtered_tb) \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 123\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m 124\u001b[0m \u001b[38;5;28;01mdel\u001b[39;00m filtered_tb\n",
"\u001b[0;31mTypeError\u001b[0m: Exception encountered when calling BroadcastTo.call().\n\n\u001b[1mFailed to convert elements of (None, 11, 128) to Tensor. Consider casting elements to a supported type. See https://www.tensorflow.org/api_docs/python/tf/dtypes for supported TF dtypes.\u001b[0m\n\nArguments received by BroadcastTo.call():\n • x=tf.Tensor(shape=(None, 11, 1), dtype=bool)"
]
}
],
"source": [
"history = model.fit(\n",
" [X_tok, X_ner, X_srl, dec_q_in, dec_a_in],\n",
" [dec_q_out, dec_a_out, y_type],\n",
" validation_split=0.1,\n",
" epochs=30,\n",
" batch_size=64,\n",
" callbacks=[tf.keras.callbacks.EarlyStopping(patience=4, restore_best_weights=True)],\n",
" verbose=2\n",
")\n",
"model.save(\"full_seq2seq.keras\")\n",
"\n",
"\n",
"# -----------------------------------------------------------------\n",
"# 5. SAVE VOCABS (.pkl keeps python dict intact)\n",
"# -----------------------------------------------------------------\n",
"def save_vocab(v, name): pickle.dump(v, open(name,\"wb\"))\n",
"save_vocab(vocab_tok,\"vocab_tok.pkl\"); save_vocab(vocab_ner,\"vocab_ner.pkl\")\n",
"save_vocab(vocab_srl,\"vocab_srl.pkl\"); save_vocab(vocab_q, \"vocab_q.pkl\")\n",
"save_vocab(vocab_a, \"vocab_a.pkl\"); save_vocab(vocab_typ,\"vocab_typ.pkl\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "61003de5",
"metadata": {},
"outputs": [],
"source": [
"def build_inference_models(trained):\n",
" # encoder\n",
" t_in = Input((MAX_SENT,), name=\"t_in\")\n",
" n_in = Input((MAX_SENT,), name=\"n_in\")\n",
" s_in = Input((MAX_SENT,), name=\"s_in\")\n",
" e_t = trained.get_layer(\"emb_tok\")(t_in)\n",
" e_n = trained.get_layer(\"emb_ner\")(n_in)\n",
" e_s = trained.get_layer(\"emb_srl\")(s_in)\n",
" concat = Concatenate()([e_t,e_n,e_s])\n",
" _, h, c = trained.get_layer(\"enc_lstm\")(concat)\n",
" enc_model = Model([t_in,n_in,s_in],[h,c])\n",
"\n",
" # questiondecoder\n",
" dq_in = Input((1,), name=\"dq_tok\")\n",
" dh = Input((units,), name=\"dh\"); dc = Input((units,), name=\"dc\")\n",
" dq_emb = trained.get_layer(\"emb_q\")(dq_in)\n",
" dq_lstm, nh, nc = trained.get_layer(\"lstm_q\")(dq_emb, initial_state=[dh,dc])\n",
" dq_out = trained.get_layer(\"q_out\").layer(dq_lstm)\n",
" dec_q_model = Model([dq_in, dh, dc], [dq_out, nh, nc])\n",
"\n",
" # answerdecoder\n",
" da_in = Input((1,), name=\"da_tok\")\n",
" ah = Input((units,), name=\"ah\"); ac = Input((units,), name=\"ac\")\n",
" da_emb = trained.get_layer(\"emb_a\")(da_in)\n",
" da_lstm, nh2, nc2 = trained.get_layer(\"lstm_a\")(da_emb, initial_state=[ah,ac])\n",
" da_out = trained.get_layer(\"a_out\").layer(da_lstm)\n",
" dec_a_model = Model([da_in, ah, ac], [da_out, nh2, nc2])\n",
"\n",
" # type classifier\n",
" type_dense = trained.get_layer(\"type_out\")\n",
" type_model = Model([t_in,n_in,s_in], type_dense(_)) # use _ = enc_lstm output\n",
"\n",
" return enc_model, dec_q_model, dec_a_model, type_model\n",
"\n",
"encoder_model, decoder_q, decoder_a, classifier_model = build_inference_models(model)\n",
"\n",
"inv_q = {v:k for k,v in vocab_q.items()}\n",
"inv_a = {v:k for k,v in vocab_a.items()}\n",
"\n",
"def enc_pad(seq, vmap, maxlen):\n",
" x = [vmap.get(t, vmap[\"<unk>\"]) for t in seq]\n",
" return x + [vmap[\"<pad>\"]] * (maxlen-len(x))\n",
"\n",
"def greedy_decode(tokens, ner, srl, max_q=20, max_a=10):\n",
" et = np.array([enc_pad(tokens, vocab_tok, MAX_SENT)])\n",
" en = np.array([enc_pad(ner, vocab_ner, MAX_SENT)])\n",
" es = np.array([enc_pad(srl, vocab_srl, MAX_SENT)])\n",
"\n",
" h,c = encoder_model.predict([et,en,es], verbose=0)\n",
"\n",
" # --- question\n",
" q_ids = []\n",
" tgt = np.array([[vocab_q[\"<sos>\"]]])\n",
" for _ in range(max_q):\n",
" logits,h,c = decoder_q.predict([tgt,h,c], verbose=0)\n",
" nxt = int(logits[0,-1].argmax())\n",
" if nxt==vocab_q[\"<eos>\"]: break\n",
" q_ids.append(nxt)\n",
" tgt = np.array([[nxt]])\n",
"\n",
" # --- answer (reuse fresh h,c)\n",
" h,c = encoder_model.predict([et,en,es], verbose=0)\n",
" a_ids = []\n",
" tgt = np.array([[vocab_a[\"<sos>\"]]])\n",
" for _ in range(max_a):\n",
" logits,h,c = decoder_a.predict([tgt,h,c], verbose=0)\n",
" nxt = int(logits[0,-1].argmax())\n",
" if nxt==vocab_a[\"<eos>\"]: break\n",
" a_ids.append(nxt)\n",
" tgt = np.array([[nxt]])\n",
"\n",
" # --- type\n",
" t_id = int(classifier_model.predict([et,en,es], verbose=0).argmax())\n",
"\n",
" return [inv_q[i] for i in q_ids], [inv_a[i] for i in a_ids], \\\n",
" [k for k,v in vocab_typ.items() if v==t_id][0]\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5279b631",
"metadata": {},
"outputs": [],
"source": [
"test_tokens = [\"soekarno\",\"membacakan\",\"teks\",\"proklamasi\",\"pada\",\n",
" \"17\",\"agustus\",\"1945\"]\n",
"test_ner = [\"B-PER\",\"O\",\"O\",\"O\",\"O\",\"B-DATE\",\"I-DATE\",\"I-DATE\"]\n",
"test_srl = [\"ARG0\",\"V\",\"ARG1\",\"ARG1\",\"O\",\"ARGM-TMP\",\"ARGM-TMP\",\"ARGM-TMP\"]\n",
"\n",
"q,a,t = greedy_decode(test_tokens,test_ner,test_srl,max_q=MAX_Q,max_a=MAX_A)\n",
"print(\"\\nDEMO\\n----\")\n",
"print(\"Q :\", \" \".join(q))\n",
"print(\"A :\", \" \".join(a))\n",
"print(\"T :\", t)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "850d4905",
"metadata": {},
"outputs": [],
"source": [
"smooth = SmoothingFunction().method4\n",
"r_scorer = rouge_scorer.RougeScorer([\"rouge1\",\"rougeL\"], use_stemmer=True)\n",
"\n",
"def strip_special(seq, pad_id, eos_id):\n",
" return [x for x in seq if x not in (pad_id, eos_id)]\n",
"\n",
"def ids_to_text(ids, inv):\n",
" return \" \".join(inv[i] for i in ids)\n",
"\n",
"def evaluate(n=200):\n",
" idxs = random.sample(range(len(samples)), n)\n",
" refs, hyps = [], []\n",
" agg = scoring.BootstrapAggregator()\n",
"\n",
" for i in idxs:\n",
" gt_ids = strip_special(dec_q_out[i], pad_q, vocab_q[\"<eos>\"])\n",
" ref = ids_to_text(gt_ids, inv_q)\n",
" pred = \" \".join(greedy_decode(\n",
" samples[i][\"tokens\"],\n",
" samples[i][\"ner\"],\n",
" samples[i][\"srl\"]\n",
" )[0])\n",
" refs.append([ref.split()])\n",
" hyps.append(pred.split())\n",
" agg.add_scores(r_scorer.score(ref, pred))\n",
"\n",
" bleu = corpus_bleu(refs, hyps, smoothing_function=smooth)\n",
" r1 = agg.aggregate()[\"rouge1\"].mid\n",
" rL = agg.aggregate()[\"rougeL\"].mid\n",
"\n",
" print(f\"\\nEVAL (n={n})\")\n",
" print(f\"BLEU4 : {bleu:.4f}\")\n",
" print(f\"ROUGE1 : P={r1.precision:.3f} R={r1.recall:.3f} F1={r1.fmeasure:.3f}\")\n",
" print(f\"ROUGEL : P={rL.precision:.3f} R={rL.recall:.3f} F1={rL.fmeasure:.3f}\")\n",
"\n",
"evaluate(2) "
]
}
],
"metadata": {
"kernelspec": {
"display_name": "myenv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.16"
}
},
"nbformat": 4,
"nbformat_minor": 5
}