867 lines
129 KiB
Plaintext
867 lines
129 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 232,
|
|
"id": "02cbdb19",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import numpy as np\n",
|
|
"import json\n",
|
|
"import random\n",
|
|
"import tensorflow as tf\n",
|
|
"from tensorflow.keras.preprocessing.text import Tokenizer\n",
|
|
"from tensorflow.keras.preprocessing.sequence import pad_sequences\n",
|
|
"from tensorflow.keras.models import Model, load_model\n",
|
|
"from tensorflow.keras.layers import (\n",
|
|
" Input,\n",
|
|
" LSTM,\n",
|
|
" Dense,\n",
|
|
" Embedding,\n",
|
|
" Bidirectional,\n",
|
|
" Concatenate,\n",
|
|
" Dropout,\n",
|
|
")\n",
|
|
"from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping\n",
|
|
"from sklearn.model_selection import train_test_split\n",
|
|
"import matplotlib.pyplot as plt\n",
|
|
"import re\n",
|
|
"from rouge_score import rouge_scorer\n",
|
|
"from nltk.translate.bleu_score import sentence_bleu\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 233,
|
|
"id": "f9c0af74",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"total context 885\n",
|
|
"total question 1547\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# Load data\n",
|
|
"with open(\"../dataset/valid_data.json\", \"r\") as f:\n",
|
|
" data = json.load(f)\n",
|
|
"\n",
|
|
"\n",
|
|
"# Preprocessing function\n",
|
|
"def preprocess_text(text):\n",
|
|
" \"\"\"Melakukan preprocessing teks dasar\"\"\"\n",
|
|
" text = text.lower()\n",
|
|
" text = re.sub(r\"\\s+\", \" \", text).strip()\n",
|
|
"\n",
|
|
" return text\n",
|
|
"\n",
|
|
"\n",
|
|
"# Persiapkan data untuk model prediksi pertanyaan\n",
|
|
"def prepare_question_prediction_data(data):\n",
|
|
" \"\"\"Siapkan data untuk model prediksi pertanyaan\"\"\"\n",
|
|
" contexts = []\n",
|
|
" tokens_list = []\n",
|
|
" ner_list = []\n",
|
|
" srl_list = []\n",
|
|
" questions = []\n",
|
|
" q_types = []\n",
|
|
"\n",
|
|
" for item in data:\n",
|
|
" \n",
|
|
" for qa in item[\"qas\"]:\n",
|
|
" # if qa[\"question\"] == \"\":\n",
|
|
" # continue\n",
|
|
" context = preprocess_text(item[\"context\"])\n",
|
|
" contexts.append(context)\n",
|
|
" token = [preprocess_text(token) for token in item[\"tokens\"]]\n",
|
|
" tokens_list.append(token)\n",
|
|
" ner_list.append(item[\"ner\"])\n",
|
|
" srl_list.append(item[\"srl\"])\n",
|
|
" questions.append(preprocess_text(qa[\"question\"]))\n",
|
|
" q_types.append(qa[\"type\"])\n",
|
|
" # Tidak mengambil jawaban (answer) sebagai input\n",
|
|
" print(\"total context \", len(data))\n",
|
|
" print(\"total question \", len(questions))\n",
|
|
" return contexts, tokens_list, ner_list, srl_list, questions, q_types\n",
|
|
"\n",
|
|
"\n",
|
|
"contexts, tokens_list, ner_list, srl_list, questions, q_types = (\n",
|
|
" prepare_question_prediction_data(data)\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 234,
|
|
"id": "952f71da",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Tokenizer untuk teks (context, question)\n",
|
|
"max_vocab_size = 10000\n",
|
|
"tokenizer = Tokenizer(num_words=max_vocab_size, oov_token=\"<OOV>\")\n",
|
|
"all_texts = contexts + questions + [\" \".join(item) for item in tokens_list]\n",
|
|
"tokenizer.fit_on_texts(all_texts)\n",
|
|
"vocab_size = len(tokenizer.word_index) + 1\n",
|
|
"\n",
|
|
"# Encoding untuk NER\n",
|
|
"ner_tokenizer = Tokenizer(oov_token=\"<OOV>\")\n",
|
|
"ner_tokenizer.fit_on_texts([\" \".join(ner) for ner in ner_list])\n",
|
|
"ner_vocab_size = len(ner_tokenizer.word_index) + 1\n",
|
|
"\n",
|
|
"# Encoding untuk SRL\n",
|
|
"srl_tokenizer = Tokenizer(oov_token=\"<OOV>\")\n",
|
|
"srl_tokenizer.fit_on_texts([\" \".join(srl) for srl in srl_list])\n",
|
|
"srl_vocab_size = len(srl_tokenizer.word_index) + 1\n",
|
|
"\n",
|
|
"# Encoding untuk tipe pertanyaan\n",
|
|
"q_type_tokenizer = Tokenizer()\n",
|
|
"q_type_tokenizer.fit_on_texts(q_types)\n",
|
|
"q_type_vocab_size = len(q_type_tokenizer.word_index) + 1\n",
|
|
"\n",
|
|
"\n",
|
|
"# Konversi token, ner, srl ke sequences\n",
|
|
"def tokens_to_sequences(tokens, ner, srl):\n",
|
|
" \"\"\"Konversi token, ner, dan srl ke sequences\"\"\"\n",
|
|
" token_seqs = [tokenizer.texts_to_sequences([\" \".join(t)])[0] for t in tokens]\n",
|
|
" ner_seqs = [ner_tokenizer.texts_to_sequences([\" \".join(n)])[0] for n in ner]\n",
|
|
" srl_seqs = [srl_tokenizer.texts_to_sequences([\" \".join(s)])[0] for s in srl]\n",
|
|
" return token_seqs, ner_seqs, srl_seqs\n",
|
|
"\n",
|
|
"\n",
|
|
"# Sequences\n",
|
|
"context_seqs = tokenizer.texts_to_sequences(contexts)\n",
|
|
"question_seqs = tokenizer.texts_to_sequences(questions)\n",
|
|
"token_seqs, ner_seqs, srl_seqs = tokens_to_sequences(tokens_list, ner_list, srl_list)\n",
|
|
"\n",
|
|
"# Menentukan panjang maksimum untuk padding\n",
|
|
"max_context_len = max([len(seq) for seq in context_seqs])\n",
|
|
"max_question_len = max([len(seq) for seq in question_seqs])\n",
|
|
"max_token_len = max([len(seq) for seq in token_seqs])\n",
|
|
"\n",
|
|
"\n",
|
|
"# Pad sequences untuk memastikan semua input sama panjang\n",
|
|
"def pad_all_sequences(context_seqs, token_seqs, ner_seqs, srl_seqs, question_seqs):\n",
|
|
" \"\"\"Padding semua sequences\"\"\"\n",
|
|
" context_padded = pad_sequences(context_seqs, maxlen=max_context_len, padding=\"post\")\n",
|
|
" token_padded = pad_sequences(token_seqs, maxlen=max_token_len, padding=\"post\")\n",
|
|
" ner_padded = pad_sequences(ner_seqs, maxlen=max_token_len, padding=\"post\")\n",
|
|
" srl_padded = pad_sequences(srl_seqs, maxlen=max_token_len, padding=\"post\")\n",
|
|
" question_padded = pad_sequences(\n",
|
|
" question_seqs, maxlen=max_question_len, padding=\"post\"\n",
|
|
" )\n",
|
|
" return (\n",
|
|
" context_padded,\n",
|
|
" token_padded,\n",
|
|
" ner_padded,\n",
|
|
" srl_padded,\n",
|
|
" question_padded,\n",
|
|
" )\n",
|
|
"\n",
|
|
"\n",
|
|
"# Encode tipe pertanyaan\n",
|
|
"q_type_indices = []\n",
|
|
"for q_type in q_types:\n",
|
|
" q_type_idx = q_type_tokenizer.word_index.get(q_type, 0)\n",
|
|
" q_type_indices.append(q_type_idx)\n",
|
|
"\n",
|
|
"# Konversi ke numpy array\n",
|
|
"q_type_indices = np.array(q_type_indices)\n",
|
|
"\n",
|
|
"# One-hot encode tipe pertanyaan\n",
|
|
"q_type_categorical = tf.keras.utils.to_categorical(\n",
|
|
" q_type_indices, num_classes=q_type_vocab_size\n",
|
|
")\n",
|
|
"\n",
|
|
"# Pad sequences\n",
|
|
"context_padded, token_padded, ner_padded, srl_padded, question_padded = (\n",
|
|
" pad_all_sequences(context_seqs, token_seqs, ner_seqs, srl_seqs, question_seqs)\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 235,
|
|
"id": "37ffc0e5",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"\n",
|
|
"indices = list(range(len(context_padded)))\n",
|
|
"train_indices, test_indices = train_test_split(indices, test_size=0.1, random_state=42)\n",
|
|
"\n",
|
|
"\n",
|
|
"def get_subset(data, indices):\n",
|
|
" return np.array([data[i] for i in indices])\n",
|
|
"\n",
|
|
"\n",
|
|
"# Train data\n",
|
|
"train_context = get_subset(context_padded, train_indices)\n",
|
|
"train_token = get_subset(token_padded, train_indices)\n",
|
|
"train_ner = get_subset(ner_padded, train_indices)\n",
|
|
"train_srl = get_subset(srl_padded, train_indices)\n",
|
|
"train_q_type = get_subset(q_type_categorical, train_indices)\n",
|
|
"train_question = get_subset(question_padded, train_indices)\n",
|
|
"\n",
|
|
"# Test data\n",
|
|
"test_context = get_subset(context_padded, test_indices)\n",
|
|
"test_token = get_subset(token_padded, test_indices)\n",
|
|
"test_ner = get_subset(ner_padded, test_indices)\n",
|
|
"test_srl = get_subset(srl_padded, test_indices)\n",
|
|
"test_q_type = get_subset(q_type_categorical, test_indices)\n",
|
|
"test_question = get_subset(question_padded, test_indices)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 236,
|
|
"id": "df580682",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">Model: \"functional_28\"</span>\n",
|
|
"</pre>\n"
|
|
],
|
|
"text/plain": [
|
|
"\u001b[1mModel: \"functional_28\"\u001b[0m\n"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓\n",
|
|
"┃<span style=\"font-weight: bold\"> Layer (type) </span>┃<span style=\"font-weight: bold\"> Output Shape </span>┃<span style=\"font-weight: bold\"> Param # </span>┃<span style=\"font-weight: bold\"> Connected to </span>┃\n",
|
|
"┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩\n",
|
|
"│ context_input │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">38</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ - │\n",
|
|
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">InputLayer</span>) │ │ │ │\n",
|
|
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
|
|
"│ token_input │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">38</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ - │\n",
|
|
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">InputLayer</span>) │ │ │ │\n",
|
|
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
|
|
"│ ner_input │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">38</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ - │\n",
|
|
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">InputLayer</span>) │ │ │ │\n",
|
|
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
|
|
"│ srl_input │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">38</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ - │\n",
|
|
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">InputLayer</span>) │ │ │ │\n",
|
|
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
|
|
"│ text_embedding │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">38</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">100</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">223,000</span> │ context_input[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>]… │\n",
|
|
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Embedding</span>) │ │ │ token_input[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
|
|
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
|
|
"│ ner_embedding │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">38</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">50</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">1,700</span> │ ner_input[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
|
|
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Embedding</span>) │ │ │ │\n",
|
|
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
|
|
"│ srl_embedding │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">38</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">50</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">1,500</span> │ srl_input[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
|
|
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Embedding</span>) │ │ │ │\n",
|
|
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
|
|
"│ bidirectional_56 │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">38</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">256</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">234,496</span> │ text_embedding[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>… │\n",
|
|
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Bidirectional</span>) │ │ │ │\n",
|
|
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
|
|
"│ token_features │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">38</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">200</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ text_embedding[<span style=\"color: #00af00; text-decoration-color: #00af00\">1</span>… │\n",
|
|
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Concatenate</span>) │ │ │ ner_embedding[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>]… │\n",
|
|
"│ │ │ │ srl_embedding[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>]… │\n",
|
|
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
|
|
"│ context_attention │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">38</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">256</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ bidirectional_56… │\n",
|
|
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Attention</span>) │ │ │ bidirectional_56… │\n",
|
|
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
|
|
"│ bidirectional_57 │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">38</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">256</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">336,896</span> │ token_features[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>… │\n",
|
|
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Bidirectional</span>) │ │ │ │\n",
|
|
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
|
|
"│ context_att_pool │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">256</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ context_attentio… │\n",
|
|
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">GlobalMaxPooling1…</span> │ │ │ │\n",
|
|
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
|
|
"│ token_pool │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">256</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ bidirectional_57… │\n",
|
|
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">GlobalMaxPooling1…</span> │ │ │ │\n",
|
|
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
|
|
"│ q_type_input │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">5</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ - │\n",
|
|
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">InputLayer</span>) │ │ │ │\n",
|
|
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
|
|
"│ all_features │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">517</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ context_att_pool… │\n",
|
|
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Concatenate</span>) │ │ │ token_pool[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>], │\n",
|
|
"│ │ │ │ q_type_input[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">…</span> │\n",
|
|
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
|
|
"│ dense_1 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Dense</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">512</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">265,216</span> │ all_features[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">…</span> │\n",
|
|
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
|
|
"│ dropout_56 │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">512</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ dense_1[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
|
|
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Dropout</span>) │ │ │ │\n",
|
|
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
|
|
"│ dense_2 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Dense</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">256</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">131,328</span> │ dropout_56[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
|
|
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
|
|
"│ dropout_57 │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">256</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ dense_2[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
|
|
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Dropout</span>) │ │ │ │\n",
|
|
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
|
|
"│ decoder_input │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">256</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">65,792</span> │ dropout_57[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
|
|
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Dense</span>) │ │ │ │\n",
|
|
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
|
|
"│ repeat_vector_28 │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">15</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">256</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ decoder_input[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>]… │\n",
|
|
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">RepeatVector</span>) │ │ │ │\n",
|
|
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
|
|
"│ decoder_lstm (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">LSTM</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">15</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">256</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">525,312</span> │ repeat_vector_28… │\n",
|
|
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
|
|
"│ time_distributed_28 │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">15</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">2230</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">573,110</span> │ decoder_lstm[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">…</span> │\n",
|
|
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">TimeDistributed</span>) │ │ │ │\n",
|
|
"└─────────────────────┴───────────────────┴────────────┴───────────────────┘\n",
|
|
"</pre>\n"
|
|
],
|
|
"text/plain": [
|
|
"┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓\n",
|
|
"┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mConnected to \u001b[0m\u001b[1m \u001b[0m┃\n",
|
|
"┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩\n",
|
|
"│ context_input │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m38\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n",
|
|
"│ (\u001b[38;5;33mInputLayer\u001b[0m) │ │ │ │\n",
|
|
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
|
|
"│ token_input │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m38\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n",
|
|
"│ (\u001b[38;5;33mInputLayer\u001b[0m) │ │ │ │\n",
|
|
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
|
|
"│ ner_input │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m38\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n",
|
|
"│ (\u001b[38;5;33mInputLayer\u001b[0m) │ │ │ │\n",
|
|
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
|
|
"│ srl_input │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m38\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n",
|
|
"│ (\u001b[38;5;33mInputLayer\u001b[0m) │ │ │ │\n",
|
|
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
|
|
"│ text_embedding │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m38\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m223,000\u001b[0m │ context_input[\u001b[38;5;34m0\u001b[0m]… │\n",
|
|
"│ (\u001b[38;5;33mEmbedding\u001b[0m) │ │ │ token_input[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
|
|
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
|
|
"│ ner_embedding │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m38\u001b[0m, \u001b[38;5;34m50\u001b[0m) │ \u001b[38;5;34m1,700\u001b[0m │ ner_input[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
|
|
"│ (\u001b[38;5;33mEmbedding\u001b[0m) │ │ │ │\n",
|
|
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
|
|
"│ srl_embedding │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m38\u001b[0m, \u001b[38;5;34m50\u001b[0m) │ \u001b[38;5;34m1,500\u001b[0m │ srl_input[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
|
|
"│ (\u001b[38;5;33mEmbedding\u001b[0m) │ │ │ │\n",
|
|
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
|
|
"│ bidirectional_56 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m38\u001b[0m, \u001b[38;5;34m256\u001b[0m) │ \u001b[38;5;34m234,496\u001b[0m │ text_embedding[\u001b[38;5;34m0\u001b[0m… │\n",
|
|
"│ (\u001b[38;5;33mBidirectional\u001b[0m) │ │ │ │\n",
|
|
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
|
|
"│ token_features │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m38\u001b[0m, \u001b[38;5;34m200\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ text_embedding[\u001b[38;5;34m1\u001b[0m… │\n",
|
|
"│ (\u001b[38;5;33mConcatenate\u001b[0m) │ │ │ ner_embedding[\u001b[38;5;34m0\u001b[0m]… │\n",
|
|
"│ │ │ │ srl_embedding[\u001b[38;5;34m0\u001b[0m]… │\n",
|
|
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
|
|
"│ context_attention │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m38\u001b[0m, \u001b[38;5;34m256\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ bidirectional_56… │\n",
|
|
"│ (\u001b[38;5;33mAttention\u001b[0m) │ │ │ bidirectional_56… │\n",
|
|
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
|
|
"│ bidirectional_57 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m38\u001b[0m, \u001b[38;5;34m256\u001b[0m) │ \u001b[38;5;34m336,896\u001b[0m │ token_features[\u001b[38;5;34m0\u001b[0m… │\n",
|
|
"│ (\u001b[38;5;33mBidirectional\u001b[0m) │ │ │ │\n",
|
|
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
|
|
"│ context_att_pool │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m256\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ context_attentio… │\n",
|
|
"│ (\u001b[38;5;33mGlobalMaxPooling1…\u001b[0m │ │ │ │\n",
|
|
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
|
|
"│ token_pool │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m256\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ bidirectional_57… │\n",
|
|
"│ (\u001b[38;5;33mGlobalMaxPooling1…\u001b[0m │ │ │ │\n",
|
|
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
|
|
"│ q_type_input │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m5\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n",
|
|
"│ (\u001b[38;5;33mInputLayer\u001b[0m) │ │ │ │\n",
|
|
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
|
|
"│ all_features │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m517\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ context_att_pool… │\n",
|
|
"│ (\u001b[38;5;33mConcatenate\u001b[0m) │ │ │ token_pool[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n",
|
|
"│ │ │ │ q_type_input[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m…\u001b[0m │\n",
|
|
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
|
|
"│ dense_1 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m512\u001b[0m) │ \u001b[38;5;34m265,216\u001b[0m │ all_features[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m…\u001b[0m │\n",
|
|
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
|
|
"│ dropout_56 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m512\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ dense_1[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
|
|
"│ (\u001b[38;5;33mDropout\u001b[0m) │ │ │ │\n",
|
|
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
|
|
"│ dense_2 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m256\u001b[0m) │ \u001b[38;5;34m131,328\u001b[0m │ dropout_56[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
|
|
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
|
|
"│ dropout_57 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m256\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ dense_2[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
|
|
"│ (\u001b[38;5;33mDropout\u001b[0m) │ │ │ │\n",
|
|
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
|
|
"│ decoder_input │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m256\u001b[0m) │ \u001b[38;5;34m65,792\u001b[0m │ dropout_57[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
|
|
"│ (\u001b[38;5;33mDense\u001b[0m) │ │ │ │\n",
|
|
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
|
|
"│ repeat_vector_28 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m15\u001b[0m, \u001b[38;5;34m256\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ decoder_input[\u001b[38;5;34m0\u001b[0m]… │\n",
|
|
"│ (\u001b[38;5;33mRepeatVector\u001b[0m) │ │ │ │\n",
|
|
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
|
|
"│ decoder_lstm (\u001b[38;5;33mLSTM\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m15\u001b[0m, \u001b[38;5;34m256\u001b[0m) │ \u001b[38;5;34m525,312\u001b[0m │ repeat_vector_28… │\n",
|
|
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
|
|
"│ time_distributed_28 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m15\u001b[0m, \u001b[38;5;34m2230\u001b[0m) │ \u001b[38;5;34m573,110\u001b[0m │ decoder_lstm[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m…\u001b[0m │\n",
|
|
"│ (\u001b[38;5;33mTimeDistributed\u001b[0m) │ │ │ │\n",
|
|
"└─────────────────────┴───────────────────┴────────────┴───────────────────┘\n"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Total params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">2,358,350</span> (9.00 MB)\n",
|
|
"</pre>\n"
|
|
],
|
|
"text/plain": [
|
|
"\u001b[1m Total params: \u001b[0m\u001b[38;5;34m2,358,350\u001b[0m (9.00 MB)\n"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Trainable params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">2,358,350</span> (9.00 MB)\n",
|
|
"</pre>\n"
|
|
],
|
|
"text/plain": [
|
|
"\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m2,358,350\u001b[0m (9.00 MB)\n"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Non-trainable params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> (0.00 B)\n",
|
|
"</pre>\n"
|
|
],
|
|
"text/plain": [
|
|
"\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"embedding_dim = 100\n",
|
|
"lstm_units = 128\n",
|
|
"ner_embedding_dim = 50\n",
|
|
"srl_embedding_dim = 50\n",
|
|
"dropout_rate = 0.3\n",
|
|
"\n",
|
|
"\n",
|
|
"# Function untuk membuat model prediksi pertanyaan\n",
|
|
"def create_question_prediction_model():\n",
|
|
" # Input layers\n",
|
|
" context_input = Input(shape=(max_context_len,), name=\"context_input\")\n",
|
|
" token_input = Input(shape=(max_token_len,), name=\"token_input\")\n",
|
|
" ner_input = Input(shape=(max_token_len,), name=\"ner_input\")\n",
|
|
" srl_input = Input(shape=(max_token_len,), name=\"srl_input\")\n",
|
|
" q_type_input = Input(shape=(q_type_vocab_size,), name=\"q_type_input\")\n",
|
|
"\n",
|
|
" # Shared embedding layer for text\n",
|
|
" text_embedding = Embedding(vocab_size, embedding_dim, name=\"text_embedding\")\n",
|
|
"\n",
|
|
" # Embedding untuk NER dan SRL\n",
|
|
" ner_embedding = Embedding(ner_vocab_size, ner_embedding_dim, name=\"ner_embedding\")(\n",
|
|
" ner_input\n",
|
|
" )\n",
|
|
" srl_embedding = Embedding(srl_vocab_size, srl_embedding_dim, name=\"srl_embedding\")(\n",
|
|
" srl_input\n",
|
|
" )\n",
|
|
"\n",
|
|
" # Apply embeddings\n",
|
|
" context_embed = text_embedding(context_input)\n",
|
|
" token_embed = text_embedding(token_input)\n",
|
|
"\n",
|
|
" # Bi-directional LSTM untuk context dan token-level features\n",
|
|
" context_lstm = Bidirectional(\n",
|
|
" LSTM(lstm_units, return_sequences=True, name=\"context_lstm\")\n",
|
|
" )(context_embed)\n",
|
|
"\n",
|
|
" # Concat token features (tokens, NER, SRL)\n",
|
|
" token_features = Concatenate(name=\"token_features\")(\n",
|
|
" [token_embed, ner_embedding, srl_embedding]\n",
|
|
" )\n",
|
|
" token_lstm = Bidirectional(\n",
|
|
" LSTM(lstm_units, return_sequences=True, name=\"token_lstm\")\n",
|
|
" )(token_features)\n",
|
|
"\n",
|
|
" # Apply attention to context LSTM\n",
|
|
" context_attention = tf.keras.layers.Attention(name=\"context_attention\")(\n",
|
|
" [context_lstm, context_lstm]\n",
|
|
" )\n",
|
|
"\n",
|
|
" # Pool attention outputs\n",
|
|
" context_att_pool = tf.keras.layers.GlobalMaxPooling1D(name=\"context_att_pool\")(\n",
|
|
" context_attention\n",
|
|
" )\n",
|
|
" token_pool = tf.keras.layers.GlobalMaxPooling1D(name=\"token_pool\")(token_lstm)\n",
|
|
"\n",
|
|
" # Concat all features (tidak ada answer feature)\n",
|
|
" all_features = Concatenate(name=\"all_features\")(\n",
|
|
" [context_att_pool, token_pool, q_type_input]\n",
|
|
" )\n",
|
|
"\n",
|
|
" # Dense layers with expanded capacity for sequence generation\n",
|
|
" x = Dense(512, activation=\"relu\", name=\"dense_1\")(all_features)\n",
|
|
" x = Dropout(dropout_rate)(x)\n",
|
|
" x = Dense(256, activation=\"relu\", name=\"dense_2\")(x)\n",
|
|
" x = Dropout(dropout_rate)(x)\n",
|
|
"\n",
|
|
" # Reshape untuk sequence decoder\n",
|
|
" decoder_dense = Dense(vocab_size, activation=\"softmax\", name=\"decoder_dense\")\n",
|
|
"\n",
|
|
" # Many-to-many architecture for sequence generation\n",
|
|
" # Decoder LSTM\n",
|
|
" decoder_lstm = LSTM(lstm_units * 2, return_sequences=True, name=\"decoder_lstm\")\n",
|
|
"\n",
|
|
" # Reshape untuk input ke decoder\n",
|
|
" decoder_input = Dense(lstm_units * 2, activation=\"relu\", name=\"decoder_input\")(x)\n",
|
|
"\n",
|
|
" # Decoder sequence with teacher forcing\n",
|
|
" # Expand dimensionality to match expected sequence length\n",
|
|
" repeated_vector = tf.keras.layers.RepeatVector(max_question_len)(decoder_input)\n",
|
|
"\n",
|
|
" # Process through decoder LSTM\n",
|
|
" decoder_outputs = decoder_lstm(repeated_vector)\n",
|
|
"\n",
|
|
" # Apply dense layer to each timestep\n",
|
|
" question_output_seq = tf.keras.layers.TimeDistributed(decoder_dense)(\n",
|
|
" decoder_outputs\n",
|
|
" )\n",
|
|
"\n",
|
|
" # Create model\n",
|
|
" model = Model(\n",
|
|
" inputs=[\n",
|
|
" context_input,\n",
|
|
" token_input,\n",
|
|
" ner_input,\n",
|
|
" srl_input,\n",
|
|
" q_type_input,\n",
|
|
" ],\n",
|
|
" outputs=question_output_seq,\n",
|
|
" )\n",
|
|
"\n",
|
|
" # Compile model with categorical crossentropy for sequence prediction\n",
|
|
" model.compile(\n",
|
|
" optimizer=\"adam\", loss=\"sparse_categorical_crossentropy\", metrics=[\"accuracy\"]\n",
|
|
" )\n",
|
|
"\n",
|
|
" return model\n",
|
|
"\n",
|
|
"\n",
|
|
"# Buat model\n",
|
|
"model = create_question_prediction_model()\n",
|
|
"model.summary()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "6ba404db",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Epoch 1/70\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"checkpoint = ModelCheckpoint(\n",
|
|
" \"question_prediction_model.h5\",\n",
|
|
" monitor=\"val_accuracy\",\n",
|
|
" save_best_only=True,\n",
|
|
" verbose=1,\n",
|
|
")\n",
|
|
"\n",
|
|
"early_stop = EarlyStopping(monitor=\"val_accuracy\", patience=10, verbose=1)\n",
|
|
"\n",
|
|
"# Reshaping question data for sequence-to-sequence training\n",
|
|
"# We need to reshape to (samples, max_question_len, 1) for sparse categorical crossentropy\n",
|
|
"train_question_target = np.expand_dims(train_question, -1)\n",
|
|
"test_question_target = np.expand_dims(test_question, -1)\n",
|
|
"\n",
|
|
"# Training parameters\n",
|
|
"batch_size = 8\n",
|
|
"epochs = 70\n",
|
|
"\n",
|
|
"# Train model\n",
|
|
"history = model.fit(\n",
|
|
" [train_context, train_token, train_ner, train_srl, train_q_type],\n",
|
|
" train_question_target,\n",
|
|
" batch_size=batch_size,\n",
|
|
" epochs=epochs,\n",
|
|
" validation_data=(\n",
|
|
" [test_context, test_token, test_ner, test_srl, test_q_type],\n",
|
|
" test_question_target,\n",
|
|
" ),\n",
|
|
" callbacks=[\n",
|
|
" # checkpoint,\n",
|
|
" early_stop,\n",
|
|
" ],\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "184209bc",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"image/png": "",
|
|
"text/plain": [
|
|
"<Figure size 1200x400 with 2 Axes>"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"WARNING:absl:You are saving your model as an HDF5 file via `model.save()` or `keras.saving.save_model(model)`. This file format is considered legacy. We recommend using instead the native Keras format, e.g. `model.save('my_model.keras')` or `keras.saving.save_model(model, 'my_model.keras')`. \n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# Plot training history\n",
|
|
"plt.figure(figsize=(12, 4))\n",
|
|
"plt.subplot(1, 2, 1)\n",
|
|
"plt.plot(history.history[\"accuracy\"])\n",
|
|
"plt.plot(history.history[\"val_accuracy\"])\n",
|
|
"plt.title(\"Model Accuracy\")\n",
|
|
"plt.ylabel(\"Accuracy\")\n",
|
|
"plt.xlabel(\"Epoch\")\n",
|
|
"plt.legend([\"Train\", \"Validation\"], loc=\"upper left\")\n",
|
|
"\n",
|
|
"plt.subplot(1, 2, 2)\n",
|
|
"plt.plot(history.history[\"loss\"])\n",
|
|
"plt.plot(history.history[\"val_loss\"])\n",
|
|
"plt.title(\"Model Loss\")\n",
|
|
"plt.ylabel(\"Loss\")\n",
|
|
"plt.xlabel(\"Epoch\")\n",
|
|
"plt.legend([\"Train\", \"Validation\"], loc=\"upper left\")\n",
|
|
"plt.tight_layout()\n",
|
|
"plt.savefig(\"question_prediction_training_history.png\")\n",
|
|
"plt.show()\n",
|
|
"\n",
|
|
"# Simpan model dan tokenizer\n",
|
|
"model.save(\"question_prediction_model_final.h5\")\n",
|
|
"\n",
|
|
"# Simpan tokenizer\n",
|
|
"tokenizer_data = {\n",
|
|
" \"word_tokenizer\": tokenizer.to_json(),\n",
|
|
" \"ner_tokenizer\": ner_tokenizer.to_json(),\n",
|
|
" \"srl_tokenizer\": srl_tokenizer.to_json(),\n",
|
|
" \"q_type_tokenizer\": q_type_tokenizer.to_json(),\n",
|
|
" \"max_context_len\": max_context_len,\n",
|
|
" \"max_question_len\": max_question_len,\n",
|
|
" \"max_token_len\": max_token_len,\n",
|
|
"}\n",
|
|
"\n",
|
|
"with open(\"question_prediction_tokenizers.json\", \"w\") as f:\n",
|
|
" json.dump(tokenizer_data, f)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "71ec455a",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\n",
|
|
"Model Performance Metrics:\n",
|
|
"Average BLEU Score: 1.05%\n",
|
|
"Average BLEU-1 Precision: 23.45%\n",
|
|
"Average BLEU-2 Precision: 7.38%\n",
|
|
"Average BLEU-3 Precision: 3.56%\n",
|
|
"Average BLEU-4 Precision: 0.97%\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"from collections import Counter\n",
|
|
"import pandas as pd\n",
|
|
"import numpy as np\n",
|
|
"\n",
|
|
"def evaluate_model_performance(test_data):\n",
|
|
" output_path = \"bleu_question_calculation.xlsx\"\n",
|
|
"\n",
|
|
" bleu_scores = []\n",
|
|
" rows = []\n",
|
|
"\n",
|
|
" for i in range(len(test_data)):\n",
|
|
" sample_context = contexts[test_data[i]]\n",
|
|
" sample_tokens = tokens_list[test_data[i]]\n",
|
|
" sample_ner = ner_list[test_data[i]]\n",
|
|
" sample_srl = srl_list[test_data[i]]\n",
|
|
" sample_q_type = q_types[test_data[i]]\n",
|
|
" actual_question = questions[test_data[i]]\n",
|
|
"\n",
|
|
" pred_question = predict_question(\n",
|
|
" sample_context, sample_tokens, sample_ner, sample_srl, sample_q_type\n",
|
|
" )\n",
|
|
"\n",
|
|
" actual_tokens = actual_question.split()\n",
|
|
" pred_tokens = pred_question.split()\n",
|
|
"\n",
|
|
" max_n = 4\n",
|
|
" weights = [1 / max_n] * max_n\n",
|
|
" clipped_counts = []\n",
|
|
" total_counts = []\n",
|
|
" precisions = []\n",
|
|
"\n",
|
|
" # print(f\"Sample {i+1}:\")\n",
|
|
" # print(f\"Actual Tokens: {actual_tokens}\")\n",
|
|
" # print(f\"Predicted Tokens: {pred_tokens}\")\n",
|
|
"\n",
|
|
" # Kalkulasi untuk setiap n-gram\n",
|
|
" for n in range(1, max_n + 1):\n",
|
|
" ref_ngrams = Counter(tuple(actual_tokens[j:j + n]) for j in range(len(actual_tokens) - n + 1))\n",
|
|
" cand_ngrams = Counter(tuple(pred_tokens[j:j + n]) for j in range(len(pred_tokens) - n + 1))\n",
|
|
"\n",
|
|
" clip_sum = sum(min(cnt, ref_ngrams.get(ng, 0)) for ng, cnt in cand_ngrams.items())\n",
|
|
" total = sum(cand_ngrams.values())\n",
|
|
" p_n = clip_sum / total if total > 0 else 0\n",
|
|
"\n",
|
|
" clipped_counts.append(clip_sum)\n",
|
|
" total_counts.append(total)\n",
|
|
" precisions.append(p_n)\n",
|
|
"\n",
|
|
" # print(f\"{n}-gram: clipped count = {clip_sum}, total candidate = {total}, precision = {p_n:.4f}\")\n",
|
|
"\n",
|
|
" c = len(pred_tokens)\n",
|
|
" r = len(actual_tokens)\n",
|
|
"\n",
|
|
" if c == 0:\n",
|
|
" bp = 0 \n",
|
|
" # print(f\"Brevity Penalty: BP = {bp:.4f} (c={c}, r={r}) - No predicted tokens.\")\n",
|
|
" else:\n",
|
|
" bp = 1 if c > r else np.exp(1 - r / c)\n",
|
|
" # print(f\"Brevity Penalty: BP = {bp:.4f} (c={c}, r={r})\")\n",
|
|
"\n",
|
|
" if all(p > 0 for p in precisions):\n",
|
|
" bleu = bp * np.exp(sum(w * np.log(p) for w, p in zip(weights, precisions)))\n",
|
|
" else:\n",
|
|
" bleu = 0.0\n",
|
|
"\n",
|
|
" # print(f\"BLEU score = {bleu:.4f}\\n\")\n",
|
|
"\n",
|
|
" bleu_scores.append(bleu)\n",
|
|
" \n",
|
|
" # Membuat row data dengan kolom terpisah\n",
|
|
" row_data = {\n",
|
|
" \"Sample\": i + 1,\n",
|
|
" \"Actual_Question\": actual_question,\n",
|
|
" \"Predicted_Question\": pred_question,\n",
|
|
" \"Actual_Tokens_Count\": len(actual_tokens),\n",
|
|
" \"Predicted_Tokens_Count\": len(pred_tokens),\n",
|
|
" \n",
|
|
" # BLEU-1\n",
|
|
" \"BLEU1_Clipped_Count\": clipped_counts[0],\n",
|
|
" \"BLEU1_Total_Count\": total_counts[0],\n",
|
|
" \"BLEU1_Precision\": precisions[0],\n",
|
|
" \n",
|
|
" # BLEU-2\n",
|
|
" \"BLEU2_Clipped_Count\": clipped_counts[1],\n",
|
|
" \"BLEU2_Total_Count\": total_counts[1],\n",
|
|
" \"BLEU2_Precision\": precisions[1],\n",
|
|
" \n",
|
|
" # BLEU-3\n",
|
|
" \"BLEU3_Clipped_Count\": clipped_counts[2],\n",
|
|
" \"BLEU3_Total_Count\": total_counts[2],\n",
|
|
" \"BLEU3_Precision\": precisions[2],\n",
|
|
" \n",
|
|
" # BLEU-4\n",
|
|
" \"BLEU4_Clipped_Count\": clipped_counts[3],\n",
|
|
" \"BLEU4_Total_Count\": total_counts[3],\n",
|
|
" \"BLEU4_Precision\": precisions[3],\n",
|
|
" \n",
|
|
" # Brevity Penalty dan BLEU Score\n",
|
|
" \"Brevity_Penalty\": bp,\n",
|
|
" \"BLEU_Score\": bleu\n",
|
|
" }\n",
|
|
" \n",
|
|
" rows.append(row_data)\n",
|
|
"\n",
|
|
" # Membuat DataFrame dan menyimpan ke Excel\n",
|
|
" df = pd.DataFrame(rows)\n",
|
|
" \n",
|
|
" # Menambahkan sheet untuk summary statistics\n",
|
|
" with pd.ExcelWriter(output_path, engine='openpyxl') as writer:\n",
|
|
" # Sheet detail untuk setiap sample\n",
|
|
" df.to_excel(writer, sheet_name='BLEU_Details', index=False)\n",
|
|
" \n",
|
|
" # Sheet summary untuk statistik keseluruhan\n",
|
|
" summary_data = {\n",
|
|
" 'Metric': [\n",
|
|
" 'Average BLEU Score',\n",
|
|
" 'Average BLEU-1 Precision',\n",
|
|
" 'Average BLEU-2 Precision', \n",
|
|
" 'Average BLEU-3 Precision',\n",
|
|
" 'Average BLEU-4 Precision',\n",
|
|
" 'Average Brevity Penalty',\n",
|
|
" 'Total Samples'\n",
|
|
" ],\n",
|
|
" 'Value': [\n",
|
|
" np.mean(bleu_scores),\n",
|
|
" df['BLEU1_Precision'].mean(),\n",
|
|
" df['BLEU2_Precision'].mean(),\n",
|
|
" df['BLEU3_Precision'].mean(),\n",
|
|
" df['BLEU4_Precision'].mean(),\n",
|
|
" df['Brevity_Penalty'].mean(),\n",
|
|
" len(test_data)\n",
|
|
" ]\n",
|
|
" }\n",
|
|
" \n",
|
|
" summary_df = pd.DataFrame(summary_data)\n",
|
|
" summary_df.to_excel(writer, sheet_name='Summary', index=False)\n",
|
|
" \n",
|
|
" # print(f\"Hasil disimpan di: {output_path}\")\n",
|
|
" # print(\"File Excel berisi 2 sheet:\")\n",
|
|
" # print(\"1. 'BLEU_Details' - Detail kalkulasi untuk setiap sample\")\n",
|
|
" # print(\"2. 'Summary' - Ringkasan statistik keseluruhan\")\n",
|
|
"\n",
|
|
" results = {\n",
|
|
" \"avg_bleu_score\": np.mean(bleu_scores),\n",
|
|
" \"avg_bleu1_precision\": df['BLEU1_Precision'].mean(),\n",
|
|
" \"avg_bleu2_precision\": df['BLEU2_Precision'].mean(),\n",
|
|
" \"avg_bleu3_precision\": df['BLEU3_Precision'].mean(),\n",
|
|
" \"avg_bleu4_precision\": df['BLEU4_Precision'].mean(),\n",
|
|
" }\n",
|
|
"\n",
|
|
" return results\n",
|
|
"\n",
|
|
"# Jalankan evaluasi\n",
|
|
"performance_metrics = evaluate_model_performance(test_indices)\n",
|
|
"\n",
|
|
"print(\"\\nModel Performance Metrics:\")\n",
|
|
"print(f\"Average BLEU Score: {performance_metrics['avg_bleu_score'] * 100:.2f}%\")\n",
|
|
"print(f\"Average BLEU-1 Precision: {performance_metrics['avg_bleu1_precision'] * 100:.2f}%\")\n",
|
|
"print(f\"Average BLEU-2 Precision: {performance_metrics['avg_bleu2_precision'] * 100:.2f}%\")\n",
|
|
"print(f\"Average BLEU-3 Precision: {performance_metrics['avg_bleu3_precision'] * 100:.2f}%\")\n",
|
|
"print(f\"Average BLEU-4 Precision: {performance_metrics['avg_bleu4_precision'] * 100:.2f}%\")"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "myenv",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.10.16"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|