616 lines
63 KiB
Plaintext
616 lines
63 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 13,
|
||
"id": "58e41ccb",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"import json, pickle, random\n",
|
||
"from pathlib import Path\n",
|
||
"from itertools import chain\n",
|
||
"\n",
|
||
"import numpy as np\n",
|
||
"import tensorflow as tf\n",
|
||
"from tensorflow.keras.layers import (\n",
|
||
" Input, Embedding, LSTM, Concatenate,\n",
|
||
" Dense, TimeDistributed\n",
|
||
")\n",
|
||
"from tensorflow.keras.models import Model\n",
|
||
"from nltk.translate.bleu_score import corpus_bleu, SmoothingFunction\n",
|
||
"from rouge_score import rouge_scorer, scoring\n",
|
||
"\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 14,
|
||
"id": "a94dd46a",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"flattened samples : 8\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"RAW = json.loads(Path(\"../dataset/dev_dataset_qg.json\").read_text())\n",
|
||
"\n",
|
||
"samples = []\n",
|
||
"for item in RAW:\n",
|
||
" for qp in item[\"quiz_posibility\"]:\n",
|
||
" samples.append({\n",
|
||
" \"tokens\" : item[\"tokens\"],\n",
|
||
" \"ner\" : item[\"ner\"],\n",
|
||
" \"srl\" : item[\"srl\"],\n",
|
||
" \"q_type\" : qp[\"type\"], # isian / opsi / benar_salah\n",
|
||
" \"q_toks\" : qp[\"question\"] + [\"<eos>\"],\n",
|
||
" \"a_toks\" : (qp[\"answer\"] if isinstance(qp[\"answer\"], list)\n",
|
||
" else [qp[\"answer\"]]) + [\"<eos>\"]\n",
|
||
" })\n",
|
||
"\n",
|
||
"print(\"flattened samples :\", len(samples))\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 15,
|
||
"id": "852fb9a8",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"def build_vocab(seq_iter, reserved=(\"<pad>\", \"<unk>\", \"<sos>\", \"<eos>\")):\n",
|
||
" vocab = {tok: idx for idx, tok in enumerate(reserved)}\n",
|
||
" for tok in chain.from_iterable(seq_iter):\n",
|
||
" vocab.setdefault(tok, len(vocab))\n",
|
||
" return vocab\n",
|
||
"\n",
|
||
"vocab_tok = build_vocab((s[\"tokens\"] for s in samples))\n",
|
||
"vocab_ner = build_vocab((s[\"ner\"] for s in samples), reserved=(\"<pad>\",\"<unk>\"))\n",
|
||
"vocab_srl = build_vocab((s[\"srl\"] for s in samples), reserved=(\"<pad>\",\"<unk>\"))\n",
|
||
"vocab_q = build_vocab((s[\"q_toks\"] for s in samples))\n",
|
||
"vocab_a = build_vocab((s[\"a_toks\"] for s in samples))\n",
|
||
"vocab_typ = {\"isian\":0, \"opsi\":1, \"benar_salah\":2}"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 16,
|
||
"id": "fdf696cf",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"def enc(seq, v): return [v.get(t, v[\"<unk>\"]) for t in seq]\n",
|
||
"\n",
|
||
"MAX_SENT = max(len(s[\"tokens\"]) for s in samples)\n",
|
||
"MAX_Q = max(len(s[\"q_toks\"]) for s in samples)\n",
|
||
"MAX_A = max(len(s[\"a_toks\"]) for s in samples)\n",
|
||
"\n",
|
||
"def pad_batch(seqs, vmap, maxlen):\n",
|
||
" return tf.keras.preprocessing.sequence.pad_sequences(\n",
|
||
" [enc(s, vmap) for s in seqs], maxlen=maxlen, padding=\"post\"\n",
|
||
" )\n",
|
||
"\n",
|
||
"X_tok = pad_batch((s[\"tokens\"] for s in samples), vocab_tok, MAX_SENT)\n",
|
||
"X_ner = pad_batch((s[\"ner\"] for s in samples), vocab_ner, MAX_SENT)\n",
|
||
"X_srl = pad_batch((s[\"srl\"] for s in samples), vocab_srl, MAX_SENT)\n",
|
||
"\n",
|
||
"dec_q_in = pad_batch(\n",
|
||
" ([[\"<sos>\"]+s[\"q_toks\"][:-1] for s in samples]), vocab_q, MAX_Q)\n",
|
||
"dec_q_out = pad_batch((s[\"q_toks\"] for s in samples), vocab_q, MAX_Q)\n",
|
||
"\n",
|
||
"dec_a_in = pad_batch(\n",
|
||
" ([[\"<sos>\"]+s[\"a_toks\"][:-1] for s in samples]), vocab_a, MAX_A)\n",
|
||
"dec_a_out = pad_batch((s[\"a_toks\"] for s in samples), vocab_a, MAX_A)\n",
|
||
"\n",
|
||
"y_type = np.array([vocab_typ[s[\"q_type\"]] for s in samples])\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 17,
|
||
"id": "33074619",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">Model: \"functional_2\"</span>\n",
|
||
"</pre>\n"
|
||
],
|
||
"text/plain": [
|
||
"\u001b[1mModel: \"functional_2\"\u001b[0m\n"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓\n",
|
||
"┃<span style=\"font-weight: bold\"> Layer (type) </span>┃<span style=\"font-weight: bold\"> Output Shape </span>┃<span style=\"font-weight: bold\"> Param # </span>┃<span style=\"font-weight: bold\"> Connected to </span>┃\n",
|
||
"┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩\n",
|
||
"│ tok_in (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">InputLayer</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">11</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ - │\n",
|
||
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
|
||
"│ ner_in (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">InputLayer</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">11</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ - │\n",
|
||
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
|
||
"│ srl_in (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">InputLayer</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">11</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ - │\n",
|
||
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
|
||
"│ not_equal_8 │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">11</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ tok_in[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
|
||
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">NotEqual</span>) │ │ │ │\n",
|
||
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
|
||
"│ emb_ner (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Embedding</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">11</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">32</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">352</span> │ ner_in[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
|
||
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
|
||
"│ emb_srl (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Embedding</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">11</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">32</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">288</span> │ srl_in[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
|
||
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
|
||
"│ expand_dims_4 │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">11</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">1</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ not_equal_8[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
|
||
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">ExpandDims</span>) │ │ │ │\n",
|
||
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
|
||
"│ broadcast_to_4 │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">11</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">128</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ expand_dims_4[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>]… │\n",
|
||
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">BroadcastTo</span>) │ │ │ │\n",
|
||
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
|
||
"│ ones_like_2 │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">11</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">32</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ emb_ner[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
|
||
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">OnesLike</span>) │ │ │ │\n",
|
||
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
|
||
"│ ones_like_3 │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">11</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">32</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ emb_srl[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
|
||
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">OnesLike</span>) │ │ │ │\n",
|
||
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
|
||
"│ emb_tok (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Embedding</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">11</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">128</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">4,992</span> │ tok_in[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
|
||
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
|
||
"│ concatenate_5 │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">11</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">192</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ broadcast_to_4[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>… │\n",
|
||
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Concatenate</span>) │ │ │ ones_like_2[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>… │\n",
|
||
"│ │ │ │ ones_like_3[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
|
||
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
|
||
"│ dec_q_in │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">9</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ - │\n",
|
||
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">InputLayer</span>) │ │ │ │\n",
|
||
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
|
||
"│ concatenate_4 │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">11</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">192</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ emb_tok[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>], │\n",
|
||
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Concatenate</span>) │ │ │ emb_ner[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>], │\n",
|
||
"│ │ │ │ emb_srl[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
|
||
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
|
||
"│ any_2 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Any</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">11</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ concatenate_5[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>]… │\n",
|
||
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
|
||
"│ dec_a_in │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">4</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ - │\n",
|
||
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">InputLayer</span>) │ │ │ │\n",
|
||
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
|
||
"│ emb_q (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Embedding</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">9</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">128</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">3,968</span> │ dec_q_in[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
|
||
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
|
||
"│ enc_lstm (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">LSTM</span>) │ [(<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">256</span>), │ <span style=\"color: #00af00; text-decoration-color: #00af00\">459,776</span> │ concatenate_4[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>]… │\n",
|
||
"│ │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">256</span>), │ │ any_2[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
|
||
"│ │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">256</span>)] │ │ │\n",
|
||
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
|
||
"│ emb_a (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Embedding</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">4</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">128</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">1,792</span> │ dec_a_in[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
|
||
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
|
||
"│ lstm_q (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">LSTM</span>) │ [(<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">9</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">256</span>), │ <span style=\"color: #00af00; text-decoration-color: #00af00\">394,240</span> │ emb_q[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>], │\n",
|
||
"│ │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">256</span>), │ │ enc_lstm[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">1</span>], │\n",
|
||
"│ │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">256</span>)] │ │ enc_lstm[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">2</span>] │\n",
|
||
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
|
||
"│ not_equal_9 │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">9</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ dec_q_in[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
|
||
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">NotEqual</span>) │ │ │ │\n",
|
||
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
|
||
"│ lstm_a (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">LSTM</span>) │ [(<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">4</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">256</span>), │ <span style=\"color: #00af00; text-decoration-color: #00af00\">394,240</span> │ emb_a[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>], │\n",
|
||
"│ │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">256</span>), │ │ enc_lstm[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">1</span>], │\n",
|
||
"│ │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">256</span>)] │ │ enc_lstm[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">2</span>] │\n",
|
||
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
|
||
"│ not_equal_10 │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">4</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ dec_a_in[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
|
||
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">NotEqual</span>) │ │ │ │\n",
|
||
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
|
||
"│ q_out │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">9</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">31</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">7,967</span> │ lstm_q[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>], │\n",
|
||
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">TimeDistributed</span>) │ │ │ not_equal_9[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
|
||
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
|
||
"│ a_out │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">4</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">14</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">3,598</span> │ lstm_a[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>], │\n",
|
||
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">TimeDistributed</span>) │ │ │ not_equal_10[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">…</span> │\n",
|
||
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
|
||
"│ type_out (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Dense</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">3</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">771</span> │ enc_lstm[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
|
||
"└─────────────────────┴───────────────────┴────────────┴───────────────────┘\n",
|
||
"</pre>\n"
|
||
],
|
||
"text/plain": [
|
||
"┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓\n",
|
||
"┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mConnected to \u001b[0m\u001b[1m \u001b[0m┃\n",
|
||
"┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩\n",
|
||
"│ tok_in (\u001b[38;5;33mInputLayer\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m11\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n",
|
||
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
|
||
"│ ner_in (\u001b[38;5;33mInputLayer\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m11\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n",
|
||
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
|
||
"│ srl_in (\u001b[38;5;33mInputLayer\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m11\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n",
|
||
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
|
||
"│ not_equal_8 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m11\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ tok_in[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
|
||
"│ (\u001b[38;5;33mNotEqual\u001b[0m) │ │ │ │\n",
|
||
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
|
||
"│ emb_ner (\u001b[38;5;33mEmbedding\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m11\u001b[0m, \u001b[38;5;34m32\u001b[0m) │ \u001b[38;5;34m352\u001b[0m │ ner_in[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
|
||
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
|
||
"│ emb_srl (\u001b[38;5;33mEmbedding\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m11\u001b[0m, \u001b[38;5;34m32\u001b[0m) │ \u001b[38;5;34m288\u001b[0m │ srl_in[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
|
||
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
|
||
"│ expand_dims_4 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m11\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ not_equal_8[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
|
||
"│ (\u001b[38;5;33mExpandDims\u001b[0m) │ │ │ │\n",
|
||
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
|
||
"│ broadcast_to_4 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m11\u001b[0m, \u001b[38;5;34m128\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ expand_dims_4[\u001b[38;5;34m0\u001b[0m]… │\n",
|
||
"│ (\u001b[38;5;33mBroadcastTo\u001b[0m) │ │ │ │\n",
|
||
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
|
||
"│ ones_like_2 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m11\u001b[0m, \u001b[38;5;34m32\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ emb_ner[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
|
||
"│ (\u001b[38;5;33mOnesLike\u001b[0m) │ │ │ │\n",
|
||
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
|
||
"│ ones_like_3 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m11\u001b[0m, \u001b[38;5;34m32\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ emb_srl[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
|
||
"│ (\u001b[38;5;33mOnesLike\u001b[0m) │ │ │ │\n",
|
||
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
|
||
"│ emb_tok (\u001b[38;5;33mEmbedding\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m11\u001b[0m, \u001b[38;5;34m128\u001b[0m) │ \u001b[38;5;34m4,992\u001b[0m │ tok_in[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
|
||
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
|
||
"│ concatenate_5 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m11\u001b[0m, \u001b[38;5;34m192\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ broadcast_to_4[\u001b[38;5;34m0\u001b[0m… │\n",
|
||
"│ (\u001b[38;5;33mConcatenate\u001b[0m) │ │ │ ones_like_2[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m… │\n",
|
||
"│ │ │ │ ones_like_3[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
|
||
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
|
||
"│ dec_q_in │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m9\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n",
|
||
"│ (\u001b[38;5;33mInputLayer\u001b[0m) │ │ │ │\n",
|
||
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
|
||
"│ concatenate_4 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m11\u001b[0m, \u001b[38;5;34m192\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ emb_tok[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n",
|
||
"│ (\u001b[38;5;33mConcatenate\u001b[0m) │ │ │ emb_ner[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n",
|
||
"│ │ │ │ emb_srl[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
|
||
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
|
||
"│ any_2 (\u001b[38;5;33mAny\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m11\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ concatenate_5[\u001b[38;5;34m0\u001b[0m]… │\n",
|
||
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
|
||
"│ dec_a_in │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m4\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n",
|
||
"│ (\u001b[38;5;33mInputLayer\u001b[0m) │ │ │ │\n",
|
||
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
|
||
"│ emb_q (\u001b[38;5;33mEmbedding\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m9\u001b[0m, \u001b[38;5;34m128\u001b[0m) │ \u001b[38;5;34m3,968\u001b[0m │ dec_q_in[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
|
||
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
|
||
"│ enc_lstm (\u001b[38;5;33mLSTM\u001b[0m) │ [(\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m256\u001b[0m), │ \u001b[38;5;34m459,776\u001b[0m │ concatenate_4[\u001b[38;5;34m0\u001b[0m]… │\n",
|
||
"│ │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m256\u001b[0m), │ │ any_2[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
|
||
"│ │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m256\u001b[0m)] │ │ │\n",
|
||
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
|
||
"│ emb_a (\u001b[38;5;33mEmbedding\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m4\u001b[0m, \u001b[38;5;34m128\u001b[0m) │ \u001b[38;5;34m1,792\u001b[0m │ dec_a_in[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
|
||
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
|
||
"│ lstm_q (\u001b[38;5;33mLSTM\u001b[0m) │ [(\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m9\u001b[0m, \u001b[38;5;34m256\u001b[0m), │ \u001b[38;5;34m394,240\u001b[0m │ emb_q[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n",
|
||
"│ │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m256\u001b[0m), │ │ enc_lstm[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m1\u001b[0m], │\n",
|
||
"│ │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m256\u001b[0m)] │ │ enc_lstm[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m2\u001b[0m] │\n",
|
||
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
|
||
"│ not_equal_9 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m9\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ dec_q_in[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
|
||
"│ (\u001b[38;5;33mNotEqual\u001b[0m) │ │ │ │\n",
|
||
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
|
||
"│ lstm_a (\u001b[38;5;33mLSTM\u001b[0m) │ [(\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m4\u001b[0m, \u001b[38;5;34m256\u001b[0m), │ \u001b[38;5;34m394,240\u001b[0m │ emb_a[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n",
|
||
"│ │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m256\u001b[0m), │ │ enc_lstm[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m1\u001b[0m], │\n",
|
||
"│ │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m256\u001b[0m)] │ │ enc_lstm[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m2\u001b[0m] │\n",
|
||
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
|
||
"│ not_equal_10 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m4\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ dec_a_in[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
|
||
"│ (\u001b[38;5;33mNotEqual\u001b[0m) │ │ │ │\n",
|
||
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
|
||
"│ q_out │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m9\u001b[0m, \u001b[38;5;34m31\u001b[0m) │ \u001b[38;5;34m7,967\u001b[0m │ lstm_q[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n",
|
||
"│ (\u001b[38;5;33mTimeDistributed\u001b[0m) │ │ │ not_equal_9[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
|
||
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
|
||
"│ a_out │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m4\u001b[0m, \u001b[38;5;34m14\u001b[0m) │ \u001b[38;5;34m3,598\u001b[0m │ lstm_a[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n",
|
||
"│ (\u001b[38;5;33mTimeDistributed\u001b[0m) │ │ │ not_equal_10[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m…\u001b[0m │\n",
|
||
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
|
||
"│ type_out (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m3\u001b[0m) │ \u001b[38;5;34m771\u001b[0m │ enc_lstm[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
|
||
"└─────────────────────┴───────────────────┴────────────┴───────────────────┘\n"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Total params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">1,271,984</span> (4.85 MB)\n",
|
||
"</pre>\n"
|
||
],
|
||
"text/plain": [
|
||
"\u001b[1m Total params: \u001b[0m\u001b[38;5;34m1,271,984\u001b[0m (4.85 MB)\n"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Trainable params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">1,271,984</span> (4.85 MB)\n",
|
||
"</pre>\n"
|
||
],
|
||
"text/plain": [
|
||
"\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m1,271,984\u001b[0m (4.85 MB)\n"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Non-trainable params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> (0.00 B)\n",
|
||
"</pre>\n"
|
||
],
|
||
"text/plain": [
|
||
"\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
}
|
||
],
|
||
"source": [
|
||
"d_tok, d_tag, units = 128, 32, 256\n",
|
||
"pad_tok, pad_q, pad_a = vocab_tok[\"<pad>\"], vocab_q[\"<pad>\"], vocab_a[\"<pad>\"]\n",
|
||
"\n",
|
||
"# ---- Encoder ----------------------------------------------------\n",
|
||
"inp_tok = Input((MAX_SENT,), name=\"tok_in\")\n",
|
||
"inp_ner = Input((MAX_SENT,), name=\"ner_in\")\n",
|
||
"inp_srl = Input((MAX_SENT,), name=\"srl_in\")\n",
|
||
"\n",
|
||
"emb_tok = Embedding(len(vocab_tok), d_tok, mask_zero=True, name=\"emb_tok\")(inp_tok)\n",
|
||
"emb_ner = Embedding(len(vocab_ner), d_tag, mask_zero=False, name=\"emb_ner\")(inp_ner)\n",
|
||
"emb_srl = Embedding(len(vocab_srl), d_tag, mask_zero=False, name=\"emb_srl\")(inp_srl)\n",
|
||
"\n",
|
||
"enc_concat = Concatenate()([emb_tok, emb_ner, emb_srl])\n",
|
||
"enc_out, state_h, state_c = LSTM(units, return_state=True, name=\"enc_lstm\")(enc_concat)\n",
|
||
"\n",
|
||
"# ---- Decoder : Question ----------------------------------------\n",
|
||
"dec_q_inp = Input((MAX_Q,), name=\"dec_q_in\")\n",
|
||
"dec_emb_q = Embedding(len(vocab_q), d_tok, mask_zero=True, name=\"emb_q\")(dec_q_inp)\n",
|
||
"dec_q_seq, _, _ = LSTM(units, return_sequences=True, return_state=True,\n",
|
||
" name=\"lstm_q\")(dec_emb_q, initial_state=[state_h, state_c])\n",
|
||
"q_out = TimeDistributed(Dense(len(vocab_q), activation=\"softmax\"), name=\"q_out\")(dec_q_seq)\n",
|
||
"\n",
|
||
"# ---- Decoder : Answer ------------------------------------------\n",
|
||
"dec_a_inp = Input((MAX_A,), name=\"dec_a_in\")\n",
|
||
"dec_emb_a = Embedding(len(vocab_a), d_tok, mask_zero=True, name=\"emb_a\")(dec_a_inp)\n",
|
||
"dec_a_seq, _, _ = LSTM(units, return_sequences=True, return_state=True,\n",
|
||
" name=\"lstm_a\")(dec_emb_a, initial_state=[state_h, state_c])\n",
|
||
"a_out = TimeDistributed(Dense(len(vocab_a), activation=\"softmax\"), name=\"a_out\")(dec_a_seq)\n",
|
||
"\n",
|
||
"# ---- Classifier -------------------------------------------------\n",
|
||
"type_out = Dense(len(vocab_typ), activation=\"softmax\", name=\"type_out\")(enc_out)\n",
|
||
"\n",
|
||
"model = Model(\n",
|
||
" [inp_tok, inp_ner, inp_srl, dec_q_inp, dec_a_inp],\n",
|
||
" [q_out, a_out, type_out]\n",
|
||
")\n",
|
||
"\n",
|
||
"# ---- Masked loss helpers ---------------------------------------\n",
|
||
"scce = tf.keras.losses.SparseCategoricalCrossentropy(reduction=\"none\")\n",
|
||
"def masked_loss_factory(pad_id):\n",
|
||
" def loss(y_true, y_pred):\n",
|
||
" l = scce(y_true, y_pred)\n",
|
||
" mask = tf.cast(tf.not_equal(y_true, pad_id), tf.float32)\n",
|
||
" return tf.reduce_sum(l*mask) / tf.reduce_sum(mask)\n",
|
||
" return loss\n",
|
||
"\n",
|
||
"model.compile(\n",
|
||
" optimizer=\"adam\",\n",
|
||
" loss = {\"q_out\":masked_loss_factory(pad_q),\n",
|
||
" \"a_out\":masked_loss_factory(pad_a),\n",
|
||
" \"type_out\":\"sparse_categorical_crossentropy\"},\n",
|
||
" loss_weights={\"q_out\":1.0, \"a_out\":1.0, \"type_out\":0.3},\n",
|
||
" metrics={\"q_out\":\"sparse_categorical_accuracy\",\n",
|
||
" \"a_out\":\"sparse_categorical_accuracy\",\n",
|
||
" \"type_out\":tf.keras.metrics.SparseCategoricalAccuracy(name=\"type_acc\")}\n",
|
||
")\n",
|
||
"model.summary()\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 18,
|
||
"id": "44d36899",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Epoch 1/30\n"
|
||
]
|
||
},
|
||
{
|
||
"ename": "TypeError",
|
||
"evalue": "Exception encountered when calling BroadcastTo.call().\n\n\u001b[1mFailed to convert elements of (None, 11, 128) to Tensor. Consider casting elements to a supported type. See https://www.tensorflow.org/api_docs/python/tf/dtypes for supported TF dtypes.\u001b[0m\n\nArguments received by BroadcastTo.call():\n • x=tf.Tensor(shape=(None, 11, 1), dtype=bool)",
|
||
"output_type": "error",
|
||
"traceback": [
|
||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
||
"\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
|
||
"Cell \u001b[0;32mIn[18], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m history \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfit\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 2\u001b[0m \u001b[43m \u001b[49m\u001b[43m[\u001b[49m\u001b[43mX_tok\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mX_ner\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mX_srl\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdec_q_in\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdec_a_in\u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3\u001b[0m \u001b[43m \u001b[49m\u001b[43m[\u001b[49m\u001b[43mdec_q_out\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdec_a_out\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43my_type\u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 4\u001b[0m \u001b[43m \u001b[49m\u001b[43mvalidation_split\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m0.1\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 5\u001b[0m \u001b[43m \u001b[49m\u001b[43mepochs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m30\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 6\u001b[0m \u001b[43m \u001b[49m\u001b[43mbatch_size\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m64\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 7\u001b[0m \u001b[43m \u001b[49m\u001b[43mcallbacks\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m[\u001b[49m\u001b[43mtf\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mkeras\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcallbacks\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mEarlyStopping\u001b[49m\u001b[43m(\u001b[49m\u001b[43mpatience\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m4\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mrestore_best_weights\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 8\u001b[0m \u001b[43m \u001b[49m\u001b[43mverbose\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m2\u001b[39;49m\n\u001b[1;32m 9\u001b[0m \u001b[43m)\u001b[49m\n\u001b[1;32m 10\u001b[0m model\u001b[38;5;241m.\u001b[39msave(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfull_seq2seq.keras\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 13\u001b[0m \u001b[38;5;66;03m# -----------------------------------------------------------------\u001b[39;00m\n\u001b[1;32m 14\u001b[0m \u001b[38;5;66;03m# 5. SAVE VOCABS (.pkl keeps python dict intact)\u001b[39;00m\n\u001b[1;32m 15\u001b[0m \u001b[38;5;66;03m# -----------------------------------------------------------------\u001b[39;00m\n",
|
||
"File \u001b[0;32m/mnt/disc1/code/thesis_quiz_project/lstm-quiz/myenv/lib64/python3.10/site-packages/keras/src/utils/traceback_utils.py:122\u001b[0m, in \u001b[0;36mfilter_traceback.<locals>.error_handler\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 119\u001b[0m filtered_tb \u001b[38;5;241m=\u001b[39m _process_traceback_frames(e\u001b[38;5;241m.\u001b[39m__traceback__)\n\u001b[1;32m 120\u001b[0m \u001b[38;5;66;03m# To get the full stack trace, call:\u001b[39;00m\n\u001b[1;32m 121\u001b[0m \u001b[38;5;66;03m# `keras.config.disable_traceback_filtering()`\u001b[39;00m\n\u001b[0;32m--> 122\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m e\u001b[38;5;241m.\u001b[39mwith_traceback(filtered_tb) \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 123\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m 124\u001b[0m \u001b[38;5;28;01mdel\u001b[39;00m filtered_tb\n",
|
||
"File \u001b[0;32m/mnt/disc1/code/thesis_quiz_project/lstm-quiz/myenv/lib64/python3.10/site-packages/keras/src/utils/traceback_utils.py:122\u001b[0m, in \u001b[0;36mfilter_traceback.<locals>.error_handler\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 119\u001b[0m filtered_tb \u001b[38;5;241m=\u001b[39m _process_traceback_frames(e\u001b[38;5;241m.\u001b[39m__traceback__)\n\u001b[1;32m 120\u001b[0m \u001b[38;5;66;03m# To get the full stack trace, call:\u001b[39;00m\n\u001b[1;32m 121\u001b[0m \u001b[38;5;66;03m# `keras.config.disable_traceback_filtering()`\u001b[39;00m\n\u001b[0;32m--> 122\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m e\u001b[38;5;241m.\u001b[39mwith_traceback(filtered_tb) \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 123\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m 124\u001b[0m \u001b[38;5;28;01mdel\u001b[39;00m filtered_tb\n",
|
||
"\u001b[0;31mTypeError\u001b[0m: Exception encountered when calling BroadcastTo.call().\n\n\u001b[1mFailed to convert elements of (None, 11, 128) to Tensor. Consider casting elements to a supported type. See https://www.tensorflow.org/api_docs/python/tf/dtypes for supported TF dtypes.\u001b[0m\n\nArguments received by BroadcastTo.call():\n • x=tf.Tensor(shape=(None, 11, 1), dtype=bool)"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"history = model.fit(\n",
|
||
" [X_tok, X_ner, X_srl, dec_q_in, dec_a_in],\n",
|
||
" [dec_q_out, dec_a_out, y_type],\n",
|
||
" validation_split=0.1,\n",
|
||
" epochs=30,\n",
|
||
" batch_size=64,\n",
|
||
" callbacks=[tf.keras.callbacks.EarlyStopping(patience=4, restore_best_weights=True)],\n",
|
||
" verbose=2\n",
|
||
")\n",
|
||
"model.save(\"full_seq2seq.keras\")\n",
|
||
"\n",
|
||
"\n",
|
||
"# -----------------------------------------------------------------\n",
|
||
"# 5. SAVE VOCABS (.pkl keeps python dict intact)\n",
|
||
"# -----------------------------------------------------------------\n",
|
||
"def save_vocab(v, name): pickle.dump(v, open(name,\"wb\"))\n",
|
||
"save_vocab(vocab_tok,\"vocab_tok.pkl\"); save_vocab(vocab_ner,\"vocab_ner.pkl\")\n",
|
||
"save_vocab(vocab_srl,\"vocab_srl.pkl\"); save_vocab(vocab_q, \"vocab_q.pkl\")\n",
|
||
"save_vocab(vocab_a, \"vocab_a.pkl\"); save_vocab(vocab_typ,\"vocab_typ.pkl\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "61003de5",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"def build_inference_models(trained):\n",
|
||
" # encoder\n",
|
||
" t_in = Input((MAX_SENT,), name=\"t_in\")\n",
|
||
" n_in = Input((MAX_SENT,), name=\"n_in\")\n",
|
||
" s_in = Input((MAX_SENT,), name=\"s_in\")\n",
|
||
" e_t = trained.get_layer(\"emb_tok\")(t_in)\n",
|
||
" e_n = trained.get_layer(\"emb_ner\")(n_in)\n",
|
||
" e_s = trained.get_layer(\"emb_srl\")(s_in)\n",
|
||
" concat = Concatenate()([e_t,e_n,e_s])\n",
|
||
" _, h, c = trained.get_layer(\"enc_lstm\")(concat)\n",
|
||
" enc_model = Model([t_in,n_in,s_in],[h,c])\n",
|
||
"\n",
|
||
" # question‑decoder\n",
|
||
" dq_in = Input((1,), name=\"dq_tok\")\n",
|
||
" dh = Input((units,), name=\"dh\"); dc = Input((units,), name=\"dc\")\n",
|
||
" dq_emb = trained.get_layer(\"emb_q\")(dq_in)\n",
|
||
" dq_lstm, nh, nc = trained.get_layer(\"lstm_q\")(dq_emb, initial_state=[dh,dc])\n",
|
||
" dq_out = trained.get_layer(\"q_out\").layer(dq_lstm)\n",
|
||
" dec_q_model = Model([dq_in, dh, dc], [dq_out, nh, nc])\n",
|
||
"\n",
|
||
" # answer‑decoder\n",
|
||
" da_in = Input((1,), name=\"da_tok\")\n",
|
||
" ah = Input((units,), name=\"ah\"); ac = Input((units,), name=\"ac\")\n",
|
||
" da_emb = trained.get_layer(\"emb_a\")(da_in)\n",
|
||
" da_lstm, nh2, nc2 = trained.get_layer(\"lstm_a\")(da_emb, initial_state=[ah,ac])\n",
|
||
" da_out = trained.get_layer(\"a_out\").layer(da_lstm)\n",
|
||
" dec_a_model = Model([da_in, ah, ac], [da_out, nh2, nc2])\n",
|
||
"\n",
|
||
" # type classifier\n",
|
||
" type_dense = trained.get_layer(\"type_out\")\n",
|
||
" type_model = Model([t_in,n_in,s_in], type_dense(_)) # use _ = enc_lstm output\n",
|
||
"\n",
|
||
" return enc_model, dec_q_model, dec_a_model, type_model\n",
|
||
"\n",
|
||
"encoder_model, decoder_q, decoder_a, classifier_model = build_inference_models(model)\n",
|
||
"\n",
|
||
"inv_q = {v:k for k,v in vocab_q.items()}\n",
|
||
"inv_a = {v:k for k,v in vocab_a.items()}\n",
|
||
"\n",
|
||
"def enc_pad(seq, vmap, maxlen):\n",
|
||
" x = [vmap.get(t, vmap[\"<unk>\"]) for t in seq]\n",
|
||
" return x + [vmap[\"<pad>\"]] * (maxlen-len(x))\n",
|
||
"\n",
|
||
"def greedy_decode(tokens, ner, srl, max_q=20, max_a=10):\n",
|
||
" et = np.array([enc_pad(tokens, vocab_tok, MAX_SENT)])\n",
|
||
" en = np.array([enc_pad(ner, vocab_ner, MAX_SENT)])\n",
|
||
" es = np.array([enc_pad(srl, vocab_srl, MAX_SENT)])\n",
|
||
"\n",
|
||
" h,c = encoder_model.predict([et,en,es], verbose=0)\n",
|
||
"\n",
|
||
" # --- question\n",
|
||
" q_ids = []\n",
|
||
" tgt = np.array([[vocab_q[\"<sos>\"]]])\n",
|
||
" for _ in range(max_q):\n",
|
||
" logits,h,c = decoder_q.predict([tgt,h,c], verbose=0)\n",
|
||
" nxt = int(logits[0,-1].argmax())\n",
|
||
" if nxt==vocab_q[\"<eos>\"]: break\n",
|
||
" q_ids.append(nxt)\n",
|
||
" tgt = np.array([[nxt]])\n",
|
||
"\n",
|
||
" # --- answer (re‑use fresh h,c)\n",
|
||
" h,c = encoder_model.predict([et,en,es], verbose=0)\n",
|
||
" a_ids = []\n",
|
||
" tgt = np.array([[vocab_a[\"<sos>\"]]])\n",
|
||
" for _ in range(max_a):\n",
|
||
" logits,h,c = decoder_a.predict([tgt,h,c], verbose=0)\n",
|
||
" nxt = int(logits[0,-1].argmax())\n",
|
||
" if nxt==vocab_a[\"<eos>\"]: break\n",
|
||
" a_ids.append(nxt)\n",
|
||
" tgt = np.array([[nxt]])\n",
|
||
"\n",
|
||
" # --- type\n",
|
||
" t_id = int(classifier_model.predict([et,en,es], verbose=0).argmax())\n",
|
||
"\n",
|
||
" return [inv_q[i] for i in q_ids], [inv_a[i] for i in a_ids], \\\n",
|
||
" [k for k,v in vocab_typ.items() if v==t_id][0]\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "5279b631",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"test_tokens = [\"soekarno\",\"membacakan\",\"teks\",\"proklamasi\",\"pada\",\n",
|
||
" \"17\",\"agustus\",\"1945\"]\n",
|
||
"test_ner = [\"B-PER\",\"O\",\"O\",\"O\",\"O\",\"B-DATE\",\"I-DATE\",\"I-DATE\"]\n",
|
||
"test_srl = [\"ARG0\",\"V\",\"ARG1\",\"ARG1\",\"O\",\"ARGM-TMP\",\"ARGM-TMP\",\"ARGM-TMP\"]\n",
|
||
"\n",
|
||
"q,a,t = greedy_decode(test_tokens,test_ner,test_srl,max_q=MAX_Q,max_a=MAX_A)\n",
|
||
"print(\"\\nDEMO\\n----\")\n",
|
||
"print(\"Q :\", \" \".join(q))\n",
|
||
"print(\"A :\", \" \".join(a))\n",
|
||
"print(\"T :\", t)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "850d4905",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"smooth = SmoothingFunction().method4\n",
|
||
"r_scorer = rouge_scorer.RougeScorer([\"rouge1\",\"rougeL\"], use_stemmer=True)\n",
|
||
"\n",
|
||
"def strip_special(seq, pad_id, eos_id):\n",
|
||
" return [x for x in seq if x not in (pad_id, eos_id)]\n",
|
||
"\n",
|
||
"def ids_to_text(ids, inv):\n",
|
||
" return \" \".join(inv[i] for i in ids)\n",
|
||
"\n",
|
||
"def evaluate(n=200):\n",
|
||
" idxs = random.sample(range(len(samples)), n)\n",
|
||
" refs, hyps = [], []\n",
|
||
" agg = scoring.BootstrapAggregator()\n",
|
||
"\n",
|
||
" for i in idxs:\n",
|
||
" gt_ids = strip_special(dec_q_out[i], pad_q, vocab_q[\"<eos>\"])\n",
|
||
" ref = ids_to_text(gt_ids, inv_q)\n",
|
||
" pred = \" \".join(greedy_decode(\n",
|
||
" samples[i][\"tokens\"],\n",
|
||
" samples[i][\"ner\"],\n",
|
||
" samples[i][\"srl\"]\n",
|
||
" )[0])\n",
|
||
" refs.append([ref.split()])\n",
|
||
" hyps.append(pred.split())\n",
|
||
" agg.add_scores(r_scorer.score(ref, pred))\n",
|
||
"\n",
|
||
" bleu = corpus_bleu(refs, hyps, smoothing_function=smooth)\n",
|
||
" r1 = agg.aggregate()[\"rouge1\"].mid\n",
|
||
" rL = agg.aggregate()[\"rougeL\"].mid\n",
|
||
"\n",
|
||
" print(f\"\\nEVAL (n={n})\")\n",
|
||
" print(f\"BLEU‑4 : {bleu:.4f}\")\n",
|
||
" print(f\"ROUGE‑1 : P={r1.precision:.3f} R={r1.recall:.3f} F1={r1.fmeasure:.3f}\")\n",
|
||
" print(f\"ROUGE‑L : P={rL.precision:.3f} R={rL.recall:.3f} F1={rL.fmeasure:.3f}\")\n",
|
||
"\n",
|
||
"evaluate(2) "
|
||
]
|
||
}
|
||
],
|
||
"metadata": {
|
||
"kernelspec": {
|
||
"display_name": "myenv",
|
||
"language": "python",
|
||
"name": "python3"
|
||
},
|
||
"language_info": {
|
||
"codemirror_mode": {
|
||
"name": "ipython",
|
||
"version": 3
|
||
},
|
||
"file_extension": ".py",
|
||
"mimetype": "text/x-python",
|
||
"name": "python",
|
||
"nbconvert_exporter": "python",
|
||
"pygments_lexer": "ipython3",
|
||
"version": "3.10.16"
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 5
|
||
}
|