{ "cells": [ { "cell_type": "code", "execution_count": 53, "id": "fb283f23", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[TypeError] RAW[260]: Expected dict, got with value: PER\n", "[Unexpected Error] RAW[915]: 'list' object has no attribute 'lower'. TOKENS: ['Setelah', 'masuk', 'Islam', 'raja', 'berganti', 'nama', 'menjadi', 'Sultan', 'Isma’il', 'Syah', 'Zill', 'Allah', 'fi', 'al-Alam', 'dan', 'juga', 'ketiga', 'orang', 'putra', 'dan', 'putrinya', 'yaitu', 'Sultan', 'Mudaffar', 'Syah', 'Siti', 'Aisyah', 'dan', 'Sultan', 'Mansyur']\n", "Total flattened samples: 1423\n" ] } ], "source": [ "import json\n", "from pathlib import Path\n", "\n", "file_path = \"../dataset/dev_dataset_qg.json\"\n", "\n", "\n", "raw_content = Path(file_path).read_text()\n", "RAW = json.loads(raw_content)\n", "\n", "samples = []\n", "for idx, item in enumerate(RAW):\n", " try:\n", " if not isinstance(item, dict):\n", " print(\n", " f\"[TypeError] RAW[{idx}]: Expected dict, got {type(item)} with value: {item}\"\n", " )\n", " continue\n", "\n", " for qp in item[\"qas\"]:\n", " samp = {\n", " \"tokens\": [tok.lower() for tok in item[\"tokens\"]],\n", " \"ner\": item[\"ner\"],\n", " \"srl\": item[\"srl\"],\n", " \"q_type\": qp[\"type\"],\n", " \"q_toks\": [tok.lower() for tok in qp[\"question\"]] + [\"\"],\n", " }\n", " if isinstance(qp[\"answer\"], list):\n", " samp[\"a_toks\"] = [tok.lower() for tok in qp[\"answer\"]] + [\"\"]\n", " else:\n", " samp[\"a_toks\"] = [qp[\"answer\"].lower(), \"\"]\n", " samples.append(samp)\n", "\n", " except KeyError as e:\n", " print(f\"[KeyError] RAW[{idx}]: Missing key {e}. TOKENS: {item['tokens']}\")\n", " except Exception as e:\n", " print(f\"[Unexpected Error] RAW[{idx}]: {e}. TOKENS: {item['tokens']}\")\n", "\n", "print(\"Total flattened samples:\", len(samples))" ] }, { "cell_type": "code", "execution_count": 42, "id": "fa4f979d", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'': 0, '': 1, '': 2, '': 3, 'dimana': 4, 'kartini': 5, 'lahir': 6, '___': 7, 'pada': 8, 'tanggal': 9, '21': 10, 'mei': 11, '1879': 12, 'kerajaan': 13, 'majapahit': 14, 'berdiri': 15, 'tahun': 16, '1300': 17, 'berapa': 18, 'kemerdekaan': 19, 'indonesia': 20, 'diproklamasikan': 21, 'siapa': 22, 'yang': 23, 'memproklamasikan': 24, 'lama': 25, 'bumi': 26, 'mengelilingi': 27, 'matahari': 28, 'presiden': 29, 'pertama': 30, 'planet': 31, 'apa': 32, 'paling': 33, 'dekat': 34, 'dengan': 35, 'venus': 36, 'memiliki': 37, 'suhu': 38, 'permukaan': 39, 'tinggi': 40, 'dikenal': 41, 'sebagai': 42, 'merah': 43, 'terbesar': 44, 'di': 45, 'tata': 46, 'surya': 47, 'terkenal': 48, 'cincin': 49, 'indah': 50, 'berwarna': 51, 'biru': 52, 'jauh': 53, 'dari': 54, 'apakah': 55, 'pluto': 56, 'masih': 57, 'dianggap': 58, 'soekarno': 59, 'membacakan': 60, 'teks': 61, 'proklamasi': 62, 'kapan': 63, 'sebutkan': 64, 'dibacakan': 65, 'andi': 66, 'melakukan': 67, 'pergi': 68, 'ke': 69, 'jakarta': 70, 'siti': 71, 'berangkat': 72, 'bandung': 73, 'budi': 74, 'pindah': 75, 'bali': 76, 'lina': 77, 'dan': 78, 'wati': 79, 'liburan': 80, 'medan': 81, 'agus': 82, 'wijaya': 83, 'melanjutkan': 84, 'surabaya': 85, 'nurul': 86, 'yogyakarta': 87, 'dedi': 88, 'makassar': 89, 'maya': 90, 'lestari': 91, 'roni': 92, 'tiara': 93, 'setiawan': 94, 'santoso': 95, 'saputra': 96, 'aktivitas': 97, 'dilakukan': 98, 'oleh': 99, 'maharani': 100, 'firmansyah': 101, 'gunung': 102, 'tertinggi': 103, 'dunia': 104, '?': 105, 'sungai': 106, 'terpanjang': 107, 'bangunan': 108, 'dibangun': 109, 'sekitar': 110, '2560': 111, 'sm': 112, 'benua': 113, 'keajaiban': 114, 'berada': 115, 'italia': 116, 'negara': 117, 'mana': 118, 'terletak': 119, 'colosseum': 120, 'nama': 121, 'letaknya': 122, 'taj': 123, 'mahal': 124, 'india': 125, 'petra': 126, 'ada': 127, 'yordania': 128, 'china': 129, 'meksiko': 130, 'chichen': 131, 'itza': 132, 'patung': 133, 'yesus': 134, 'penebus': 135, 'brasil': 136, 'peru': 137, 'inggris': 138, 'stonehenge': 139, 'menara': 140, 'pisa': 141, 'angkot': 142, 'wat': 143, 'kamodja': 144, 'ketinggian': 145, 'everest': 146, '9000': 147, 'meter': 148, 'merdeka': 149, 'merumuskan': 150, 'teori': 151, 'relativitas': 152, 'albert': 153, 'einstein': 154, '1910': 155, 'organ': 156, 'memompa': 157, 'darah': 158, 'seluruh': 159, 'tubuh': 160, 'fungsi': 161, 'jantung': 162, 'manusia': 163, 'ibukota': 164, 'jepang': 165, 'kota': 166, 'menjadi': 167, 'air': 168, 'mendidih': 169, '90': 170, 'derajat': 171, 'celsius': 172, 'penemu': 173, 'bola': 174, 'lampu': 175, 'thomas': 176, 'alva': 177, 'edison': 178, 'menemukan': 179, 'urutan': 180, 'adalah': 181, 'keempat': 182, 'pelukis': 183, 'mona': 184, 'lisa': 185, 'lukisan': 186, 'dibuat': 187, 'jarak': 188, 'dalam': 189, 'satu': 190, 'cahaya': 191, 'setara': 192, '10': 193, 'triliun': 194, 'kilometer': 195, 'pemimpin': 196, 'gerakan': 197, 'mahatma': 198, 'gandhi': 199, 'memimpin': 200, 'pakistan': 201, 'nasa': 202, 'didirikan': 203, 'bagian': 204, 'terluar': 205, 'mata': 206, 'retina': 207, 'napoleon': 208, 'bonaparte': 209, 'dikalahkan': 210, 'pertempuran': 211, 'waterloo': 212, 'terjadi': 213, 'komodo': 214, 'ditemukan': 215, 'dapat': 216, 'australia': 217, 'pemenang': 218, 'nobel': 219, 'bidang': 220, 'fisika': 221, 'kimia': 222, 'marie': 223, 'curie': 224, 'memenangkan': 225, 'machu': 226, 'picchu': 227, 'situs': 228, 'peradaban': 229, 'dihasilkan': 230, 'fotosintesis': 231, 'selain': 232, 'glukosa': 233, 'bahan': 234, 'saja': 235, 'dibutuhkan': 236, 'seniman': 237, 'memotong': 238, 'telinganya': 239, 'sendiri': 240, 'vincent': 241, 'van': 242, 'gogh': 243, '1890': 244, 'bagaimana': 245, 'bentuk': 246, 'molekul': 247, 'dna': 248, 'struktur': 249, 'berbentuk': 250, 'penisilin': 251, 'secara': 252, 'sengaja': 253, 'setelah': 254, 'penelitian': 255, 'bertahun-tahun': 256, 'buah': 257, 'mengandung': 258, 'banyak': 259, 'vitamin': 260, 'c': 261, 'terkandung': 262, 'jeruk': 263, 'mengembangkan': 264, 'sistem': 265, 'arus': 266, 'listrik': 267, 'bolak-balik': 268, 'panjang': 269, 'nil': 270, 'hewan': 271, 'pernah': 272, 'hidup': 273, 'gajah': 274, 'memproduksi': 275, 'insulin': 276, 'mengemukakan': 277, 'evolusi': 278, 'dikemukakan': 279, 'isaac': 280, 'newton': 281, 'membangun': 282, 'untuk': 283, 'asia': 284, 'luas': 285, '44.58': 286, 'juta': 287, 'km²': 288, 'afrika': 289, 'utara': 290, 'laut': 291, 'mediterania': 292, 'manakah': 293, 'terkecil': 294, 'kedua': 295, 'eropa': 296, 'lebih': 297, 'besar': 298, 'samudera': 299, 'berbatasan': 300, 'amerika': 301, 'timur': 302, 'hindia': 303, 'barat': 304, 'selatan': 305, 'terdiri': 306, 'hutan': 307, 'amazon': 308, 'belahan': 309, 'kutub': 310, 'antartika': 311, 'hampir': 312, 'seluruhnya': 313, 'tertutup': 314, 'es': 315, 'populasi': 316, '4.7': 317, 'miliar': 318, 'penduduk': 319, 'jumlah': 320, 'kilimanjaro': 321, '5,895': 322, 'gurun': 323, 'ketiga': 324, 'sahara': 325, 'merupakan': 326, 'peringkat': 327, '6,650': 328, 'km': 329, 'pegunungan': 330, 'alpen': 331, 'membentang': 332, '8': 333, 'danau': 334, 'superior': 335, 'tawar': 336, 'menghadiri': 337, 'turnamen': 338, 'catur': 339, 'ali': 340, '15': 341, 'juli': 342, '2023': 343, 'rapat': 344, 'organisasi': 345, 'nina': 346, '25': 347, 'desember': 348, 'farhan': 349, 'workshop': 350, 'fotografi': 351, 'pameran': 352, 'teknologi': 353, '5': 354, 'malang': 355, 'iqbal': 356, 'perlombaan': 357, 'renang': 358, 'padang': 359, 'konser': 360, 'musik': 361, 'agustus': 362, 'fajar': 363, 'dina': 364, '1': 365, 'januari': 366, '2024': 367, 'festival': 368, 'kuliner': 369, 'rian': 370, 'bazar': 371, 'amal': 372, 'tari': 373, 'seminar': 374, 'pendidikan': 375, 'kompetisi': 376, 'robotik': 377, 'rudi': 378, 'semarang': 379, 'putri': 380, 'hana': 381, 'raka': 382, 'dewi': 383, 'pahlawan': 384, 'jawa': 385, 'pelajar': 386, 'kaya': 387, 'akan': 388, 'budaya': 389, 'pusat': 390, 'pemerintahan': 391, 'kembang': 392, 'fashion': 393, 'sejuk': 394, 'destinasi': 395, 'wisata': 396, 'alam': 397, 'pulau': 398, 'dewata': 399, 'masakan': 400, 'rendang': 401, 'mendunia': 402, 'pelabuhan': 403, 'utama': 404, 'sumatra': 405, 'khas': 406, 'kemerdekaannya': 407, 'sumpah': 408, 'pemuda': 409, 'diikrarkan': 410, 'isi': 411, 'proses': 412, 'berlangsung': 413, 'diubah': 414, 'pembelahan': 415, 'mitosis': 416, 'sel': 417, 'menghasilkan': 418, 'gamet': 419, 'benda': 420, 'ditarik': 421, 'magnet': 422, 'dilaksanakan': 423, 'ginjal': 424, 'paru-paru': 425, 'ekskresi': 426, 'hati': 427, 'itu': 428, 'ramah': 429, 'lingkungan': 430, 'dampak': 431, 'negatif': 432, 'penerapan': 433, 'bioteknologi': 434, 'pembentukan': 435, 'urine': 436, 'peredaran': 437, 'berperan': 438, 'penting': 439, 'masa': 440, 'reformasi': 441, 'dimulai': 442, 'peristiwa': 443, 'politik': 444, 'pencernaan': 445, 'penyerapan': 446, 'zat': 447, 'makanan': 448, 'pernapasan': 449, 'contoh': 450, 'alat': 451, 'optik': 452, 'kehidupan': 453, 'sehari-hari': 454, 'kacamata': 455, 'pembuluh': 456, 'xilem': 457, 'tumbuhan': 458, 'floem': 459, 'pemanfaatan': 460, 'getaran': 461, 'gelombang': 462, 'hasil': 463, 'gangguan': 464, 'penyebab': 465, 'penyakit': 466, 'asma': 467, 'mohammad': 468, 'hatta': 469, 'bung': 470, 'tomo': 471, 'i': 472, 'gusti': 473, 'ngurah': 474, 'rai': 475, 'gugur': 476, 'cut': 477, 'nyak': 478, 'dien': 479, 'wafat': 480, 'teuku': 481, 'umar': 482, 'wahidin': 483, 'sudirohusodo': 484, 'sultan': 485, 'mahmud': 486, 'badaruddin': 487, 'ii': 488, 'kh': 489, 'ahmad': 490, 'dahlan': 491, 'hasyim': 492, \"asy'ari\": 493, 'ageng': 494, 'tirtayasa': 495, 'hasanuddin': 496, 'pattimura': 497, 'pangeran': 498, 'diponegoro': 499, 'sentot': 500, 'alibasya': 501, 'prawirodirjo': 502, 'cipto': 503, 'mangunkusumo': 504, 'ernest': 505, 'douwes': 506, 'dekker': 507, 'dr.': 508, 'mas': 509, 'mansur': 510, 'sutan': 511, 'sjahrir': 512, 'abdul': 513, 'muis': 514, 'otto': 515, 'iskandardinata': 516, 'abikusno': 517, 'tjokrosujoso': 518, 'wahid': 519, 'bpupki': 520, 'ketua': 521, 'ppki': 522, 'pendiri': 523, 'nahdlatul': 524, 'ulama': 525, 'jong': 526, 'islamieten': 527, 'bond': 528, 'muhammadiyah': 529, 'muda': 530, 'perhimpunan': 531, 'partai': 532, 'nasional': 533, 'voc': 534, 'dibubarkan': 535, 'utomo': 536, 'tokoh': 537, 'sarekat': 538, 'islam': 539, 'komunis': 540, 'fonds': 541, 'mardika': 542, 'kutai': 543, 'peranan': 544, 'mahakam': 545, 'bagi': 546, 'perekonomian': 547, 'sumber': 548, 'sejarah': 549, 'raja': 550, 'memerintah': 551, 'saat': 552, 'yupa': 553, 'dikeluarkan': 554, 'ditulis': 555, 'huruf': 556, 'bahasa': 557, 'prasasti': 558, 'diperkirakan': 559, 'kakek': 560, 'mulawarman': 561, 'dinasti': 562, 'lembu': 563, 'dikorbankan': 564, 'zaman': 565, 'keemasan': 566, 'melalui': 567, 'ekonomi': 568, 'berkembang': 569, 'pesat': 570, 'jalur': 571, 'perdagangan': 572, 'internasional': 573, 'hingga': 574, 'memberi': 575, 'sedekah': 576, 'sapi': 577, 'diberikan': 578, 'kepada': 579, 'mengalami': 580, 'digunakan': 581, 'tarumanegara': 582, 'abad': 583, 'memerintahkan': 584, 'penggalian': 585, 'candrabaga': 586, 'ekor': 587, 'dipersembahkan': 588, 'ditonjolkan': 589, 'cidanghiang': 590, 'dilambangkan': 591, 'gambar': 592, 'telapak': 593, 'kaki': 594, 'kebon': 595, 'kopi': 596, 'saturnus': 597, 'mars': 598, 'bintang': 599, 'dihuni': 600, 'makhluk': 601, 'sekarang': 602, 'dikategorikan': 603, 'satelit': 604, 'dimiliki': 605, 'jupiter': 606, 'bernama': 607, 'alami': 608, 'berputar': 609, 'miring': 610, 'terhadap': 611, 'porosnya': 612, 'disebut': 613, 'panas': 614, 'tersusun': 615, 'dilalui': 616, 'tidak': 617, 'siang': 618, 'malam': 619, 'karena': 620, 'habibie': 621, 'megawati': 622, 'yudhoyono': 623, 'widodo': 624, 'sri': 625, 'mulyani': 626, 'shihab': 627, 'agnez': 628, 'mo': 629, 'zain': 630, 'dian': 631, 'denpasar': 632, 'vina': 633, 'sari': 634, 'mita': 635, 'aditya': 636, 'kediri': 637, 'rahmat': 638, 'eka': 639, 'hendra': 640, 'zulfa': 641, 'nadya': 642, 'luthfi': 643, 'mario': 644, 'citra': 645, 'udin': 646, 'ambon': 647, 'cilegon': 648, 'joni': 649, 'jayapura': 650, 'xenia': 651, 'gina': 652, 'probolinggo': 653, 'candra': 654, 'faisal': 655, 'intan': 656, 'tegal': 657, 'joko': 658, 'palembang': 659, 'vivin': 660, 'wawan': 661, 'jember': 662, 'hari': 663, 'bayu': 664, 'qori': 665, 'solo': 666, 'depok': 667, 'kiki': 668, 'ernita': 669, 'laila': 670, 'teguh': 671, 'banjarmasin': 672, 'samarinda': 673, 'pontianak': 674, 'galuh': 675, 'pratiwi': 676, 'gianyar': 677, 'khansa': 678, 'oki': 679, 'yosef': 680, 'umi': 681, 'blitar': 682, 'manado': 683, 'rani': 684, 'kupang': 685, 'tasikmalaya': 686, 'yana': 687, 'salatiga': 688, 'pekanbaru': 689, 'magelang': 690, 'tania': 691, 'ilham': 692, 'sidoarjo': 693, 'purwokerto': 694, 'sukabumi': 695, 'opik': 696, 'cirebon': 697, 'bekasi': 698, 'mataram': 699, 'wahyu': 700, 'serang': 701}\n" ] } ], "source": [ "from itertools import chain\n", "\n", "\n", "def build_vocab(seq_iter, reserved=[\"\", \"\", \"\", \"\"]):\n", " vocab = {tok: idx for idx, tok in enumerate(reserved)}\n", " for tok in chain.from_iterable(seq_iter):\n", " if tok not in vocab:\n", " vocab[tok] = len(vocab)\n", " return vocab\n", "\n", "\n", "vocab_tok = build_vocab((s[\"tokens\"] for s in samples))\n", "vocab_ner = build_vocab((s[\"ner\"] for s in samples), reserved=[\"\", \"\"])\n", "vocab_srl = build_vocab((s[\"srl\"] for s in samples), reserved=[\"\", \"\"])\n", "vocab_q = build_vocab((s[\"q_toks\"] for s in samples))\n", "vocab_a = build_vocab((s[\"a_toks\"] for s in samples))\n", "\n", "vocab_typ = {\"isian\": 0, \"opsi\": 1, \"true_false\": 2}\n", "\n", "print(vocab_q)" ] }, { "cell_type": "code", "execution_count": 43, "id": "d1a5b324", "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "from tensorflow.keras.preprocessing.sequence import pad_sequences\n", "\n", "\n", "def encode(seq, vmap): # token → id\n", " return [vmap.get(t, vmap[\"\"]) for t in seq]\n", "\n", "\n", "MAX_SENT = max(len(s[\"tokens\"]) for s in samples)\n", "MAX_Q = max(len(s[\"q_toks\"]) for s in samples)\n", "MAX_A = max(len(s[\"a_toks\"]) for s in samples)\n", "\n", "X_tok = pad_sequences(\n", " [encode(s[\"tokens\"], vocab_tok) for s in samples], maxlen=MAX_SENT, padding=\"post\"\n", ")\n", "X_ner = pad_sequences(\n", " [encode(s[\"ner\"], vocab_ner) for s in samples], maxlen=MAX_SENT, padding=\"post\"\n", ")\n", "X_srl = pad_sequences(\n", " [encode(s[\"srl\"], vocab_srl) for s in samples], maxlen=MAX_SENT, padding=\"post\"\n", ")\n", "\n", "# Decoder input = + target[:-1]\n", "dec_q_in = pad_sequences(\n", " [[vocab_q[\"\"], *encode(s[\"q_toks\"][:-1], vocab_q)] for s in samples],\n", " maxlen=MAX_Q,\n", " padding=\"post\",\n", ")\n", "dec_q_out = pad_sequences(\n", " [encode(s[\"q_toks\"], vocab_q) for s in samples], maxlen=MAX_Q, padding=\"post\"\n", ")\n", "\n", "dec_a_in = pad_sequences(\n", " [[vocab_a[\"\"], *encode(s[\"a_toks\"][:-1], vocab_a)] for s in samples],\n", " maxlen=MAX_A,\n", " padding=\"post\",\n", ")\n", "dec_a_out = pad_sequences(\n", " [encode(s[\"a_toks\"], vocab_a) for s in samples], maxlen=MAX_A, padding=\"post\"\n", ")\n", "y_type = np.array([vocab_typ[s[\"q_type\"]] for s in samples])\n", "\n", "MAX_SENT = max(len(s[\"tokens\"]) for s in samples)\n", "MAX_Q = max(len(s[\"q_toks\"]) for s in samples)\n", "MAX_A = max(len(s[\"a_toks\"]) for s in samples)" ] }, { "cell_type": "code", "execution_count": 44, "id": "ff5bd85f", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
Model: \"functional_4\"\n",
       "
\n" ], "text/plain": [ "\u001b[1mModel: \"functional_4\"\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓\n",
       "┃ Layer (type)         Output Shape          Param #  Connected to      ┃\n",
       "┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩\n",
       "│ tok_in (InputLayer) │ (None, 34)        │          0 │ -                 │\n",
       "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
       "│ ner_in (InputLayer) │ (None, 34)        │          0 │ -                 │\n",
       "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
       "│ srl_in (InputLayer) │ (None, 34)        │          0 │ -                 │\n",
       "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
       "│ embedding_tok       │ (None, 34, 32)    │     36,096 │ tok_in[0][0]      │\n",
       "│ (Embedding)         │                   │            │                   │\n",
       "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
       "│ embedding_ner       │ (None, 34, 16)    │        448 │ ner_in[0][0]      │\n",
       "│ (Embedding)         │                   │            │                   │\n",
       "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
       "│ embedding_srl       │ (None, 34, 16)    │        352 │ srl_in[0][0]      │\n",
       "│ (Embedding)         │                   │            │                   │\n",
       "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
       "│ dec_q_in            │ (None, 13)        │          0 │ -                 │\n",
       "│ (InputLayer)        │                   │            │                   │\n",
       "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
       "│ concatenate_4       │ (None, 34, 64)    │          0 │ embedding_tok[0]… │\n",
       "│ (Concatenate)       │                   │            │ embedding_ner[0]… │\n",
       "│                     │                   │            │ embedding_srl[0]… │\n",
       "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
       "│ dec_a_in            │ (None, 12)        │          0 │ -                 │\n",
       "│ (InputLayer)        │                   │            │                   │\n",
       "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
       "│ embedding_q_decoder │ (None, 13, 32)    │     22,464 │ dec_q_in[0][0]    │\n",
       "│ (Embedding)         │                   │            │                   │\n",
       "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
       "│ encoder_lstm (LSTM) │ [(None, 64),      │     33,024 │ concatenate_4[0]… │\n",
       "│                     │ (None, 64),       │            │                   │\n",
       "│                     │ (None, 64)]       │            │                   │\n",
       "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
       "│ embedding_a_decoder │ (None, 12, 32)    │     19,200 │ dec_a_in[0][0]    │\n",
       "│ (Embedding)         │                   │            │                   │\n",
       "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
       "│ lstm_q_decoder      │ [(None, 13, 64),  │     24,832 │ embedding_q_deco… │\n",
       "│ (LSTM)              │ (None, 64),       │            │ encoder_lstm[0][ │\n",
       "│                     │ (None, 64)]       │            │ encoder_lstm[0][ │\n",
       "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
       "│ not_equal_16        │ (None, 13)        │          0 │ dec_q_in[0][0]    │\n",
       "│ (NotEqual)          │                   │            │                   │\n",
       "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
       "│ lstm_a_decoder      │ [(None, 12, 64),  │     24,832 │ embedding_a_deco… │\n",
       "│ (LSTM)              │ (None, 64),       │            │ encoder_lstm[0][ │\n",
       "│                     │ (None, 64)]       │            │ encoder_lstm[0][ │\n",
       "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
       "│ not_equal_17        │ (None, 12)        │          0 │ dec_a_in[0][0]    │\n",
       "│ (NotEqual)          │                   │            │                   │\n",
       "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
       "│ q_output            │ (None, 13, 702)   │     45,630 │ lstm_q_decoder[0… │\n",
       "│ (TimeDistributed)   │                   │            │ not_equal_16[0][ │\n",
       "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
       "│ a_output            │ (None, 12, 600)   │     39,000 │ lstm_a_decoder[0… │\n",
       "│ (TimeDistributed)   │                   │            │ not_equal_17[0][ │\n",
       "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
       "│ type_output (Dense) │ (None, 3)         │        195 │ encoder_lstm[0][ │\n",
       "└─────────────────────┴───────────────────┴────────────┴───────────────────┘\n",
       "
\n" ], "text/plain": [ "┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓\n", "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mConnected to \u001b[0m\u001b[1m \u001b[0m┃\n", "┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩\n", "│ tok_in (\u001b[38;5;33mInputLayer\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m34\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n", "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", "│ ner_in (\u001b[38;5;33mInputLayer\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m34\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n", "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", "│ srl_in (\u001b[38;5;33mInputLayer\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m34\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n", "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", "│ embedding_tok │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m34\u001b[0m, \u001b[38;5;34m32\u001b[0m) │ \u001b[38;5;34m36,096\u001b[0m │ tok_in[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", "│ (\u001b[38;5;33mEmbedding\u001b[0m) │ │ │ │\n", "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", "│ embedding_ner │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m34\u001b[0m, \u001b[38;5;34m16\u001b[0m) │ \u001b[38;5;34m448\u001b[0m │ ner_in[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", "│ (\u001b[38;5;33mEmbedding\u001b[0m) │ │ │ │\n", "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", "│ embedding_srl │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m34\u001b[0m, \u001b[38;5;34m16\u001b[0m) │ \u001b[38;5;34m352\u001b[0m │ srl_in[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", "│ (\u001b[38;5;33mEmbedding\u001b[0m) │ │ │ │\n", "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", "│ dec_q_in │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m13\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n", "│ (\u001b[38;5;33mInputLayer\u001b[0m) │ │ │ │\n", "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", "│ concatenate_4 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m34\u001b[0m, \u001b[38;5;34m64\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ embedding_tok[\u001b[38;5;34m0\u001b[0m]… │\n", "│ (\u001b[38;5;33mConcatenate\u001b[0m) │ │ │ embedding_ner[\u001b[38;5;34m0\u001b[0m]… │\n", "│ │ │ │ embedding_srl[\u001b[38;5;34m0\u001b[0m]… │\n", "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", "│ dec_a_in │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m12\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n", "│ (\u001b[38;5;33mInputLayer\u001b[0m) │ │ │ │\n", "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", "│ embedding_q_decoder │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m13\u001b[0m, \u001b[38;5;34m32\u001b[0m) │ \u001b[38;5;34m22,464\u001b[0m │ dec_q_in[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", "│ (\u001b[38;5;33mEmbedding\u001b[0m) │ │ │ │\n", "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", "│ encoder_lstm (\u001b[38;5;33mLSTM\u001b[0m) │ [(\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m64\u001b[0m), │ \u001b[38;5;34m33,024\u001b[0m │ concatenate_4[\u001b[38;5;34m0\u001b[0m]… │\n", "│ │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m64\u001b[0m), │ │ │\n", "│ │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m64\u001b[0m)] │ │ │\n", "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", "│ embedding_a_decoder │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m12\u001b[0m, \u001b[38;5;34m32\u001b[0m) │ \u001b[38;5;34m19,200\u001b[0m │ dec_a_in[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", "│ (\u001b[38;5;33mEmbedding\u001b[0m) │ │ │ │\n", "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", "│ lstm_q_decoder │ [(\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m13\u001b[0m, \u001b[38;5;34m64\u001b[0m), │ \u001b[38;5;34m24,832\u001b[0m │ embedding_q_deco… │\n", "│ (\u001b[38;5;33mLSTM\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m64\u001b[0m), │ │ encoder_lstm[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m…\u001b[0m │\n", "│ │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m64\u001b[0m)] │ │ encoder_lstm[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m…\u001b[0m │\n", "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", "│ not_equal_16 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m13\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ dec_q_in[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", "│ (\u001b[38;5;33mNotEqual\u001b[0m) │ │ │ │\n", "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", "│ lstm_a_decoder │ [(\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m12\u001b[0m, \u001b[38;5;34m64\u001b[0m), │ \u001b[38;5;34m24,832\u001b[0m │ embedding_a_deco… │\n", "│ (\u001b[38;5;33mLSTM\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m64\u001b[0m), │ │ encoder_lstm[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m…\u001b[0m │\n", "│ │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m64\u001b[0m)] │ │ encoder_lstm[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m…\u001b[0m │\n", "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", "│ not_equal_17 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m12\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ dec_a_in[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", "│ (\u001b[38;5;33mNotEqual\u001b[0m) │ │ │ │\n", "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", "│ q_output │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m13\u001b[0m, \u001b[38;5;34m702\u001b[0m) │ \u001b[38;5;34m45,630\u001b[0m │ lstm_q_decoder[\u001b[38;5;34m0\u001b[0m… │\n", "│ (\u001b[38;5;33mTimeDistributed\u001b[0m) │ │ │ not_equal_16[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m…\u001b[0m │\n", "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", "│ a_output │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m12\u001b[0m, \u001b[38;5;34m600\u001b[0m) │ \u001b[38;5;34m39,000\u001b[0m │ lstm_a_decoder[\u001b[38;5;34m0\u001b[0m… │\n", "│ (\u001b[38;5;33mTimeDistributed\u001b[0m) │ │ │ not_equal_17[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m…\u001b[0m │\n", "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", "│ type_output (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m3\u001b[0m) │ \u001b[38;5;34m195\u001b[0m │ encoder_lstm[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m…\u001b[0m │\n", "└─────────────────────┴───────────────────┴────────────┴───────────────────┘\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
 Total params: 246,073 (961.22 KB)\n",
       "
\n" ], "text/plain": [ "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m246,073\u001b[0m (961.22 KB)\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
 Trainable params: 246,073 (961.22 KB)\n",
       "
\n" ], "text/plain": [ "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m246,073\u001b[0m (961.22 KB)\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
 Non-trainable params: 0 (0.00 B)\n",
       "
\n" ], "text/plain": [ "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import tensorflow as tf\n", "from tensorflow.keras.layers import (\n", " Input,\n", " Embedding,\n", " LSTM,\n", " Concatenate,\n", " Dense,\n", " TimeDistributed,\n", ")\n", "from tensorflow.keras.models import Model\n", "\n", "# ---- constants ---------------------------------------------------\n", "d_tok = 32 # token embedding dim\n", "d_tag = 16 # NER / SRL embedding dim\n", "units = 64\n", "\n", "# ---- encoder -----------------------------------------------------\n", "inp_tok = Input((MAX_SENT,), name=\"tok_in\")\n", "inp_ner = Input((MAX_SENT,), name=\"ner_in\")\n", "inp_srl = Input((MAX_SENT,), name=\"srl_in\")\n", "\n", "# make ALL streams mask the same way (here: no masking,\n", "# we'll just pad with 0s and let the LSTM ignore them)\n", "emb_tok = Embedding(len(vocab_tok), d_tok, mask_zero=False, name=\"embedding_tok\")(\n", " inp_tok\n", ")\n", "emb_ner = Embedding(len(vocab_ner), d_tag, mask_zero=False, name=\"embedding_ner\")(\n", " inp_ner\n", ")\n", "emb_srl = Embedding(len(vocab_srl), d_tag, mask_zero=False, name=\"embedding_srl\")(\n", " inp_srl\n", ")\n", "\n", "enc_concat = Concatenate()([emb_tok, emb_ner, emb_srl])\n", "enc_out, state_h, state_c = LSTM(units, return_state=True, name=\"encoder_lstm\")(\n", " enc_concat\n", ")\n", "\n", "\n", "# ---------- DECODER : Question ----------\n", "dec_q_inp = Input(shape=(MAX_Q,), name=\"dec_q_in\")\n", "dec_emb_q = Embedding(len(vocab_q), d_tok, mask_zero=True, name=\"embedding_q_decoder\")(\n", " dec_q_inp\n", ")\n", "dec_q, _, _ = LSTM(\n", " units, return_state=True, return_sequences=True, name=\"lstm_q_decoder\"\n", ")(dec_emb_q, initial_state=[state_h, state_c])\n", "q_out = TimeDistributed(\n", " Dense(len(vocab_q), activation=\"softmax\", name=\"dense_q_output\"), name=\"q_output\"\n", ")(dec_q)\n", "\n", "# ---------- DECODER : Answer ----------\n", "dec_a_inp = Input(shape=(MAX_A,), name=\"dec_a_in\")\n", "dec_emb_a = Embedding(len(vocab_a), d_tok, mask_zero=True, name=\"embedding_a_decoder\")(\n", " dec_a_inp\n", ")\n", "dec_a, _, _ = LSTM(\n", " units, return_state=True, return_sequences=True, name=\"lstm_a_decoder\"\n", ")(dec_emb_a, initial_state=[state_h, state_c])\n", "a_out = TimeDistributed(\n", " Dense(len(vocab_a), activation=\"softmax\", name=\"dense_a_output\"), name=\"a_output\"\n", ")(dec_a)\n", "\n", "# ---------- CLASSIFIER : Question Type ----------\n", "type_out = Dense(len(vocab_typ), activation=\"softmax\", name=\"type_output\")(enc_out)\n", "\n", "model = Model(\n", " inputs=[inp_tok, inp_ner, inp_srl, dec_q_inp, dec_a_inp],\n", " outputs=[q_out, a_out, type_out],\n", ")\n", "\n", "model.summary()" ] }, { "cell_type": "code", "execution_count": 45, "id": "fece1ae9", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/30\n", "\u001b[1m18/18\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 46ms/step - a_output_loss: 6.2977 - a_output_sparse_categorical_accuracy: 0.0491 - loss: 13.0486 - q_output_loss: 6.5023 - q_output_sparse_categorical_accuracy: 0.0484 - type_output_accuracy: 0.6917 - type_output_loss: 0.8213 - val_a_output_loss: 5.7298 - val_a_output_sparse_categorical_accuracy: 0.0833 - val_loss: 11.6245 - val_q_output_loss: 5.8765 - val_q_output_sparse_categorical_accuracy: 0.0949 - val_type_output_accuracy: 1.0000 - val_type_output_loss: 0.0610\n", "Epoch 2/30\n", "\u001b[1m18/18\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 20ms/step - a_output_loss: 5.4265 - a_output_sparse_categorical_accuracy: 0.0833 - loss: 11.3418 - q_output_loss: 5.7511 - q_output_sparse_categorical_accuracy: 0.0848 - type_output_accuracy: 0.8554 - type_output_loss: 0.5355 - val_a_output_loss: 4.9895 - val_a_output_sparse_categorical_accuracy: 0.0833 - val_loss: 9.5240 - val_q_output_loss: 4.4785 - val_q_output_sparse_categorical_accuracy: 0.0949 - val_type_output_accuracy: 1.0000 - val_type_output_loss: 0.1890\n", "Epoch 3/30\n", "\u001b[1m18/18\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 22ms/step - a_output_loss: 4.5486 - a_output_sparse_categorical_accuracy: 0.0833 - loss: 9.3727 - q_output_loss: 4.6587 - q_output_sparse_categorical_accuracy: 0.0860 - type_output_accuracy: 0.8450 - type_output_loss: 0.5396 - val_a_output_loss: 4.6774 - val_a_output_sparse_categorical_accuracy: 0.0833 - val_loss: 8.0876 - val_q_output_loss: 3.3677 - val_q_output_sparse_categorical_accuracy: 0.1526 - val_type_output_accuracy: 1.0000 - val_type_output_loss: 0.1459\n", "Epoch 4/30\n", "\u001b[1m18/18\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 24ms/step - a_output_loss: 4.2761 - a_output_sparse_categorical_accuracy: 0.0833 - loss: 8.5210 - q_output_loss: 4.0894 - q_output_sparse_categorical_accuracy: 0.1156 - type_output_accuracy: 0.8572 - type_output_loss: 0.5103 - val_a_output_loss: 4.6189 - val_a_output_sparse_categorical_accuracy: 0.0833 - val_loss: 7.7501 - val_q_output_loss: 3.0787 - val_q_output_sparse_categorical_accuracy: 0.1942 - val_type_output_accuracy: 1.0000 - val_type_output_loss: 0.1810\n", "Epoch 5/30\n", "\u001b[1m18/18\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 22ms/step - a_output_loss: 4.2011 - a_output_sparse_categorical_accuracy: 0.0833 - loss: 8.4394 - q_output_loss: 4.0754 - q_output_sparse_categorical_accuracy: 0.1411 - type_output_accuracy: 0.8433 - type_output_loss: 0.5441 - val_a_output_loss: 4.5352 - val_a_output_sparse_categorical_accuracy: 0.0833 - val_loss: 7.5588 - val_q_output_loss: 2.9835 - val_q_output_sparse_categorical_accuracy: 0.2128 - val_type_output_accuracy: 1.0000 - val_type_output_loss: 0.1384\n", "Epoch 6/30\n", "\u001b[1m18/18\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 20ms/step - a_output_loss: 4.1145 - a_output_sparse_categorical_accuracy: 0.0833 - loss: 8.2565 - q_output_loss: 3.9795 - q_output_sparse_categorical_accuracy: 0.1437 - type_output_accuracy: 0.8453 - type_output_loss: 0.5410 - val_a_output_loss: 4.4676 - val_a_output_sparse_categorical_accuracy: 0.0833 - val_loss: 7.4411 - val_q_output_loss: 2.9275 - val_q_output_sparse_categorical_accuracy: 0.2122 - val_type_output_accuracy: 1.0000 - val_type_output_loss: 0.1587\n", "Epoch 7/30\n", "\u001b[1m18/18\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 26ms/step - a_output_loss: 3.9719 - a_output_sparse_categorical_accuracy: 0.0833 - loss: 7.9923 - q_output_loss: 3.8560 - q_output_sparse_categorical_accuracy: 0.1567 - type_output_accuracy: 0.8416 - type_output_loss: 0.5470 - val_a_output_loss: 4.4240 - val_a_output_sparse_categorical_accuracy: 0.0833 - val_loss: 7.2738 - val_q_output_loss: 2.8047 - val_q_output_sparse_categorical_accuracy: 0.2444 - val_type_output_accuracy: 1.0000 - val_type_output_loss: 0.1557\n", "Epoch 8/30\n", "\u001b[1m18/18\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 24ms/step - a_output_loss: 3.8570 - a_output_sparse_categorical_accuracy: 0.0835 - loss: 7.7904 - q_output_loss: 3.7669 - q_output_sparse_categorical_accuracy: 0.1771 - type_output_accuracy: 0.8376 - type_output_loss: 0.5608 - val_a_output_loss: 4.3630 - val_a_output_sparse_categorical_accuracy: 0.0833 - val_loss: 7.0905 - val_q_output_loss: 2.6867 - val_q_output_sparse_categorical_accuracy: 0.2481 - val_type_output_accuracy: 1.0000 - val_type_output_loss: 0.1411\n", "Epoch 9/30\n", "\u001b[1m18/18\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 24ms/step - a_output_loss: 3.8458 - a_output_sparse_categorical_accuracy: 0.0895 - loss: 7.6486 - q_output_loss: 3.6402 - q_output_sparse_categorical_accuracy: 0.1849 - type_output_accuracy: 0.8434 - type_output_loss: 0.5451 - val_a_output_loss: 4.3015 - val_a_output_sparse_categorical_accuracy: 0.0833 - val_loss: 6.9016 - val_q_output_loss: 2.5552 - val_q_output_sparse_categorical_accuracy: 0.2475 - val_type_output_accuracy: 1.0000 - val_type_output_loss: 0.1544\n", "Epoch 10/30\n", "\u001b[1m18/18\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 19ms/step - a_output_loss: 3.7376 - a_output_sparse_categorical_accuracy: 0.0886 - loss: 7.4231 - q_output_loss: 3.5317 - q_output_sparse_categorical_accuracy: 0.1886 - type_output_accuracy: 0.8567 - type_output_loss: 0.5139 - val_a_output_loss: 4.2232 - val_a_output_sparse_categorical_accuracy: 0.0833 - val_loss: 6.7062 - val_q_output_loss: 2.4321 - val_q_output_sparse_categorical_accuracy: 0.2481 - val_type_output_accuracy: 1.0000 - val_type_output_loss: 0.1747\n", "Epoch 11/30\n", "\u001b[1m18/18\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 18ms/step - a_output_loss: 3.6927 - a_output_sparse_categorical_accuracy: 0.0887 - loss: 7.3313 - q_output_loss: 3.4814 - q_output_sparse_categorical_accuracy: 0.1886 - type_output_accuracy: 0.8507 - type_output_loss: 0.5278 - val_a_output_loss: 4.1498 - val_a_output_sparse_categorical_accuracy: 0.0833 - val_loss: 6.5367 - val_q_output_loss: 2.3453 - val_q_output_sparse_categorical_accuracy: 0.2481 - val_type_output_accuracy: 1.0000 - val_type_output_loss: 0.1433\n", "Epoch 12/30\n", "\u001b[1m18/18\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 19ms/step - a_output_loss: 3.6399 - a_output_sparse_categorical_accuracy: 0.0876 - loss: 7.1554 - q_output_loss: 3.3614 - q_output_sparse_categorical_accuracy: 0.1936 - type_output_accuracy: 0.8574 - type_output_loss: 0.5086 - val_a_output_loss: 4.0757 - val_a_output_sparse_categorical_accuracy: 0.0833 - val_loss: 6.3824 - val_q_output_loss: 2.2535 - val_q_output_sparse_categorical_accuracy: 0.2655 - val_type_output_accuracy: 1.0000 - val_type_output_loss: 0.1821\n", "Epoch 13/30\n", "\u001b[1m18/18\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 20ms/step - a_output_loss: 3.5070 - a_output_sparse_categorical_accuracy: 0.0884 - loss: 6.9808 - q_output_loss: 3.3093 - q_output_sparse_categorical_accuracy: 0.2028 - type_output_accuracy: 0.8412 - type_output_loss: 0.5476 - val_a_output_loss: 4.0188 - val_a_output_sparse_categorical_accuracy: 0.0833 - val_loss: 6.2431 - val_q_output_loss: 2.1877 - val_q_output_sparse_categorical_accuracy: 0.2655 - val_type_output_accuracy: 1.0000 - val_type_output_loss: 0.1261\n", "Epoch 14/30\n", "\u001b[1m18/18\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 19ms/step - a_output_loss: 3.4826 - a_output_sparse_categorical_accuracy: 0.0878 - loss: 6.9060 - q_output_loss: 3.2682 - q_output_sparse_categorical_accuracy: 0.2067 - type_output_accuracy: 0.8555 - type_output_loss: 0.5168 - val_a_output_loss: 3.9746 - val_a_output_sparse_categorical_accuracy: 0.0833 - val_loss: 6.1472 - val_q_output_loss: 2.1195 - val_q_output_sparse_categorical_accuracy: 0.2655 - val_type_output_accuracy: 1.0000 - val_type_output_loss: 0.1812\n", "Epoch 15/30\n", "\u001b[1m18/18\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 21ms/step - a_output_loss: 3.4678 - a_output_sparse_categorical_accuracy: 0.0893 - loss: 6.7741 - q_output_loss: 3.1490 - q_output_sparse_categorical_accuracy: 0.2098 - type_output_accuracy: 0.8540 - type_output_loss: 0.5190 - val_a_output_loss: 3.9369 - val_a_output_sparse_categorical_accuracy: 0.0833 - val_loss: 6.0406 - val_q_output_loss: 2.0556 - val_q_output_sparse_categorical_accuracy: 0.2655 - val_type_output_accuracy: 1.0000 - val_type_output_loss: 0.1640\n", "Epoch 16/30\n", "\u001b[1m18/18\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 20ms/step - a_output_loss: 3.3488 - a_output_sparse_categorical_accuracy: 0.0884 - loss: 6.5965 - q_output_loss: 3.1009 - q_output_sparse_categorical_accuracy: 0.2117 - type_output_accuracy: 0.8654 - type_output_loss: 0.4899 - val_a_output_loss: 3.8978 - val_a_output_sparse_categorical_accuracy: 0.0833 - val_loss: 5.9461 - val_q_output_loss: 1.9961 - val_q_output_sparse_categorical_accuracy: 0.2829 - val_type_output_accuracy: 1.0000 - val_type_output_loss: 0.1780\n", "Epoch 17/30\n", "\u001b[1m18/18\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 20ms/step - a_output_loss: 3.3488 - a_output_sparse_categorical_accuracy: 0.0915 - loss: 6.5416 - q_output_loss: 3.0414 - q_output_sparse_categorical_accuracy: 0.2198 - type_output_accuracy: 0.8612 - type_output_loss: 0.5010 - val_a_output_loss: 3.8658 - val_a_output_sparse_categorical_accuracy: 0.0840 - val_loss: 5.8560 - val_q_output_loss: 1.9388 - val_q_output_sparse_categorical_accuracy: 0.3009 - val_type_output_accuracy: 1.0000 - val_type_output_loss: 0.1751\n", "Epoch 18/30\n", "\u001b[1m18/18\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 19ms/step - a_output_loss: 3.2773 - a_output_sparse_categorical_accuracy: 0.0932 - loss: 6.4819 - q_output_loss: 3.0429 - q_output_sparse_categorical_accuracy: 0.2364 - type_output_accuracy: 0.8417 - type_output_loss: 0.5477 - val_a_output_loss: 3.8428 - val_a_output_sparse_categorical_accuracy: 0.0867 - val_loss: 5.7718 - val_q_output_loss: 1.8904 - val_q_output_sparse_categorical_accuracy: 0.3009 - val_type_output_accuracy: 1.0000 - val_type_output_loss: 0.1327\n", "Epoch 19/30\n", "\u001b[1m18/18\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 19ms/step - a_output_loss: 3.2296 - a_output_sparse_categorical_accuracy: 0.0988 - loss: 6.3270 - q_output_loss: 2.9284 - q_output_sparse_categorical_accuracy: 0.2408 - type_output_accuracy: 0.8385 - type_output_loss: 0.5597 - val_a_output_loss: 3.8217 - val_a_output_sparse_categorical_accuracy: 0.0867 - val_loss: 5.6938 - val_q_output_loss: 1.8297 - val_q_output_sparse_categorical_accuracy: 0.3009 - val_type_output_accuracy: 1.0000 - val_type_output_loss: 0.1456\n", "Epoch 20/30\n", "\u001b[1m18/18\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 20ms/step - a_output_loss: 3.1702 - a_output_sparse_categorical_accuracy: 0.0973 - loss: 6.1988 - q_output_loss: 2.8774 - q_output_sparse_categorical_accuracy: 0.2407 - type_output_accuracy: 0.8577 - type_output_loss: 0.5078 - val_a_output_loss: 3.8066 - val_a_output_sparse_categorical_accuracy: 0.0867 - val_loss: 5.6503 - val_q_output_loss: 1.7954 - val_q_output_sparse_categorical_accuracy: 0.3009 - val_type_output_accuracy: 1.0000 - val_type_output_loss: 0.1651\n", "Epoch 21/30\n", "\u001b[1m18/18\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 19ms/step - a_output_loss: 3.1889 - a_output_sparse_categorical_accuracy: 0.0984 - loss: 6.1064 - q_output_loss: 2.7800 - q_output_sparse_categorical_accuracy: 0.2473 - type_output_accuracy: 0.8713 - type_output_loss: 0.4700 - val_a_output_loss: 3.7826 - val_a_output_sparse_categorical_accuracy: 0.0860 - val_loss: 5.5638 - val_q_output_loss: 1.7345 - val_q_output_sparse_categorical_accuracy: 0.3009 - val_type_output_accuracy: 1.0000 - val_type_output_loss: 0.1591\n", "Epoch 22/30\n", "\u001b[1m18/18\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 19ms/step - a_output_loss: 3.1265 - a_output_sparse_categorical_accuracy: 0.0988 - loss: 6.0308 - q_output_loss: 2.7526 - q_output_sparse_categorical_accuracy: 0.2497 - type_output_accuracy: 0.8545 - type_output_loss: 0.5097 - val_a_output_loss: 3.7677 - val_a_output_sparse_categorical_accuracy: 0.0860 - val_loss: 5.5064 - val_q_output_loss: 1.7013 - val_q_output_sparse_categorical_accuracy: 0.3009 - val_type_output_accuracy: 1.0000 - val_type_output_loss: 0.1283\n", "Epoch 23/30\n", "\u001b[1m18/18\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 19ms/step - a_output_loss: 3.1438 - a_output_sparse_categorical_accuracy: 0.0992 - loss: 5.9302 - q_output_loss: 2.6496 - q_output_sparse_categorical_accuracy: 0.2625 - type_output_accuracy: 0.8705 - type_output_loss: 0.4574 - val_a_output_loss: 3.7538 - val_a_output_sparse_categorical_accuracy: 0.0860 - val_loss: 5.4513 - val_q_output_loss: 1.6560 - val_q_output_sparse_categorical_accuracy: 0.3009 - val_type_output_accuracy: 1.0000 - val_type_output_loss: 0.1424\n", "Epoch 24/30\n", "\u001b[1m18/18\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 20ms/step - a_output_loss: 3.0850 - a_output_sparse_categorical_accuracy: 0.1027 - loss: 5.9167 - q_output_loss: 2.6764 - q_output_sparse_categorical_accuracy: 0.2587 - type_output_accuracy: 0.8494 - type_output_loss: 0.5071 - val_a_output_loss: 3.7385 - val_a_output_sparse_categorical_accuracy: 0.0860 - val_loss: 5.4008 - val_q_output_loss: 1.6286 - val_q_output_sparse_categorical_accuracy: 0.3009 - val_type_output_accuracy: 1.0000 - val_type_output_loss: 0.1169\n", "Epoch 25/30\n", "\u001b[1m18/18\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 19ms/step - a_output_loss: 3.0401 - a_output_sparse_categorical_accuracy: 0.1000 - loss: 5.8064 - q_output_loss: 2.6111 - q_output_sparse_categorical_accuracy: 0.2587 - type_output_accuracy: 0.8506 - type_output_loss: 0.5047 - val_a_output_loss: 3.7129 - val_a_output_sparse_categorical_accuracy: 0.0880 - val_loss: 5.3261 - val_q_output_loss: 1.5804 - val_q_output_sparse_categorical_accuracy: 0.3009 - val_type_output_accuracy: 1.0000 - val_type_output_loss: 0.1143\n", "Epoch 26/30\n", "\u001b[1m18/18\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 19ms/step - a_output_loss: 3.0136 - a_output_sparse_categorical_accuracy: 0.1035 - loss: 5.7400 - q_output_loss: 2.5725 - q_output_sparse_categorical_accuracy: 0.2655 - type_output_accuracy: 0.8451 - type_output_loss: 0.5104 - val_a_output_loss: 3.7015 - val_a_output_sparse_categorical_accuracy: 0.0874 - val_loss: 5.2855 - val_q_output_loss: 1.5530 - val_q_output_sparse_categorical_accuracy: 0.3009 - val_type_output_accuracy: 1.0000 - val_type_output_loss: 0.1083\n", "Epoch 27/30\n", "\u001b[1m18/18\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 19ms/step - a_output_loss: 3.0272 - a_output_sparse_categorical_accuracy: 0.1043 - loss: 5.7157 - q_output_loss: 2.5314 - q_output_sparse_categorical_accuracy: 0.2678 - type_output_accuracy: 0.8359 - type_output_loss: 0.5309 - val_a_output_loss: 3.6950 - val_a_output_sparse_categorical_accuracy: 0.0867 - val_loss: 5.2692 - val_q_output_loss: 1.5477 - val_q_output_sparse_categorical_accuracy: 0.3009 - val_type_output_accuracy: 1.0000 - val_type_output_loss: 0.0954\n", "Epoch 28/30\n", "\u001b[1m18/18\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 19ms/step - a_output_loss: 2.9941 - a_output_sparse_categorical_accuracy: 0.1071 - loss: 5.5785 - q_output_loss: 2.4341 - q_output_sparse_categorical_accuracy: 0.2799 - type_output_accuracy: 0.8506 - type_output_loss: 0.5006 - val_a_output_loss: 3.6802 - val_a_output_sparse_categorical_accuracy: 0.0867 - val_loss: 5.2171 - val_q_output_loss: 1.5063 - val_q_output_sparse_categorical_accuracy: 0.3009 - val_type_output_accuracy: 1.0000 - val_type_output_loss: 0.1080\n", "Epoch 29/30\n", "\u001b[1m18/18\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 19ms/step - a_output_loss: 2.9584 - a_output_sparse_categorical_accuracy: 0.1039 - loss: 5.5143 - q_output_loss: 2.3997 - q_output_sparse_categorical_accuracy: 0.2848 - type_output_accuracy: 0.8399 - type_output_loss: 0.5188 - val_a_output_loss: 3.6568 - val_a_output_sparse_categorical_accuracy: 0.0860 - val_loss: 5.1558 - val_q_output_loss: 1.4744 - val_q_output_sparse_categorical_accuracy: 0.3009 - val_type_output_accuracy: 1.0000 - val_type_output_loss: 0.0881\n", "Epoch 30/30\n", "\u001b[1m18/18\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 19ms/step - a_output_loss: 2.9210 - a_output_sparse_categorical_accuracy: 0.1048 - loss: 5.4172 - q_output_loss: 2.3643 - q_output_sparse_categorical_accuracy: 0.2832 - type_output_accuracy: 0.8666 - type_output_loss: 0.4503 - val_a_output_loss: 3.6499 - val_a_output_sparse_categorical_accuracy: 0.0860 - val_loss: 5.1438 - val_q_output_loss: 1.4678 - val_q_output_sparse_categorical_accuracy: 0.3009 - val_type_output_accuracy: 1.0000 - val_type_output_loss: 0.0931\n" ] } ], "source": [ "losses = {\n", " \"q_output\": \"sparse_categorical_crossentropy\",\n", " \"a_output\": \"sparse_categorical_crossentropy\",\n", " \"type_output\": \"sparse_categorical_crossentropy\",\n", "}\n", "loss_weights = {\"q_output\": 1.0, \"a_output\": 1.0, \"type_output\": 0.3}\n", "\n", "model.compile(\n", " optimizer=\"adam\",\n", " loss=losses,\n", " loss_weights=loss_weights,\n", " metrics={\n", " \"q_output\": \"sparse_categorical_accuracy\",\n", " \"a_output\": \"sparse_categorical_accuracy\",\n", " \"type_output\": \"accuracy\",\n", " },\n", ")\n", "\n", "history = model.fit(\n", " [X_tok, X_ner, X_srl, dec_q_in, dec_a_in],\n", " [dec_q_out, dec_a_out, y_type],\n", " validation_split=0.1,\n", " epochs=30,\n", " batch_size=64,\n", " callbacks=[tf.keras.callbacks.EarlyStopping(patience=4, restore_best_weights=True)],\n", " verbose=1,\n", ")\n", "\n", "model.save(\"full_seq2seq.keras\")\n", "\n", "import json\n", "import pickle\n", "\n", "# def save_vocab(vocab, path):\n", "# with open(path, \"w\", encoding=\"utf-8\") as f:\n", "# json.dump(vocab, f, ensure_ascii=False, indent=2)\n", "\n", "# # Simpan semua vocab\n", "# save_vocab(vocab_tok, \"vocab_tok.json\")\n", "# save_vocab(vocab_ner, \"vocab_ner.json\")\n", "# save_vocab(vocab_srl, \"vocab_srl.json\")\n", "# save_vocab(vocab_q, \"vocab_q.json\")\n", "# save_vocab(vocab_a, \"vocab_a.json\")\n", "# save_vocab(vocab_typ, \"vocab_typ.json\")\n", "\n", "\n", "def save_vocab_pkl(vocab, path):\n", " with open(path, \"wb\") as f:\n", " pickle.dump(vocab, f)\n", "\n", "\n", "# Simpan semua vocab\n", "save_vocab_pkl(vocab_tok, \"vocab_tok.pkl\")\n", "save_vocab_pkl(vocab_ner, \"vocab_ner.pkl\")\n", "save_vocab_pkl(vocab_srl, \"vocab_srl.pkl\")\n", "save_vocab_pkl(vocab_q, \"vocab_q.pkl\")\n", "save_vocab_pkl(vocab_a, \"vocab_a.pkl\")\n", "save_vocab_pkl(vocab_typ, \"vocab_typ.pkl\")" ] }, { "cell_type": "code", "execution_count": 46, "id": "3355c0c7", "metadata": {}, "outputs": [], "source": [ "import tensorflow as tf\n", "import numpy as np\n", "import pickle\n", "from tensorflow.keras.models import load_model, Model\n", "from tensorflow.keras.layers import Input, Concatenate\n", "\n", "# === Load Model Utama ===\n", "model = load_model(\"full_seq2seq.keras\")\n", "\n", "\n", "# === Load Vocabulary dari .pkl ===\n", "def load_vocab(path):\n", " with open(path, \"rb\") as f:\n", " return pickle.load(f)\n", "\n", "\n", "vocab_tok = load_vocab(\"vocab_tok.pkl\")\n", "vocab_ner = load_vocab(\"vocab_ner.pkl\")\n", "vocab_srl = load_vocab(\"vocab_srl.pkl\")\n", "vocab_q = load_vocab(\"vocab_q.pkl\")\n", "vocab_a = load_vocab(\"vocab_a.pkl\")\n", "vocab_typ = load_vocab(\"vocab_typ.pkl\")\n", "\n", "inv_vocab_q = {v: k for k, v in vocab_q.items()}\n", "inv_vocab_a = {v: k for k, v in vocab_a.items()}\n", "\n", "# === Build Encoder Model ===\n", "MAX_SENT = model.input_shape[0][1] # Ambil shape dari model yang diload\n", "MAX_Q = model.input_shape[3][1] # Max length for question\n", "MAX_A = model.input_shape[4][1] # Max length for answer\n", "\n", "inp_tok_g = Input(shape=(MAX_SENT,), name=\"tok_in_g\")\n", "inp_ner_g = Input(shape=(MAX_SENT,), name=\"ner_in_g\")\n", "inp_srl_g = Input(shape=(MAX_SENT,), name=\"srl_in_g\")\n", "\n", "emb_tok = model.get_layer(\"embedding_tok\").call(inp_tok_g)\n", "emb_ner = model.get_layer(\"embedding_ner\").call(inp_ner_g)\n", "emb_srl = model.get_layer(\"embedding_srl\").call(inp_srl_g)\n", "\n", "enc_concat = Concatenate(name=\"concat_encoder\")([emb_tok, emb_ner, emb_srl])\n", "\n", "encoder_lstm = model.get_layer(\"encoder_lstm\")\n", "enc_out, state_h, state_c = encoder_lstm(enc_concat)\n", "\n", "# Create encoder model with full output including enc_out\n", "encoder_model = Model(\n", " inputs=[inp_tok_g, inp_ner_g, inp_srl_g],\n", " outputs=[enc_out, state_h, state_c],\n", " name=\"encoder_model\",\n", ")\n", "\n", "# === Build Decoder for Question ===\n", "dec_q_inp = Input(shape=(1,), name=\"dec_q_in\")\n", "dec_emb_q = model.get_layer(\"embedding_q_decoder\").call(dec_q_inp)\n", "\n", "state_h_dec = Input(shape=(units,), name=\"state_h_dec\")\n", "state_c_dec = Input(shape=(units,), name=\"state_c_dec\")\n", "\n", "lstm_decoder_q = model.get_layer(\"lstm_q_decoder\")\n", "\n", "dec_out_q, state_h_q, state_c_q = lstm_decoder_q(\n", " dec_emb_q, initial_state=[state_h_dec, state_c_dec]\n", ")\n", "\n", "q_time_dist_layer = model.get_layer(\"q_output\")\n", "dense_q = q_time_dist_layer.layer\n", "q_output = dense_q(dec_out_q)\n", "\n", "decoder_q = Model(\n", " inputs=[dec_q_inp, state_h_dec, state_c_dec],\n", " outputs=[q_output, state_h_q, state_c_q],\n", " name=\"decoder_question_model\",\n", ")\n", "\n", "# === Build Decoder for Answer ===\n", "dec_a_inp = Input(shape=(1,), name=\"dec_a_in\")\n", "dec_emb_a = model.get_layer(\"embedding_a_decoder\").call(dec_a_inp)\n", "\n", "state_h_a = Input(shape=(units,), name=\"state_h_a\")\n", "state_c_a = Input(shape=(units,), name=\"state_c_a\")\n", "\n", "lstm_decoder_a = model.get_layer(\"lstm_a_decoder\")\n", "\n", "dec_out_a, state_h_a_out, state_c_a_out = lstm_decoder_a(\n", " dec_emb_a, initial_state=[state_h_a, state_c_a]\n", ")\n", "\n", "a_time_dist_layer = model.get_layer(\"a_output\")\n", "dense_a = a_time_dist_layer.layer\n", "a_output = dense_a(dec_out_a)\n", "\n", "decoder_a = Model(\n", " inputs=[dec_a_inp, state_h_a, state_c_a],\n", " outputs=[a_output, state_h_a_out, state_c_a_out],\n", " name=\"decoder_answer_model\",\n", ")\n", "\n", "# === Build Classifier for Question Type ===\n", "type_dense = model.get_layer(\"type_output\")\n", "type_out = type_dense(enc_out)\n", "\n", "classifier_model = Model(\n", " inputs=[inp_tok_g, inp_ner_g, inp_srl_g], outputs=type_out, name=\"classifier_model\"\n", ")" ] }, { "cell_type": "code", "execution_count": 47, "id": "d406e6ff", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Generated Question: siapa yang lahir di ___\n", "Generated Answer : 4 juli 1927\n", "Question Type : isian\n" ] } ], "source": [ "def encode(seq, vmap):\n", " return [vmap.get(tok, vmap[\"\"]) for tok in seq]\n", "\n", "\n", "def encode_and_pad(seq, vmap, max_len=MAX_SENT):\n", " encoded = [vmap.get(tok, vmap[\"\"]) for tok in seq]\n", " # Pad with vocab[\"\"] to the right if sequence is shorter than max_len\n", " padded = encoded + [vmap[\"\"]] * (max_len - len(encoded))\n", " return padded[:max_len] # Ensure it doesn't exceed max_len\n", "\n", "\n", "def greedy_decode(tokens, ner, srl, max_q=20, max_a=10):\n", " # --- encode encoder inputs -------------------------------------------\n", " if isinstance(tokens, np.ndarray):\n", " enc_tok = tokens\n", " enc_ner = ner\n", " enc_srl = srl\n", " else:\n", " enc_tok = np.array([encode_and_pad(tokens, vocab_tok, MAX_SENT)])\n", " enc_ner = np.array([encode_and_pad(ner, vocab_ner, MAX_SENT)])\n", " enc_srl = np.array([encode_and_pad(srl, vocab_srl, MAX_SENT)])\n", "\n", " # --- Get encoder outputs ---\n", " enc_out, h, c = encoder_model.predict([enc_tok, enc_ner, enc_srl], verbose=0)\n", "\n", " # QUESTION Decoding\n", " tgt = np.array([[vocab_q[\"\"]]])\n", " question_ids = []\n", " for _ in range(max_q):\n", " logits, h, c = decoder_q.predict([tgt, h, c], verbose=0)\n", " next_id = int(logits[0, 0].argmax()) # Get the predicted token ID\n", " if next_id == vocab_q[\"\"]:\n", " break\n", " question_ids.append(next_id)\n", " tgt = np.array([[next_id]]) # Feed the predicted token back as input\n", "\n", " # ANSWER Decoding - use encoder outputs again for fresh state\n", " _, h, c = encoder_model.predict([enc_tok, enc_ner, enc_srl], verbose=0)\n", " tgt = np.array([[vocab_a[\"\"]]])\n", " answer_ids = []\n", " for _ in range(max_a):\n", " logits, h, c = decoder_a.predict([tgt, h, c], verbose=0)\n", " next_id = int(logits[0, 0].argmax())\n", " if next_id == vocab_a[\"\"]:\n", " break\n", " answer_ids.append(next_id)\n", " tgt = np.array([[next_id]])\n", "\n", " # Question Type\n", " qtype_logits = classifier_model.predict([enc_tok, enc_ner, enc_srl], verbose=0)\n", " qtype_id = int(qtype_logits.argmax())\n", "\n", " # Final output\n", " question = [inv_vocab_q.get(i, \"\") for i in question_ids]\n", " answer = [inv_vocab_a.get(i, \"\") for i in answer_ids]\n", " q_type = [k for k, v in vocab_typ.items() if v == qtype_id][0]\n", "\n", " return question, answer, q_type\n", "\n", "\n", "def test_model():\n", " test_data = {\n", " \"tokens\": [\n", " \"joko\",\n", " \"opik\",\n", " \"widodo\",\n", " \"lahir\",\n", " \"pada\",\n", " \"27\",\n", " \"maret\",\n", " \"1992\",\n", " \"di\",\n", " \"solo\",\n", " ],\n", " \"ner\": [\n", " \"B-PER\",\n", " \"I-PER\",\n", " \"I-PER\",\n", " \"V\",\n", " \"O\",\n", " \"B-DATE\",\n", " \"I-DATE\",\n", " \"I-DATE\",\n", " \"O\",\n", " \"B-LOC\",\n", " ],\n", " \"srl\": [\n", " \"ARG0\",\n", " \"ARG0\",\n", " \"ARG0\",\n", " \"V\",\n", " \"O\",\n", " \"ARGM-TMP\",\n", " \"ARGM-TMP\",\n", " \"ARGM-TMP\",\n", " \"O\",\n", " \"ARGM-LOC\",\n", " ],\n", " }\n", " # tokens = [\n", " # \"soekarno\",\n", " # \"membacakan\",\n", " # \"teks\",\n", " # \"proklamasi\",\n", " # \"pada\",\n", " # \"17\",\n", " # \"agustus\",\n", " # \"1945\",\n", " # ]\n", " # ner_tags = [\"B-PER\", \"O\", \"O\", \"O\", \"O\", \"B-DATE\", \"I-DATE\", \"I-DATE\"]\n", " # srl_tags = [\"ARG0\", \"V\", \"ARG1\", \"ARG1\", \"O\", \"ARGM-TMP\", \"ARGM-TMP\", \"ARGM-TMP\"]\n", "\n", " question, answer, q_type = greedy_decode(\n", " test_data[\"tokens\"], test_data[\"ner\"], test_data[\"srl\"]\n", " )\n", " print(f\"Generated Question: {' '.join(question)}\")\n", " print(f\"Generated Answer : {' '.join(answer)}\")\n", " print(f\"Question Type : {q_type}\")\n", "\n", "\n", "test_model()" ] }, { "cell_type": "code", "execution_count": 48, "id": "5adde3c3", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "BLEU : 10.75%\n", "ROUGE1: 27.63% | ROUGE-L: 27.63%\n" ] } ], "source": [ "from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction\n", "from rouge_score import rouge_scorer\n", "\n", "smoothie = SmoothingFunction().method4\n", "scorer = rouge_scorer.RougeScorer([\"rouge1\", \"rougeL\"], use_stemmer=True)\n", "\n", "\n", "# Helper to strip special ids\n", "def strip_special(ids, vocab):\n", " pad = vocab[\"\"] if \"\" in vocab else None\n", " eos = vocab[\"\"]\n", " return [i for i in ids if i not in (pad, eos)]\n", "\n", "\n", "def ids_to_text(ids, inv_vocab):\n", " return \" \".join(inv_vocab[i] for i in ids)\n", "\n", "\n", "# ---- evaluation over a set of indices ----\n", "import random\n", "\n", "\n", "def evaluate(indices=None):\n", " if indices is None:\n", " indices = random.sample(range(len(X_tok)), k=min(100, len(X_tok)))\n", "\n", " bleu_scores, rou1, rouL = [], [], []\n", " for idx in indices:\n", " # Ground truth\n", " gt_q = strip_special(dec_q_out[idx], vocab_q)\n", " gt_a = strip_special(dec_a_out[idx], vocab_a)\n", " # Prediction\n", " q_pred, a_pred, _ = greedy_decode(\n", " X_tok[idx : idx + 1], X_ner[idx : idx + 1], X_srl[idx : idx + 1]\n", " )\n", "\n", " # BLEU on question tokens\n", " bleu_scores.append(\n", " sentence_bleu(\n", " [[inv_vocab_q[i] for i in gt_q]], q_pred, smoothing_function=smoothie\n", " )\n", " )\n", " # ROUGE on question strings\n", " r = scorer.score(ids_to_text(gt_q, inv_vocab_q), \" \".join(q_pred))\n", " rou1.append(r[\"rouge1\"].fmeasure)\n", " rouL.append(r[\"rougeL\"].fmeasure)\n", "\n", " print(f\"BLEU : {np.mean(bleu_scores) * 100:.2f}%\")\n", " print(f\"ROUGE1: {np.mean(rou1) * 100:.2f}% | ROUGE-L: {np.mean(rouL) * 100:.2f}%\")\n", "\n", "\n", "evaluate()" ] } ], "metadata": { "kernelspec": { "display_name": "myenv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.16" } }, "nbformat": 4, "nbformat_minor": 5 }