{ "cells": [ { "cell_type": "code", "execution_count": 32, "id": "fb283f23", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Total flattened samples: 596\n" ] } ], "source": [ "import json\n", "from pathlib import Path\n", "from itertools import chain\n", "\n", "RAW = json.loads(\n", " Path(\"../dataset/dev_dataset_test.json\").read_text()\n", ") # ← file contoh Anda\n", "\n", "samples = []\n", "for item in RAW:\n", " for qp in item[\"quiz_posibility\"]:\n", " samp = {\n", " \"tokens\": [tok.lower() for tok in item[\"tokens\"]],\n", " \"ner\": item[\"ner\"],\n", " \"srl\": item[\"srl\"],\n", " \"q_type\": qp[\"type\"], # isian / opsi / benar_salah\n", " \"q_toks\": [tok.lower() for tok in qp[\"question\"]]\n", " + [\"\"], # tambahkan \n", " }\n", " # Jawaban bisa multi token\n", " if isinstance(qp[\"answer\"], list):\n", " samp[\"a_toks\"] = [tok.lower() for tok in qp[\"answer\"]] + [\"\"]\n", " else:\n", " samp[\"a_toks\"] = [qp[\"answer\"].lower(), \"\"]\n", " samples.append(samp)\n", "\n", "print(\"Total flattened samples:\", len(samples))" ] }, { "cell_type": "code", "execution_count": 33, "id": "fa4f979d", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'': 0, '': 1, '': 2, '': 3, 'indra': 4, 'rahmawati': 5, 'lahir': 6, 'di': 7, 'malang': 8, 'pada': 9, '27': 10, 'februari': 11, '1972': 12, '___': 13, 'dimana': 14, 'fitri': 15, 'pratama': 16, 'joko': 17, 'wijaya': 18, 'tangerang': 19, '26': 20, 'september': 21, '1970': 22, 'citra': 23, 'jakarta': 24, '17': 25, '1983': 26, 'saputra': 27, 'kota': 28, 'mana': 29, 'nina': 30, 'ridwan': 31, 'gilang': 32, 'purnama': 33, 'depok': 34, '21': 35, 'desember': 36, '1971': 37, 'budi': 38, 'eka': 39, 'laksmi': 40, 'surabaya': 41, '24': 42, 'oktober': 43, '1996': 44, 'sari': 45, '20': 46, 'mei': 47, '1986': 48, 'maulana': 49, 'hana': 50, 'makassar': 51, '16': 52, 'agustus': 53, '1993': 54, 'restu': 55, 'hasan': 56, 'lina': 57, 'siregar': 58, '22': 59, 'april': 60, '1987': 61, 'palembang': 62, '7': 63, 'januari': 64, '1997': 65, '4': 66, 'november': 67, '2009': 68, 'dina': 69, '1978': 70, 'nanda': 71, 'putra': 72, 'tanggal': 73, 'berapa': 74, 'kusuma': 75, '13': 76, '28': 77, '1984': 78, 'medan': 79, '3': 80, 'maret': 81, '1995': 82, '6': 83, '2000': 84, 'anggraini': 85, '10': 86, '2006': 87, 'semarang': 88, 'juni': 89, '1976': 90, 'mega': 91, 'dwiputra': 92, '1992': 93, 'yuni': 94, 'baskara': 95, 'utami': 96, 'diaz': 97, '14': 98, '2008': 99, 'vina': 100, 'lestari': 101, 'mariani': 102, 'islami': 103, '23': 104, 'juli': 105, 'permata': 106, '2007': 107, '1994': 108, '1': 109, '1981': 110, 'toni': 111, 'hakim': 112, 'hendra': 113, 'dwi': 114, 'santoso': 115, 'andi': 116, '25': 117, '1998': 118, '19': 119, 'bayu': 120, 'suryanto': 121, 'kartika': 122, '2010': 123, 'joni': 124, '1985': 125, 'qori': 126, '2003': 127, '1999': 128, 'bandung': 129, 'adi': 130, 'angga': 131, '18': 132, '2004': 133, 'putri': 134, 'laksono': 135, 'hani': 136, '12': 137, '15': 138, '11': 139, '1980': 140, '2002': 141, 'trisna': 142, '1973': 143, 'oki': 144, 'hidayat': 145, '1982': 146, 'suryadi': 147, 'galih': 148, '2': 149, '1977': 150, 'agus': 151, 'sudarto': 152, 'endah': 153, '1991': 154, 'yuliani': 155, 'seno': 156, 'nandito': 157, 'syah': 158, '5': 159, 'mira': 160, '9': 161, 'indah': 162, 'widodo': 163, '1974': 164, 'rio': 165, '1979': 166, '1988': 167, 'yana': 168, 'elisa': 169, 'irma': 170, 'irawan': 171, 'fajar': 172, '1990': 173, 'mustofa': 174, 'heri': 175, 'fina': 176, 'cahya': 177, 'ramadhan': 178, 'setiawan': 179, 'prasetyo': 180, 'ariani': 181, 'kurniawan': 182, 'farhan': 183, 'dwisaputra': 184, 'hardian': 185, '1975': 186, '8': 187, '2001': 188, 'iqbal': 189, 'kurnia': 190, 'cahyono': 191, '1989': 192, 'kiki': 193, 'suwandi': 194, 'gita': 195, '2005': 196, 'rahman': 197, 'danang': 198, 'maharani': 199, 'lisa': 200, 'firmansyah': 201, 'wulan': 202, 'zaki': 203, 'wicak': 204, 'sita': 205, 'ahmad': 206, 'wahyuni': 207, 'saputri': 208, 'tria': 209, 'fadhil': 210}\n" ] } ], "source": [ "def build_vocab(seq_iter, reserved=[\"\", \"\", \"\", \"\"]):\n", " vocab = {tok: idx for idx, tok in enumerate(reserved)}\n", " for tok in chain.from_iterable(seq_iter):\n", " if tok not in vocab:\n", " vocab[tok] = len(vocab)\n", " return vocab\n", "\n", "\n", "vocab_tok = build_vocab((s[\"tokens\"] for s in samples))\n", "vocab_ner = build_vocab((s[\"ner\"] for s in samples), reserved=[\"\", \"\"])\n", "vocab_srl = build_vocab((s[\"srl\"] for s in samples), reserved=[\"\", \"\"])\n", "vocab_q = build_vocab((s[\"q_toks\"] for s in samples))\n", "vocab_a = build_vocab((s[\"a_toks\"] for s in samples))\n", "\n", "vocab_typ = {\"isian\": 0, \"opsi\": 1, \"true_false\": 2}\n", "\n", "print(vocab_q)" ] }, { "cell_type": "code", "execution_count": 34, "id": "d1a5b324", "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "from tensorflow.keras.preprocessing.sequence import pad_sequences\n", "\n", "\n", "def encode(seq, vmap): # token → id\n", " return [vmap.get(t, vmap[\"\"]) for t in seq]\n", "\n", "\n", "MAX_SENT = max(len(s[\"tokens\"]) for s in samples)\n", "MAX_Q = max(len(s[\"q_toks\"]) for s in samples)\n", "MAX_A = max(len(s[\"a_toks\"]) for s in samples)\n", "\n", "X_tok = pad_sequences(\n", " [encode(s[\"tokens\"], vocab_tok) for s in samples], maxlen=MAX_SENT, padding=\"post\"\n", ")\n", "X_ner = pad_sequences(\n", " [encode(s[\"ner\"], vocab_ner) for s in samples], maxlen=MAX_SENT, padding=\"post\"\n", ")\n", "X_srl = pad_sequences(\n", " [encode(s[\"srl\"], vocab_srl) for s in samples], maxlen=MAX_SENT, padding=\"post\"\n", ")\n", "\n", "# Decoder input = + target[:-1]\n", "dec_q_in = pad_sequences(\n", " [[vocab_q[\"\"], *encode(s[\"q_toks\"][:-1], vocab_q)] for s in samples],\n", " maxlen=MAX_Q,\n", " padding=\"post\",\n", ")\n", "dec_q_out = pad_sequences(\n", " [encode(s[\"q_toks\"], vocab_q) for s in samples], maxlen=MAX_Q, padding=\"post\"\n", ")\n", "\n", "dec_a_in = pad_sequences(\n", " [[vocab_a[\"\"], *encode(s[\"a_toks\"][:-1], vocab_a)] for s in samples],\n", " maxlen=MAX_A,\n", " padding=\"post\",\n", ")\n", "dec_a_out = pad_sequences(\n", " [encode(s[\"a_toks\"], vocab_a) for s in samples], maxlen=MAX_A, padding=\"post\"\n", ")\n", "\n", "MAX_SENT = max(len(s[\"tokens\"]) for s in samples)\n", "MAX_Q = max(len(s[\"q_toks\"]) for s in samples)\n", "MAX_A = max(len(s[\"a_toks\"]) for s in samples)\n", "y_type = np.array([vocab_typ[s[\"q_type\"]] for s in samples])" ] }, { "cell_type": "code", "execution_count": 35, "id": "ff5bd85f", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
Model: \"functional_3\"\n",
       "
\n" ], "text/plain": [ "\u001b[1mModel: \"functional_3\"\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓\n",
       "┃ Layer (type)         Output Shape          Param #  Connected to      ┃\n",
       "┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩\n",
       "│ tok_in (InputLayer) │ (None, 9)         │          0 │ -                 │\n",
       "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
       "│ ner_in (InputLayer) │ (None, 9)         │          0 │ -                 │\n",
       "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
       "│ srl_in (InputLayer) │ (None, 9)         │          0 │ -                 │\n",
       "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
       "│ embedding_tok       │ (None, 9, 128)    │     31,232 │ tok_in[0][0]      │\n",
       "│ (Embedding)         │                   │            │                   │\n",
       "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
       "│ embedding_ner       │ (None, 9, 32)     │        288 │ ner_in[0][0]      │\n",
       "│ (Embedding)         │                   │            │                   │\n",
       "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
       "│ embedding_srl       │ (None, 9, 32)     │        224 │ srl_in[0][0]      │\n",
       "│ (Embedding)         │                   │            │                   │\n",
       "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
       "│ dec_q_in            │ (None, 11)        │          0 │ -                 │\n",
       "│ (InputLayer)        │                   │            │                   │\n",
       "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
       "│ concatenate_3       │ (None, 9, 192)    │          0 │ embedding_tok[0]… │\n",
       "│ (Concatenate)       │                   │            │ embedding_ner[0]… │\n",
       "│                     │                   │            │ embedding_srl[0]… │\n",
       "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
       "│ dec_a_in            │ (None, 4)         │          0 │ -                 │\n",
       "│ (InputLayer)        │                   │            │                   │\n",
       "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
       "│ embedding_q_decoder │ (None, 11, 128)   │     27,008 │ dec_q_in[0][0]    │\n",
       "│ (Embedding)         │                   │            │                   │\n",
       "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
       "│ encoder_lstm (LSTM) │ [(None, 256),     │    459,776 │ concatenate_3[0]… │\n",
       "│                     │ (None, 256),      │            │                   │\n",
       "│                     │ (None, 256)]      │            │                   │\n",
       "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
       "│ embedding_a_decoder │ (None, 4, 128)    │     14,336 │ dec_a_in[0][0]    │\n",
       "│ (Embedding)         │                   │            │                   │\n",
       "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
       "│ lstm_q_decoder      │ [(None, 11, 256), │    394,240 │ embedding_q_deco… │\n",
       "│ (LSTM)              │ (None, 256),      │            │ encoder_lstm[0][ │\n",
       "│                     │ (None, 256)]      │            │ encoder_lstm[0][ │\n",
       "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
       "│ not_equal_12        │ (None, 11)        │          0 │ dec_q_in[0][0]    │\n",
       "│ (NotEqual)          │                   │            │                   │\n",
       "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
       "│ lstm_a_decoder      │ [(None, 4, 256),  │    394,240 │ embedding_a_deco… │\n",
       "│ (LSTM)              │ (None, 256),      │            │ encoder_lstm[0][ │\n",
       "│                     │ (None, 256)]      │            │ encoder_lstm[0][ │\n",
       "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
       "│ not_equal_13        │ (None, 4)         │          0 │ dec_a_in[0][0]    │\n",
       "│ (NotEqual)          │                   │            │                   │\n",
       "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
       "│ q_output            │ (None, 11, 211)   │     54,227 │ lstm_q_decoder[0… │\n",
       "│ (TimeDistributed)   │                   │            │ not_equal_12[0][ │\n",
       "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
       "│ a_output            │ (None, 4, 112)    │     28,784 │ lstm_a_decoder[0… │\n",
       "│ (TimeDistributed)   │                   │            │ not_equal_13[0][ │\n",
       "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
       "│ type_output (Dense) │ (None, 3)         │        771 │ encoder_lstm[0][ │\n",
       "└─────────────────────┴───────────────────┴────────────┴───────────────────┘\n",
       "
\n" ], "text/plain": [ "┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓\n", "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mConnected to \u001b[0m\u001b[1m \u001b[0m┃\n", "┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩\n", "│ tok_in (\u001b[38;5;33mInputLayer\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m9\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n", "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", "│ ner_in (\u001b[38;5;33mInputLayer\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m9\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n", "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", "│ srl_in (\u001b[38;5;33mInputLayer\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m9\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n", "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", "│ embedding_tok │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m9\u001b[0m, \u001b[38;5;34m128\u001b[0m) │ \u001b[38;5;34m31,232\u001b[0m │ tok_in[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", "│ (\u001b[38;5;33mEmbedding\u001b[0m) │ │ │ │\n", "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", "│ embedding_ner │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m9\u001b[0m, \u001b[38;5;34m32\u001b[0m) │ \u001b[38;5;34m288\u001b[0m │ ner_in[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", "│ (\u001b[38;5;33mEmbedding\u001b[0m) │ │ │ │\n", "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", "│ embedding_srl │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m9\u001b[0m, \u001b[38;5;34m32\u001b[0m) │ \u001b[38;5;34m224\u001b[0m │ srl_in[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", "│ (\u001b[38;5;33mEmbedding\u001b[0m) │ │ │ │\n", "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", "│ dec_q_in │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m11\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n", "│ (\u001b[38;5;33mInputLayer\u001b[0m) │ │ │ │\n", "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", "│ concatenate_3 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m9\u001b[0m, \u001b[38;5;34m192\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ embedding_tok[\u001b[38;5;34m0\u001b[0m]… │\n", "│ (\u001b[38;5;33mConcatenate\u001b[0m) │ │ │ embedding_ner[\u001b[38;5;34m0\u001b[0m]… │\n", "│ │ │ │ embedding_srl[\u001b[38;5;34m0\u001b[0m]… │\n", "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", "│ dec_a_in │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m4\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n", "│ (\u001b[38;5;33mInputLayer\u001b[0m) │ │ │ │\n", "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", "│ embedding_q_decoder │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m11\u001b[0m, \u001b[38;5;34m128\u001b[0m) │ \u001b[38;5;34m27,008\u001b[0m │ dec_q_in[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", "│ (\u001b[38;5;33mEmbedding\u001b[0m) │ │ │ │\n", "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", "│ encoder_lstm (\u001b[38;5;33mLSTM\u001b[0m) │ [(\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m256\u001b[0m), │ \u001b[38;5;34m459,776\u001b[0m │ concatenate_3[\u001b[38;5;34m0\u001b[0m]… │\n", "│ │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m256\u001b[0m), │ │ │\n", "│ │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m256\u001b[0m)] │ │ │\n", "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", "│ embedding_a_decoder │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m4\u001b[0m, \u001b[38;5;34m128\u001b[0m) │ \u001b[38;5;34m14,336\u001b[0m │ dec_a_in[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", "│ (\u001b[38;5;33mEmbedding\u001b[0m) │ │ │ │\n", "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", "│ lstm_q_decoder │ [(\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m11\u001b[0m, \u001b[38;5;34m256\u001b[0m), │ \u001b[38;5;34m394,240\u001b[0m │ embedding_q_deco… │\n", "│ (\u001b[38;5;33mLSTM\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m256\u001b[0m), │ │ encoder_lstm[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m…\u001b[0m │\n", "│ │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m256\u001b[0m)] │ │ encoder_lstm[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m…\u001b[0m │\n", "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", "│ not_equal_12 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m11\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ dec_q_in[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", "│ (\u001b[38;5;33mNotEqual\u001b[0m) │ │ │ │\n", "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", "│ lstm_a_decoder │ [(\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m4\u001b[0m, \u001b[38;5;34m256\u001b[0m), │ \u001b[38;5;34m394,240\u001b[0m │ embedding_a_deco… │\n", "│ (\u001b[38;5;33mLSTM\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m256\u001b[0m), │ │ encoder_lstm[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m…\u001b[0m │\n", "│ │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m256\u001b[0m)] │ │ encoder_lstm[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m…\u001b[0m │\n", "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", "│ not_equal_13 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m4\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ dec_a_in[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", "│ (\u001b[38;5;33mNotEqual\u001b[0m) │ │ │ │\n", "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", "│ q_output │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m11\u001b[0m, \u001b[38;5;34m211\u001b[0m) │ \u001b[38;5;34m54,227\u001b[0m │ lstm_q_decoder[\u001b[38;5;34m0\u001b[0m… │\n", "│ (\u001b[38;5;33mTimeDistributed\u001b[0m) │ │ │ not_equal_12[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m…\u001b[0m │\n", "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", "│ a_output │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m4\u001b[0m, \u001b[38;5;34m112\u001b[0m) │ \u001b[38;5;34m28,784\u001b[0m │ lstm_a_decoder[\u001b[38;5;34m0\u001b[0m… │\n", "│ (\u001b[38;5;33mTimeDistributed\u001b[0m) │ │ │ not_equal_13[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m…\u001b[0m │\n", "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", "│ type_output (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m3\u001b[0m) │ \u001b[38;5;34m771\u001b[0m │ encoder_lstm[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m…\u001b[0m │\n", "└─────────────────────┴───────────────────┴────────────┴───────────────────┘\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
 Total params: 1,405,126 (5.36 MB)\n",
       "
\n" ], "text/plain": [ "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m1,405,126\u001b[0m (5.36 MB)\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
 Trainable params: 1,405,126 (5.36 MB)\n",
       "
\n" ], "text/plain": [ "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m1,405,126\u001b[0m (5.36 MB)\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
 Non-trainable params: 0 (0.00 B)\n",
       "
\n" ], "text/plain": [ "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import tensorflow as tf\n", "from tensorflow.keras.layers import (\n", " Input,\n", " Embedding,\n", " LSTM,\n", " Concatenate,\n", " Dense,\n", " TimeDistributed,\n", ")\n", "from tensorflow.keras.models import Model\n", "\n", "# ---- constants ---------------------------------------------------\n", "d_tok = 128 # token embedding dim\n", "d_tag = 32 # NER / SRL embedding dim\n", "units = 256\n", "\n", "# ---- encoder -----------------------------------------------------\n", "inp_tok = Input((MAX_SENT,), name=\"tok_in\")\n", "inp_ner = Input((MAX_SENT,), name=\"ner_in\")\n", "inp_srl = Input((MAX_SENT,), name=\"srl_in\")\n", "\n", "# make ALL streams mask the same way (here: no masking,\n", "# we'll just pad with 0s and let the LSTM ignore them)\n", "emb_tok = Embedding(len(vocab_tok), d_tok, mask_zero=False, name=\"embedding_tok\")(\n", " inp_tok\n", ")\n", "emb_ner = Embedding(len(vocab_ner), d_tag, mask_zero=False, name=\"embedding_ner\")(\n", " inp_ner\n", ")\n", "emb_srl = Embedding(len(vocab_srl), d_tag, mask_zero=False, name=\"embedding_srl\")(\n", " inp_srl\n", ")\n", "\n", "enc_concat = Concatenate()([emb_tok, emb_ner, emb_srl])\n", "enc_out, state_h, state_c = LSTM(units, return_state=True, name=\"encoder_lstm\")(\n", " enc_concat\n", ")\n", "\n", "\n", "# ---------- DECODER : Question ----------\n", "dec_q_inp = Input(shape=(MAX_Q,), name=\"dec_q_in\")\n", "dec_emb_q = Embedding(len(vocab_q), d_tok, mask_zero=True, name=\"embedding_q_decoder\")(\n", " dec_q_inp\n", ")\n", "dec_q, _, _ = LSTM(\n", " units, return_state=True, return_sequences=True, name=\"lstm_q_decoder\"\n", ")(dec_emb_q, initial_state=[state_h, state_c])\n", "q_out = TimeDistributed(\n", " Dense(len(vocab_q), activation=\"softmax\", name=\"dense_q_output\"), name=\"q_output\"\n", ")(dec_q)\n", "\n", "# ---------- DECODER : Answer ----------\n", "dec_a_inp = Input(shape=(MAX_A,), name=\"dec_a_in\")\n", "dec_emb_a = Embedding(len(vocab_a), d_tok, mask_zero=True, name=\"embedding_a_decoder\")(\n", " dec_a_inp\n", ")\n", "dec_a, _, _ = LSTM(\n", " units, return_state=True, return_sequences=True, name=\"lstm_a_decoder\"\n", ")(dec_emb_a, initial_state=[state_h, state_c])\n", "a_out = TimeDistributed(\n", " Dense(len(vocab_a), activation=\"softmax\", name=\"dense_a_output\"), name=\"a_output\"\n", ")(dec_a)\n", "\n", "# ---------- CLASSIFIER : Question Type ----------\n", "type_out = Dense(len(vocab_typ), activation=\"softmax\", name=\"type_output\")(enc_out)\n", "\n", "model = Model(\n", " inputs=[inp_tok, inp_ner, inp_srl, dec_q_inp, dec_a_inp],\n", " outputs=[q_out, a_out, type_out],\n", ")\n", "\n", "model.summary()" ] }, { "cell_type": "code", "execution_count": 36, "id": "fece1ae9", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/30\n", "\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 83ms/step - a_output_loss: 4.5185 - a_output_sparse_categorical_accuracy: 0.1853 - loss: 10.1289 - q_output_loss: 5.2751 - q_output_sparse_categorical_accuracy: 0.1679 - type_output_accuracy: 0.3966 - type_output_loss: 1.0344 - val_a_output_loss: 2.2993 - val_a_output_sparse_categorical_accuracy: 0.2500 - val_loss: 6.6556 - val_q_output_loss: 4.1554 - val_q_output_sparse_categorical_accuracy: 0.1606 - val_type_output_accuracy: 0.6167 - val_type_output_loss: 0.6698\n", "Epoch 2/30\n", "\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 28ms/step - a_output_loss: 2.7148 - a_output_sparse_categorical_accuracy: 0.2674 - loss: 6.7993 - q_output_loss: 3.8706 - q_output_sparse_categorical_accuracy: 0.1625 - type_output_accuracy: 0.5096 - type_output_loss: 0.7111 - val_a_output_loss: 1.9687 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 5.6731 - val_q_output_loss: 3.4983 - val_q_output_sparse_categorical_accuracy: 0.2848 - val_type_output_accuracy: 0.8167 - val_type_output_loss: 0.6870\n", "Epoch 3/30\n", "\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 29ms/step - a_output_loss: 2.4494 - a_output_sparse_categorical_accuracy: 0.3225 - loss: 6.0211 - q_output_loss: 3.3610 - q_output_sparse_categorical_accuracy: 0.2195 - type_output_accuracy: 0.6316 - type_output_loss: 0.6890 - val_a_output_loss: 1.8200 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 5.1733 - val_q_output_loss: 3.1541 - val_q_output_sparse_categorical_accuracy: 0.2136 - val_type_output_accuracy: 0.6167 - val_type_output_loss: 0.6640\n", "Epoch 4/30\n", "\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 29ms/step - a_output_loss: 2.2031 - a_output_sparse_categorical_accuracy: 0.3726 - loss: 5.4427 - q_output_loss: 3.0396 - q_output_sparse_categorical_accuracy: 0.2869 - type_output_accuracy: 0.5478 - type_output_loss: 0.6786 - val_a_output_loss: 1.5951 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 4.6514 - val_q_output_loss: 2.8462 - val_q_output_sparse_categorical_accuracy: 0.3758 - val_type_output_accuracy: 0.3833 - val_type_output_loss: 0.7003\n", "Epoch 5/30\n", "\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 30ms/step - a_output_loss: 2.0493 - a_output_sparse_categorical_accuracy: 0.3739 - loss: 4.9806 - q_output_loss: 2.7312 - q_output_sparse_categorical_accuracy: 0.3659 - type_output_accuracy: 0.5789 - type_output_loss: 0.6663 - val_a_output_loss: 1.5595 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 4.3330 - val_q_output_loss: 2.5741 - val_q_output_sparse_categorical_accuracy: 0.4197 - val_type_output_accuracy: 0.4833 - val_type_output_loss: 0.6650\n", "Epoch 6/30\n", "\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 28ms/step - a_output_loss: 1.9310 - a_output_sparse_categorical_accuracy: 0.3769 - loss: 4.5682 - q_output_loss: 2.4562 - q_output_sparse_categorical_accuracy: 0.4147 - type_output_accuracy: 0.6767 - type_output_loss: 0.6249 - val_a_output_loss: 1.4318 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 3.9148 - val_q_output_loss: 2.3074 - val_q_output_sparse_categorical_accuracy: 0.4333 - val_type_output_accuracy: 0.8167 - val_type_output_loss: 0.5853\n", "Epoch 7/30\n", "\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 28ms/step - a_output_loss: 1.8407 - a_output_sparse_categorical_accuracy: 0.3722 - loss: 4.1857 - q_output_loss: 2.1853 - q_output_sparse_categorical_accuracy: 0.4229 - type_output_accuracy: 0.7938 - type_output_loss: 0.5382 - val_a_output_loss: 1.3413 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 3.5705 - val_q_output_loss: 2.0979 - val_q_output_sparse_categorical_accuracy: 0.4515 - val_type_output_accuracy: 0.8500 - val_type_output_loss: 0.4377\n", "Epoch 8/30\n", "\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 28ms/step - a_output_loss: 1.8257 - a_output_sparse_categorical_accuracy: 0.3606 - loss: 3.9578 - q_output_loss: 1.9912 - q_output_sparse_categorical_accuracy: 0.4426 - type_output_accuracy: 0.7711 - type_output_loss: 0.4701 - val_a_output_loss: 1.2847 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 3.3498 - val_q_output_loss: 1.9433 - val_q_output_sparse_categorical_accuracy: 0.4712 - val_type_output_accuracy: 0.8500 - val_type_output_loss: 0.4062\n", "Epoch 9/30\n", "\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 29ms/step - a_output_loss: 1.6802 - a_output_sparse_categorical_accuracy: 0.3683 - loss: 3.6282 - q_output_loss: 1.8210 - q_output_sparse_categorical_accuracy: 0.4542 - type_output_accuracy: 0.7996 - type_output_loss: 0.4131 - val_a_output_loss: 1.2484 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 3.1973 - val_q_output_loss: 1.8318 - val_q_output_sparse_categorical_accuracy: 0.4667 - val_type_output_accuracy: 0.8500 - val_type_output_loss: 0.3905\n", "Epoch 10/30\n", "\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 29ms/step - a_output_loss: 1.6225 - a_output_sparse_categorical_accuracy: 0.3738 - loss: 3.4468 - q_output_loss: 1.7095 - q_output_sparse_categorical_accuracy: 0.4516 - type_output_accuracy: 0.8104 - type_output_loss: 0.3915 - val_a_output_loss: 1.2075 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 3.0751 - val_q_output_loss: 1.7371 - val_q_output_sparse_categorical_accuracy: 0.4758 - val_type_output_accuracy: 0.8167 - val_type_output_loss: 0.4349\n", "Epoch 11/30\n", "\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 28ms/step - a_output_loss: 1.5862 - a_output_sparse_categorical_accuracy: 0.3721 - loss: 3.3171 - q_output_loss: 1.6153 - q_output_sparse_categorical_accuracy: 0.4538 - type_output_accuracy: 0.8102 - type_output_loss: 0.3785 - val_a_output_loss: 1.1730 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 2.9582 - val_q_output_loss: 1.6659 - val_q_output_sparse_categorical_accuracy: 0.4697 - val_type_output_accuracy: 0.8500 - val_type_output_loss: 0.3979\n", "Epoch 12/30\n", "\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 29ms/step - a_output_loss: 1.5335 - a_output_sparse_categorical_accuracy: 0.3664 - loss: 3.1939 - q_output_loss: 1.5335 - q_output_sparse_categorical_accuracy: 0.4518 - type_output_accuracy: 0.7775 - type_output_loss: 0.4105 - val_a_output_loss: 1.1304 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 2.8616 - val_q_output_loss: 1.6143 - val_q_output_sparse_categorical_accuracy: 0.4727 - val_type_output_accuracy: 0.8500 - val_type_output_loss: 0.3897\n", "Epoch 13/30\n", "\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 28ms/step - a_output_loss: 1.4389 - a_output_sparse_categorical_accuracy: 0.3754 - loss: 3.0570 - q_output_loss: 1.5013 - q_output_sparse_categorical_accuracy: 0.4627 - type_output_accuracy: 0.8116 - type_output_loss: 0.3873 - val_a_output_loss: 1.1223 - val_a_output_sparse_categorical_accuracy: 0.4083 - val_loss: 2.8335 - val_q_output_loss: 1.5881 - val_q_output_sparse_categorical_accuracy: 0.4515 - val_type_output_accuracy: 0.8333 - val_type_output_loss: 0.4107\n", "Epoch 14/30\n", "\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 30ms/step - a_output_loss: 1.4312 - a_output_sparse_categorical_accuracy: 0.3714 - loss: 2.9962 - q_output_loss: 1.4529 - q_output_sparse_categorical_accuracy: 0.4608 - type_output_accuracy: 0.7943 - type_output_loss: 0.3844 - val_a_output_loss: 1.1235 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 2.8020 - val_q_output_loss: 1.5603 - val_q_output_sparse_categorical_accuracy: 0.4652 - val_type_output_accuracy: 0.8500 - val_type_output_loss: 0.3939\n", "Epoch 15/30\n", "\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 28ms/step - a_output_loss: 1.3670 - a_output_sparse_categorical_accuracy: 0.3787 - loss: 2.9339 - q_output_loss: 1.4391 - q_output_sparse_categorical_accuracy: 0.4609 - type_output_accuracy: 0.7846 - type_output_loss: 0.4245 - val_a_output_loss: 1.1046 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 2.7626 - val_q_output_loss: 1.5316 - val_q_output_sparse_categorical_accuracy: 0.4697 - val_type_output_accuracy: 0.8000 - val_type_output_loss: 0.4212\n", "Epoch 16/30\n", "\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 28ms/step - a_output_loss: 1.3007 - a_output_sparse_categorical_accuracy: 0.3830 - loss: 2.8316 - q_output_loss: 1.4239 - q_output_sparse_categorical_accuracy: 0.4647 - type_output_accuracy: 0.8005 - type_output_loss: 0.3726 - val_a_output_loss: 1.1026 - val_a_output_sparse_categorical_accuracy: 0.4083 - val_loss: 2.7498 - val_q_output_loss: 1.5221 - val_q_output_sparse_categorical_accuracy: 0.4652 - val_type_output_accuracy: 0.8000 - val_type_output_loss: 0.4171\n", "Epoch 17/30\n", "\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 31ms/step - a_output_loss: 1.3868 - a_output_sparse_categorical_accuracy: 0.3768 - loss: 2.8855 - q_output_loss: 1.3850 - q_output_sparse_categorical_accuracy: 0.4683 - type_output_accuracy: 0.8190 - type_output_loss: 0.3552 - val_a_output_loss: 1.0983 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 2.7381 - val_q_output_loss: 1.5079 - val_q_output_sparse_categorical_accuracy: 0.4636 - val_type_output_accuracy: 0.8000 - val_type_output_loss: 0.4401\n", "Epoch 18/30\n", "\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 28ms/step - a_output_loss: 1.2897 - a_output_sparse_categorical_accuracy: 0.3837 - loss: 2.7760 - q_output_loss: 1.3759 - q_output_sparse_categorical_accuracy: 0.4749 - type_output_accuracy: 0.8087 - type_output_loss: 0.3802 - val_a_output_loss: 1.1044 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 2.7490 - val_q_output_loss: 1.5155 - val_q_output_sparse_categorical_accuracy: 0.4636 - val_type_output_accuracy: 0.8000 - val_type_output_loss: 0.4305\n", "Epoch 19/30\n", "\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 30ms/step - a_output_loss: 1.3054 - a_output_sparse_categorical_accuracy: 0.3802 - loss: 2.7626 - q_output_loss: 1.3517 - q_output_sparse_categorical_accuracy: 0.4699 - type_output_accuracy: 0.8228 - type_output_loss: 0.3645 - val_a_output_loss: 1.0990 - val_a_output_sparse_categorical_accuracy: 0.4125 - val_loss: 2.7261 - val_q_output_loss: 1.4992 - val_q_output_sparse_categorical_accuracy: 0.4667 - val_type_output_accuracy: 0.8000 - val_type_output_loss: 0.4264\n", "Epoch 20/30\n", "\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 28ms/step - a_output_loss: 1.2454 - a_output_sparse_categorical_accuracy: 0.3856 - loss: 2.7195 - q_output_loss: 1.3633 - q_output_sparse_categorical_accuracy: 0.4751 - type_output_accuracy: 0.8284 - type_output_loss: 0.3665 - val_a_output_loss: 1.1154 - val_a_output_sparse_categorical_accuracy: 0.4125 - val_loss: 2.7285 - val_q_output_loss: 1.4769 - val_q_output_sparse_categorical_accuracy: 0.4652 - val_type_output_accuracy: 0.7833 - val_type_output_loss: 0.4540\n", "Epoch 21/30\n", "\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 29ms/step - a_output_loss: 1.2583 - a_output_sparse_categorical_accuracy: 0.3896 - loss: 2.7032 - q_output_loss: 1.3383 - q_output_sparse_categorical_accuracy: 0.4719 - type_output_accuracy: 0.8246 - type_output_loss: 0.3594 - val_a_output_loss: 1.1381 - val_a_output_sparse_categorical_accuracy: 0.4125 - val_loss: 2.8064 - val_q_output_loss: 1.4643 - val_q_output_sparse_categorical_accuracy: 0.4788 - val_type_output_accuracy: 0.6500 - val_type_output_loss: 0.6802\n", "Epoch 22/30\n", "\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 28ms/step - a_output_loss: 1.3214 - a_output_sparse_categorical_accuracy: 0.3743 - loss: 2.7428 - q_output_loss: 1.3119 - q_output_sparse_categorical_accuracy: 0.4669 - type_output_accuracy: 0.8213 - type_output_loss: 0.3563 - val_a_output_loss: 1.1002 - val_a_output_sparse_categorical_accuracy: 0.4125 - val_loss: 2.7029 - val_q_output_loss: 1.4733 - val_q_output_sparse_categorical_accuracy: 0.4727 - val_type_output_accuracy: 0.8333 - val_type_output_loss: 0.4315\n", "Epoch 23/30\n", "\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 29ms/step - a_output_loss: 1.2769 - a_output_sparse_categorical_accuracy: 0.3907 - loss: 2.7114 - q_output_loss: 1.3188 - q_output_sparse_categorical_accuracy: 0.4772 - type_output_accuracy: 0.8281 - type_output_loss: 0.3824 - val_a_output_loss: 1.1039 - val_a_output_sparse_categorical_accuracy: 0.4042 - val_loss: 2.7224 - val_q_output_loss: 1.4924 - val_q_output_sparse_categorical_accuracy: 0.4727 - val_type_output_accuracy: 0.8167 - val_type_output_loss: 0.4205\n", "Epoch 24/30\n", "\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 29ms/step - a_output_loss: 1.1957 - a_output_sparse_categorical_accuracy: 0.3969 - loss: 2.6335 - q_output_loss: 1.3299 - q_output_sparse_categorical_accuracy: 0.4776 - type_output_accuracy: 0.8342 - type_output_loss: 0.3683 - val_a_output_loss: 1.0974 - val_a_output_sparse_categorical_accuracy: 0.4167 - val_loss: 2.7003 - val_q_output_loss: 1.4677 - val_q_output_sparse_categorical_accuracy: 0.4667 - val_type_output_accuracy: 0.8000 - val_type_output_loss: 0.4504\n", "Epoch 25/30\n", "\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 28ms/step - a_output_loss: 1.2393 - a_output_sparse_categorical_accuracy: 0.3898 - loss: 2.6442 - q_output_loss: 1.2937 - q_output_sparse_categorical_accuracy: 0.4804 - type_output_accuracy: 0.8273 - type_output_loss: 0.3564 - val_a_output_loss: 1.1525 - val_a_output_sparse_categorical_accuracy: 0.4125 - val_loss: 2.7743 - val_q_output_loss: 1.4611 - val_q_output_sparse_categorical_accuracy: 0.4682 - val_type_output_accuracy: 0.7833 - val_type_output_loss: 0.5356\n", "Epoch 26/30\n", "\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 29ms/step - a_output_loss: 1.1872 - a_output_sparse_categorical_accuracy: 0.4001 - loss: 2.5917 - q_output_loss: 1.2930 - q_output_sparse_categorical_accuracy: 0.4802 - type_output_accuracy: 0.8280 - type_output_loss: 0.3555 - val_a_output_loss: 1.1505 - val_a_output_sparse_categorical_accuracy: 0.4083 - val_loss: 2.7684 - val_q_output_loss: 1.4587 - val_q_output_sparse_categorical_accuracy: 0.4742 - val_type_output_accuracy: 0.8000 - val_type_output_loss: 0.5307\n", "Epoch 27/30\n", "\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 28ms/step - a_output_loss: 1.1337 - a_output_sparse_categorical_accuracy: 0.4077 - loss: 2.5274 - q_output_loss: 1.2917 - q_output_sparse_categorical_accuracy: 0.4856 - type_output_accuracy: 0.8328 - type_output_loss: 0.3424 - val_a_output_loss: 1.1274 - val_a_output_sparse_categorical_accuracy: 0.4208 - val_loss: 2.7139 - val_q_output_loss: 1.4500 - val_q_output_sparse_categorical_accuracy: 0.4788 - val_type_output_accuracy: 0.8000 - val_type_output_loss: 0.4548\n", "Epoch 28/30\n", "\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 30ms/step - a_output_loss: 1.1081 - a_output_sparse_categorical_accuracy: 0.4104 - loss: 2.4859 - q_output_loss: 1.2903 - q_output_sparse_categorical_accuracy: 0.4845 - type_output_accuracy: 0.8714 - type_output_loss: 0.2936 - val_a_output_loss: 1.1394 - val_a_output_sparse_categorical_accuracy: 0.4167 - val_loss: 2.7244 - val_q_output_loss: 1.4512 - val_q_output_sparse_categorical_accuracy: 0.4652 - val_type_output_accuracy: 0.8167 - val_type_output_loss: 0.4457\n" ] } ], "source": [ "losses = {\n", " \"q_output\": \"sparse_categorical_crossentropy\",\n", " \"a_output\": \"sparse_categorical_crossentropy\",\n", " \"type_output\": \"sparse_categorical_crossentropy\",\n", "}\n", "loss_weights = {\"q_output\": 1.0, \"a_output\": 1.0, \"type_output\": 0.3}\n", "\n", "model.compile(\n", " optimizer=\"adam\",\n", " loss=losses,\n", " loss_weights=loss_weights,\n", " metrics={\n", " \"q_output\": \"sparse_categorical_accuracy\",\n", " \"a_output\": \"sparse_categorical_accuracy\",\n", " \"type_output\": \"accuracy\",\n", " },\n", ")\n", "\n", "history = model.fit(\n", " [X_tok, X_ner, X_srl, dec_q_in, dec_a_in],\n", " [dec_q_out, dec_a_out, y_type],\n", " validation_split=0.1,\n", " epochs=30,\n", " batch_size=64,\n", " callbacks=[tf.keras.callbacks.EarlyStopping(patience=4, restore_best_weights=True)],\n", " verbose=1,\n", ")\n", "\n", "model.save(\"full_seq2seq.keras\")\n", "\n", "import json\n", "import pickle\n", "\n", "# def save_vocab(vocab, path):\n", "# with open(path, \"w\", encoding=\"utf-8\") as f:\n", "# json.dump(vocab, f, ensure_ascii=False, indent=2)\n", "\n", "# # Simpan semua vocab\n", "# save_vocab(vocab_tok, \"vocab_tok.json\")\n", "# save_vocab(vocab_ner, \"vocab_ner.json\")\n", "# save_vocab(vocab_srl, \"vocab_srl.json\")\n", "# save_vocab(vocab_q, \"vocab_q.json\")\n", "# save_vocab(vocab_a, \"vocab_a.json\")\n", "# save_vocab(vocab_typ, \"vocab_typ.json\")\n", "\n", "\n", "def save_vocab_pkl(vocab, path):\n", " with open(path, \"wb\") as f:\n", " pickle.dump(vocab, f)\n", "\n", "\n", "# Simpan semua vocab\n", "save_vocab_pkl(vocab_tok, \"vocab_tok.pkl\")\n", "save_vocab_pkl(vocab_ner, \"vocab_ner.pkl\")\n", "save_vocab_pkl(vocab_srl, \"vocab_srl.pkl\")\n", "save_vocab_pkl(vocab_q, \"vocab_q.pkl\")\n", "save_vocab_pkl(vocab_a, \"vocab_a.pkl\")\n", "save_vocab_pkl(vocab_typ, \"vocab_typ.pkl\")" ] }, { "cell_type": "code", "execution_count": 37, "id": "3355c0c7", "metadata": {}, "outputs": [], "source": [ "import tensorflow as tf\n", "import numpy as np\n", "import pickle\n", "from tensorflow.keras.models import load_model, Model\n", "from tensorflow.keras.layers import Input, Concatenate\n", "\n", "# === Load Model Utama ===\n", "model = load_model(\"full_seq2seq.keras\")\n", "\n", "\n", "# === Load Vocabulary dari .pkl ===\n", "def load_vocab(path):\n", " with open(path, \"rb\") as f:\n", " return pickle.load(f)\n", "\n", "\n", "vocab_tok = load_vocab(\"vocab_tok.pkl\")\n", "vocab_ner = load_vocab(\"vocab_ner.pkl\")\n", "vocab_srl = load_vocab(\"vocab_srl.pkl\")\n", "vocab_q = load_vocab(\"vocab_q.pkl\")\n", "vocab_a = load_vocab(\"vocab_a.pkl\")\n", "vocab_typ = load_vocab(\"vocab_typ.pkl\")\n", "\n", "inv_vocab_q = {v: k for k, v in vocab_q.items()}\n", "inv_vocab_a = {v: k for k, v in vocab_a.items()}\n", "\n", "# === Build Encoder Model ===\n", "MAX_SENT = model.input_shape[0][1] # Ambil shape dari model yang diload\n", "MAX_Q = model.input_shape[3][1] # Max length for question\n", "MAX_A = model.input_shape[4][1] # Max length for answer\n", "\n", "inp_tok_g = Input(shape=(MAX_SENT,), name=\"tok_in_g\")\n", "inp_ner_g = Input(shape=(MAX_SENT,), name=\"ner_in_g\")\n", "inp_srl_g = Input(shape=(MAX_SENT,), name=\"srl_in_g\")\n", "\n", "emb_tok = model.get_layer(\"embedding_tok\").call(inp_tok_g)\n", "emb_ner = model.get_layer(\"embedding_ner\").call(inp_ner_g)\n", "emb_srl = model.get_layer(\"embedding_srl\").call(inp_srl_g)\n", "\n", "enc_concat = Concatenate(name=\"concat_encoder\")([emb_tok, emb_ner, emb_srl])\n", "\n", "encoder_lstm = model.get_layer(\"encoder_lstm\")\n", "enc_out, state_h, state_c = encoder_lstm(enc_concat)\n", "\n", "# Create encoder model with full output including enc_out\n", "encoder_model = Model(\n", " inputs=[inp_tok_g, inp_ner_g, inp_srl_g],\n", " outputs=[enc_out, state_h, state_c],\n", " name=\"encoder_model\",\n", ")\n", "\n", "# === Build Decoder for Question ===\n", "dec_q_inp = Input(shape=(1,), name=\"dec_q_in\")\n", "dec_emb_q = model.get_layer(\"embedding_q_decoder\").call(dec_q_inp)\n", "\n", "state_h_dec = Input(shape=(256,), name=\"state_h_dec\")\n", "state_c_dec = Input(shape=(256,), name=\"state_c_dec\")\n", "\n", "lstm_decoder_q = model.get_layer(\"lstm_q_decoder\")\n", "\n", "dec_out_q, state_h_q, state_c_q = lstm_decoder_q(\n", " dec_emb_q, initial_state=[state_h_dec, state_c_dec]\n", ")\n", "\n", "q_time_dist_layer = model.get_layer(\"q_output\")\n", "dense_q = q_time_dist_layer.layer\n", "q_output = dense_q(dec_out_q)\n", "\n", "decoder_q = Model(\n", " inputs=[dec_q_inp, state_h_dec, state_c_dec],\n", " outputs=[q_output, state_h_q, state_c_q],\n", " name=\"decoder_question_model\",\n", ")\n", "\n", "# === Build Decoder for Answer ===\n", "dec_a_inp = Input(shape=(1,), name=\"dec_a_in\")\n", "dec_emb_a = model.get_layer(\"embedding_a_decoder\").call(dec_a_inp)\n", "\n", "state_h_a = Input(shape=(256,), name=\"state_h_a\")\n", "state_c_a = Input(shape=(256,), name=\"state_c_a\")\n", "\n", "lstm_decoder_a = model.get_layer(\"lstm_a_decoder\")\n", "\n", "dec_out_a, state_h_a_out, state_c_a_out = lstm_decoder_a(\n", " dec_emb_a, initial_state=[state_h_a, state_c_a]\n", ")\n", "\n", "a_time_dist_layer = model.get_layer(\"a_output\")\n", "dense_a = a_time_dist_layer.layer\n", "a_output = dense_a(dec_out_a)\n", "\n", "decoder_a = Model(\n", " inputs=[dec_a_inp, state_h_a, state_c_a],\n", " outputs=[a_output, state_h_a_out, state_c_a_out],\n", " name=\"decoder_answer_model\",\n", ")\n", "\n", "# === Build Classifier for Question Type ===\n", "type_dense = model.get_layer(\"type_output\")\n", "type_out = type_dense(enc_out)\n", "\n", "classifier_model = Model(\n", " inputs=[inp_tok_g, inp_ner_g, inp_srl_g], outputs=type_out, name=\"classifier_model\"\n", ")" ] }, { "cell_type": "code", "execution_count": 38, "id": "d406e6ff", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Generated Question: gilang maulana lahir di semarang pada 26 november 1983 ___\n", "Generated Answer : true\n", "Question Type : true_false\n" ] } ], "source": [ "def encode(seq, vmap):\n", " return [vmap.get(tok, vmap[\"\"]) for tok in seq]\n", "\n", "\n", "def encode_and_pad(seq, vmap, max_len=MAX_SENT):\n", " encoded = [vmap.get(tok, vmap[\"\"]) for tok in seq]\n", " # Pad with vocab[\"\"] to the right if sequence is shorter than max_len\n", " padded = encoded + [vmap[\"\"]] * (max_len - len(encoded))\n", " return padded[:max_len] # Ensure it doesn't exceed max_len\n", "\n", "\n", "def greedy_decode(tokens, ner, srl, max_q=20, max_a=10):\n", " # --- encode encoder inputs -------------------------------------------\n", " if isinstance(tokens, np.ndarray):\n", " enc_tok = tokens\n", " enc_ner = ner\n", " enc_srl = srl\n", " else:\n", " enc_tok = np.array([encode_and_pad(tokens, vocab_tok, MAX_SENT)])\n", " enc_ner = np.array([encode_and_pad(ner, vocab_ner, MAX_SENT)])\n", " enc_srl = np.array([encode_and_pad(srl, vocab_srl, MAX_SENT)])\n", "\n", " # --- Get encoder outputs ---\n", " enc_out, h, c = encoder_model.predict([enc_tok, enc_ner, enc_srl], verbose=0)\n", "\n", " # QUESTION Decoding\n", " tgt = np.array([[vocab_q[\"\"]]])\n", " question_ids = []\n", " for _ in range(max_q):\n", " logits, h, c = decoder_q.predict([tgt, h, c], verbose=0)\n", " next_id = int(logits[0, 0].argmax()) # Get the predicted token ID\n", " if next_id == vocab_q[\"\"]:\n", " break\n", " question_ids.append(next_id)\n", " tgt = np.array([[next_id]]) # Feed the predicted token back as input\n", "\n", " # ANSWER Decoding - use encoder outputs again for fresh state\n", " _, h, c = encoder_model.predict([enc_tok, enc_ner, enc_srl], verbose=0)\n", " tgt = np.array([[vocab_a[\"\"]]])\n", " answer_ids = []\n", " for _ in range(max_a):\n", " logits, h, c = decoder_a.predict([tgt, h, c], verbose=0)\n", " next_id = int(logits[0, 0].argmax())\n", " if next_id == vocab_a[\"\"]:\n", " break\n", " answer_ids.append(next_id)\n", " tgt = np.array([[next_id]])\n", "\n", " # Question Type\n", " qtype_logits = classifier_model.predict([enc_tok, enc_ner, enc_srl], verbose=0)\n", " qtype_id = int(qtype_logits.argmax())\n", "\n", " # Final output\n", " question = [inv_vocab_q.get(i, \"\") for i in question_ids]\n", " answer = [inv_vocab_a.get(i, \"\") for i in answer_ids]\n", " q_type = [k for k, v in vocab_typ.items() if v == qtype_id][0]\n", "\n", " return question, answer, q_type\n", "\n", "\n", "def test_model():\n", " test_data = {\n", " \"tokens\": [\n", " \"indra\",\n", " \"maulana\",\n", " \"lahir\",\n", " \"di\",\n", " \"semarang\",\n", " \"pada\",\n", " \"12\",\n", " \"maret\",\n", " \"1977\",\n", " ],\n", " \"ner\": [\"B-PER\", \"I-PER\", \"V\", \"O\", \"B-LOC\", \"O\", \"B-DATE\", \"I-DATE\", \"I-DATE\"],\n", " \"srl\": [\n", " \"ARG0\",\n", " \"ARG0\",\n", " \"V\",\n", " \"O\",\n", " \"ARGM-LOC\",\n", " \"O\",\n", " \"ARGM-TMP\",\n", " \"ARGM-TMP\",\n", " \"ARGM-TMP\",\n", " ],\n", " }\n", " # tokens = [\n", " # \"soekarno\",\n", " # \"membacakan\",\n", " # \"teks\",\n", " # \"proklamasi\",\n", " # \"pada\",\n", " # \"17\",\n", " # \"agustus\",\n", " # \"1945\",\n", " # ]\n", " # ner_tags = [\"B-PER\", \"O\", \"O\", \"O\", \"O\", \"B-DATE\", \"I-DATE\", \"I-DATE\"]\n", " # srl_tags = [\"ARG0\", \"V\", \"ARG1\", \"ARG1\", \"O\", \"ARGM-TMP\", \"ARGM-TMP\", \"ARGM-TMP\"]\n", "\n", " question, answer, q_type = greedy_decode(\n", " test_data[\"tokens\"], test_data[\"ner\"], test_data[\"srl\"]\n", " )\n", " print(f\"Generated Question: {' '.join(question)}\")\n", " print(f\"Generated Answer : {' '.join(answer)}\")\n", " print(f\"Question Type : {q_type}\")\n", "\n", "\n", "test_model()" ] }, { "cell_type": "code", "execution_count": 39, "id": "5adde3c3", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "BLEU : 0.1185\n", "ROUGE1: 0.3967 | ROUGE-L: 0.3967\n" ] } ], "source": [ "from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction\n", "from rouge_score import rouge_scorer\n", "\n", "smoothie = SmoothingFunction().method4\n", "scorer = rouge_scorer.RougeScorer([\"rouge1\", \"rougeL\"], use_stemmer=True)\n", "\n", "\n", "# Helper to strip special ids\n", "def strip_special(ids, vocab):\n", " pad = vocab[\"\"] if \"\" in vocab else None\n", " eos = vocab[\"\"]\n", " return [i for i in ids if i not in (pad, eos)]\n", "\n", "\n", "def ids_to_text(ids, inv_vocab):\n", " return \" \".join(inv_vocab[i] for i in ids)\n", "\n", "\n", "# ---- evaluation over a set of indices ----\n", "import random\n", "\n", "\n", "def evaluate(indices=None):\n", " if indices is None:\n", " indices = random.sample(range(len(X_tok)), k=min(100, len(X_tok)))\n", "\n", " bleu_scores, rou1, rouL = [], [], []\n", " for idx in indices:\n", " # Ground truth\n", " gt_q = strip_special(dec_q_out[idx], vocab_q)\n", " gt_a = strip_special(dec_a_out[idx], vocab_a)\n", " # Prediction\n", " q_pred, a_pred, _ = greedy_decode(\n", " X_tok[idx : idx + 1], X_ner[idx : idx + 1], X_srl[idx : idx + 1]\n", " )\n", "\n", " # BLEU on question tokens\n", " bleu_scores.append(\n", " sentence_bleu(\n", " [[inv_vocab_q[i] for i in gt_q]], q_pred, smoothing_function=smoothie\n", " )\n", " )\n", " # ROUGE on question strings\n", " r = scorer.score(ids_to_text(gt_q, inv_vocab_q), \" \".join(q_pred))\n", " rou1.append(r[\"rouge1\"].fmeasure)\n", " rouL.append(r[\"rougeL\"].fmeasure)\n", "\n", " print(f\"BLEU : {np.mean(bleu_scores):.4f}\")\n", " print(f\"ROUGE1: {np.mean(rou1):.4f} | ROUGE-L: {np.mean(rouL):.4f}\")\n", "\n", "\n", "evaluate()" ] } ], "metadata": { "kernelspec": { "display_name": "myenv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.16" } }, "nbformat": 4, "nbformat_minor": 5 }