390 lines
16 KiB
Python
390 lines
16 KiB
Python
# ===============================================================
|
||
# Seq2Seq‑LSTM + Luong Attention untuk Question‑Answer Generator
|
||
# + Greedy & Beam Search decoding + BLEU‑4 evaluation
|
||
# ===============================================================
|
||
# • Semua embedding mask_zero=True (padding di‑mask)
|
||
# • Encoder = Bidirectional LSTM (return_sequences=True)
|
||
# • Decoder = LSTM + Luong Attention (keras.layers.Attention).
|
||
# • Greedy & beam‑search inference sub‑model dibangun terpisah (encoder,
|
||
# decoder‑Q‑step, decoder‑A‑step).
|
||
# • BLEU score (nltk.corpus_bleu) untuk evaluasi pertanyaan & jawaban.
|
||
# ---------------------------------------------------------------
|
||
# PETUNJUK
|
||
# 1. pip install nltk
|
||
# 2. python seq2seq_qa_attention.py # train + simpan model
|
||
# 3. jalankan fungsi evaluate_bleu() # hitung BLEU di validation/test
|
||
# ===============================================================
|
||
|
||
import json
|
||
from pathlib import Path
|
||
from itertools import chain
|
||
import numpy as np
|
||
import tensorflow as tf
|
||
from tensorflow.keras.preprocessing.sequence import pad_sequences
|
||
from tensorflow.keras.layers import (
|
||
Input, Embedding, LSTM, Bidirectional, Concatenate,
|
||
Dense, TimeDistributed, Attention
|
||
)
|
||
from tensorflow.keras.models import Model
|
||
from nltk.translate.bleu_score import corpus_bleu # pip install nltk
|
||
|
||
# ----------------------- 1. Load & flatten data ----------------------------
|
||
RAW = json.loads(Path("../dataset/dev_dataset_test.json").read_text())
|
||
|
||
samples = []
|
||
for item in RAW:
|
||
for qp in item["quiz_posibility"]:
|
||
samp = {
|
||
"tokens": [t.lower() for t in item["tokens"]],
|
||
"ner": item["ner"],
|
||
"srl": item["srl"],
|
||
"q_type": qp["type"],
|
||
"q_toks": [t.lower() for t in qp["question"]] + ["<eos>"],
|
||
}
|
||
if isinstance(qp["answer"], list):
|
||
samp["a_toks"] = [t.lower() for t in qp["answer"]] + ["<eos>"]
|
||
else:
|
||
samp["a_toks"] = [qp["answer"].lower(), "<eos>"]
|
||
samples.append(samp)
|
||
|
||
print("Total flattened samples:", len(samples))
|
||
|
||
# ----------------------- 2. Build vocabularies -----------------------------
|
||
|
||
def build_vocab(seq_iter, reserved=("<pad>", "<unk>", "<sos>", "<eos>")):
|
||
vocab = {tok: idx for idx, tok in enumerate(reserved)}
|
||
for tok in chain.from_iterable(seq_iter):
|
||
if tok not in vocab:
|
||
vocab[tok] = len(vocab)
|
||
return vocab
|
||
|
||
v_tok = build_vocab((s["tokens"] for s in samples))
|
||
v_ner = build_vocab((s["ner"] for s in samples), reserved=("<pad>", "<unk>"))
|
||
v_srl = build_vocab((s["srl"] for s in samples), reserved=("<pad>", "<unk>"))
|
||
v_q = build_vocab((s["q_toks"] for s in samples))
|
||
v_a = build_vocab((s["a_toks"] for s in samples))
|
||
v_typ = {"isian": 0, "opsi": 1, "true_false": 2}
|
||
|
||
iv_q = {i: t for t, i in v_q.items()}
|
||
iv_a = {i: t for t, i in v_a.items()}
|
||
|
||
# ----------------------- 3. Vectorise + pad -------------------------------
|
||
|
||
def encode(seq, vmap):
|
||
return [vmap.get(tok, vmap["<unk>"]) for tok in seq]
|
||
|
||
MAX_SENT = max(len(s["tokens"]) for s in samples)
|
||
MAX_Q = max(len(s["q_toks"]) for s in samples)
|
||
MAX_A = max(len(s["a_toks"]) for s in samples)
|
||
|
||
X_tok_ids = pad_sequences([encode(s["tokens"], v_tok) for s in samples],
|
||
maxlen=MAX_SENT, padding="post")
|
||
X_ner_ids = pad_sequences([encode(s["ner"], v_ner) for s in samples],
|
||
maxlen=MAX_SENT, padding="post")
|
||
X_srl_ids = pad_sequences([encode(s["srl"], v_srl) for s in samples],
|
||
maxlen=MAX_SENT, padding="post")
|
||
|
||
q_in_ids = pad_sequences([[v_q["<sos>"], *encode(s["q_toks"][:-1], v_q)]
|
||
for s in samples], maxlen=MAX_Q, padding="post")
|
||
q_out_ids = pad_sequences([encode(s["q_toks"], v_q) for s in samples],
|
||
maxlen=MAX_Q, padding="post")
|
||
|
||
a_in_ids = pad_sequences([[v_a["<sos>"], *encode(s["a_toks"][:-1], v_a)]
|
||
for s in samples], maxlen=MAX_A, padding="post")
|
||
a_out_ids = pad_sequences([encode(s["a_toks"], v_a) for s in samples],
|
||
maxlen=MAX_A, padding="post")
|
||
|
||
y_type_ids = np.array([v_typ[s["q_type"]] for s in samples])
|
||
|
||
# ----------------------- 4. Hyper‑params ----------------------------------
|
||
d_tok = 32 # token embedding dim
|
||
d_tag = 16 # NER / SRL embedding dim
|
||
units = 64 # per direction of BiLSTM
|
||
lat_dim = units * 2
|
||
|
||
# ----------------------- 5. Build model -----------------------------------
|
||
# Encoder ----------------------------------------------------------
|
||
|
||
tok_in = Input((MAX_SENT,), dtype="int32", name="tok_in")
|
||
ner_in = Input((MAX_SENT,), dtype="int32", name="ner_in")
|
||
srl_in = Input((MAX_SENT,), dtype="int32", name="srl_in")
|
||
|
||
emb_tok = Embedding(len(v_tok), d_tok, mask_zero=True, name="emb_tok")(tok_in)
|
||
emb_ner = Embedding(len(v_ner), d_tag, mask_zero=True, name="emb_ner")(ner_in)
|
||
emb_srl = Embedding(len(v_srl), d_tag, mask_zero=True, name="emb_srl")(srl_in)
|
||
|
||
enc_concat = Concatenate(name="enc_concat")([emb_tok, emb_ner, emb_srl])
|
||
bi_lstm = Bidirectional(LSTM(units, return_sequences=True, return_state=True),
|
||
name="encoder_bi_lstm")
|
||
enc_seq, f_h, f_c, b_h, b_c = bi_lstm(enc_concat)
|
||
enc_h = Concatenate()( [f_h, b_h] ) # (B, lat_dim)
|
||
enc_c = Concatenate()( [f_c, b_c] )
|
||
|
||
# Decoder – QUESTION ----------------------------------------------
|
||
q_in = Input((MAX_Q,), dtype="int32", name="q_in")
|
||
# 💡 mask_zero=False supaya Attention tidak bentrok dengan mask encoder
|
||
q_emb = Embedding(len(v_q), d_tok, mask_zero=False, name="q_emb")(q_in)
|
||
|
||
dec_q_lstm = LSTM(lat_dim, return_sequences=True, return_state=True,
|
||
name="decoder_q_lstm")
|
||
q_seq, q_h, q_c = dec_q_lstm(q_emb, initial_state=[enc_h, enc_c])
|
||
|
||
enc_proj_q = TimeDistributed(Dense(lat_dim), name="enc_proj_q")(enc_seq)
|
||
attn_q = Attention(name="attn_q")([q_seq, enc_proj_q])
|
||
q_concat = Concatenate(name="q_concat")([q_seq, attn_q])
|
||
q_out = TimeDistributed(Dense(len(v_q), activation="softmax"), name="q_out")(q_concat)
|
||
|
||
# Decoder – ANSWER -------------------------------------------------
|
||
a_in = Input((MAX_A,), dtype="int32", name="a_in")
|
||
# juga mask_zero=False
|
||
a_emb = Embedding(len(v_a), d_tok, mask_zero=False, name="a_emb")(a_in)
|
||
|
||
dec_a_lstm = LSTM(lat_dim, return_sequences=True, return_state=True,
|
||
name="decoder_a_lstm")
|
||
a_seq, _, _ = dec_a_lstm(a_emb, initial_state=[q_h, q_c])
|
||
|
||
enc_proj_a = TimeDistributed(Dense(lat_dim), name="enc_proj_a")(enc_seq)
|
||
attn_a = Attention(name="attn_a")([a_seq, enc_proj_a])
|
||
a_concat = Concatenate(name="a_concat")([a_seq, attn_a])
|
||
a_out = TimeDistributed(Dense(len(v_a), activation="softmax"), name="a_out")(a_concat)
|
||
|
||
# Classifier -------------------------------------------------------
|
||
type_dense = Dense(len(v_typ), activation="softmax", name="type_out")(enc_h)
|
||
|
||
model = Model(inputs=[tok_in, ner_in, srl_in, q_in, a_in],
|
||
outputs=[q_out, a_out, type_dense])
|
||
model.summary()
|
||
|
||
# ----------------------- 6. Compile & train ------------------------------
|
||
losses = {
|
||
"q_out": "sparse_categorical_crossentropy",
|
||
"a_out": "sparse_categorical_crossentropy",
|
||
"type_out": "sparse_categorical_crossentropy",
|
||
}
|
||
loss_weights = {"q_out": 1.0, "a_out": 1.0, "type_out": 0.3}
|
||
|
||
model.compile(optimizer="adam", loss=losses, loss_weights=loss_weights,
|
||
metrics={"q_out": "sparse_categorical_accuracy",
|
||
"a_out": "sparse_categorical_accuracy",
|
||
"type_out": "accuracy"})
|
||
|
||
history = model.fit(
|
||
[X_tok_ids, X_ner_ids, X_srl_ids, q_in_ids, a_in_ids],
|
||
[q_out_ids, a_out_ids, y_type_ids],
|
||
validation_split=0.1,
|
||
epochs=30,
|
||
batch_size=64,
|
||
callbacks=[tf.keras.callbacks.EarlyStopping(patience=4, restore_best_weights=True)],
|
||
verbose=1,
|
||
)
|
||
|
||
model.save("seq2seq_attn.keras")
|
||
print("Model saved to seq2seq_attn.keras")
|
||
|
||
# ----------------------- 7. Inference sub‑models --------------------------
|
||
# Encoder model
|
||
encoder_model = Model([tok_in, ner_in, srl_in], [enc_seq, enc_h, enc_c])
|
||
|
||
# Question decoder step model ------------------------------------------------
|
||
# Inputs
|
||
q_token_in = Input((1,), dtype="int32", name="q_token_in")
|
||
enc_seq_in = Input((MAX_SENT, lat_dim), name="enc_seq_in")
|
||
enc_proj_q_in = Input((MAX_SENT, lat_dim), name="enc_proj_q_in")
|
||
state_h_in = Input((lat_dim,), name="state_h_in")
|
||
state_c_in = Input((lat_dim,), name="state_c_in")
|
||
|
||
# Embedding
|
||
q_emb_step = model.get_layer("q_emb")(q_token_in)
|
||
|
||
# LSTM (reuse weights)
|
||
q_lstm_step, h_out, c_out = model.get_layer("decoder_q_lstm")(q_emb_step,
|
||
initial_state=[state_h_in, state_c_in])
|
||
# Attention
|
||
attn_step = model.get_layer("attn_q")([q_lstm_step, enc_proj_q_in])
|
||
q_concat_step = Concatenate()([q_lstm_step, attn_step])
|
||
q_logits_step = model.get_layer("q_out")(q_concat_step)
|
||
|
||
decoder_q_step = Model([q_token_in, enc_seq_in, enc_proj_q_in, state_h_in, state_c_in],
|
||
[q_logits_step, h_out, c_out])
|
||
|
||
# Answer decoder step model --------------------------------------------------
|
||
a_token_in = Input((1,), dtype="int32", name="a_token_in")
|
||
enc_proj_a_in = Input((MAX_SENT, lat_dim), name="enc_proj_a_in")
|
||
state_h_a_in = Input((lat_dim,), name="state_h_a_in")
|
||
state_c_a_in = Input((lat_dim,), name="state_c_a_in")
|
||
|
||
# Embedding reuse
|
||
a_emb_step = model.get_layer("a_emb")(a_token_in)
|
||
|
||
# LSTM reuse
|
||
a_lstm_step, h_a_out, c_a_out = model.get_layer("decoder_a_lstm")(a_emb_step,
|
||
initial_state=[state_h_a_in, state_c_a_in])
|
||
# Attention reuse
|
||
attn_a_step = model.get_layer("attn_a")([a_lstm_step, enc_proj_a_in])
|
||
a_concat_step = Concatenate()([a_lstm_step, attn_a_step])
|
||
a_logits_step = model.get_layer("a_out")(a_concat_step)
|
||
|
||
decoder_a_step = Model([a_token_in, enc_proj_a_in, state_h_a_in, state_c_a_in],
|
||
[a_logits_step, h_a_out, c_a_out])
|
||
|
||
# ----------------------- 8. Decoding helpers ------------------------------
|
||
|
||
def encode_and_pad(seq, vmap, max_len):
|
||
ids = encode(seq, vmap)
|
||
return ids + [vmap["<pad>"]] * (max_len - len(ids))
|
||
|
||
|
||
def greedy_decode(tokens, ner, srl, max_q=20, max_a=10):
|
||
"""Return generated (question_tokens, answer_tokens, q_type_str)"""
|
||
# --- encoder ---------------------------------------------------------
|
||
enc_tok = np.array([encode_and_pad(tokens, v_tok, MAX_SENT)])
|
||
enc_ner = np.array([encode_and_pad(ner, v_ner, MAX_SENT)])
|
||
enc_srl = np.array([encode_and_pad(srl, v_srl, MAX_SENT)])
|
||
|
||
enc_seq_val, h, c = encoder_model.predict([enc_tok, enc_ner, enc_srl], verbose=0)
|
||
enc_proj_q_val = model.get_layer("enc_proj_q")(enc_seq_val)
|
||
enc_proj_a_val = model.get_layer("enc_proj_a")(enc_seq_val)
|
||
|
||
# --- greedy Question --------------------------------------------------
|
||
q_ids = []
|
||
tgt = np.array([[v_q["<sos>"]]])
|
||
for _ in range(max_q):
|
||
logits, h, c = decoder_q_step.predict([tgt, enc_seq_val, enc_proj_q_val, h, c], verbose=0)
|
||
next_id = int(logits[0, 0].argmax())
|
||
if next_id == v_q["<eos>"]:
|
||
break
|
||
q_ids.append(next_id)
|
||
tgt = np.array([[next_id]])
|
||
|
||
# --- reset state for Answer -------------------------------------------
|
||
# Use last q_h, q_c (already in h,c)
|
||
a_ids = []
|
||
tgt_a = np.array([[v_a["<sos>"]]])
|
||
for _ in range(max_a):
|
||
logits_a, h, c = decoder_a_step.predict([tgt_a, enc_proj_a_val, h, c], verbose=0)
|
||
next_a = int(logits_a[0, 0].argmax())
|
||
if next_a == v_a["<eos>"]:
|
||
break
|
||
a_ids.append(next_a)
|
||
tgt_a = np.array([[next_a]])
|
||
|
||
# Question type
|
||
typ_logits = model.predict([enc_tok, enc_ner, enc_srl, np.zeros((1, MAX_Q)), np.zeros((1, MAX_A))], verbose=0)[2]
|
||
typ_id = int(typ_logits.argmax())
|
||
q_type = [k for k, v in v_typ.items() if v == typ_id][0]
|
||
|
||
question = [iv_q.get(i, "<unk>") for i in q_ids]
|
||
answer = [iv_a.get(i, "<unk>") for i in a_ids]
|
||
return question, answer, q_type
|
||
|
||
|
||
def beam_decode(tokens, ner, srl, beam_width=5, max_q=20, max_a=10):
|
||
"""Beam search decoding. Returns best (question_tokens, answer_tokens, q_type)"""
|
||
enc_tok = np.array([encode_and_pad(tokens, v_tok, MAX_SENT)])
|
||
enc_ner = np.array([encode_and_pad(ner, v_ner, MAX_SENT)])
|
||
enc_srl = np.array([encode_and_pad(srl, v_srl, MAX_SENT)])
|
||
enc_seq_val, h0, c0 = encoder_model.predict([enc_tok, enc_ner, enc_srl], verbose=0)
|
||
enc_proj_q_val = model.get_layer("enc_proj_q")(enc_seq_val)
|
||
enc_proj_a_val = model.get_layer("enc_proj_a")(enc_seq_val)
|
||
|
||
# ----- Beam for Question ----------------------------------------------
|
||
Beam = [( [v_q["<sos>"]], 0.0, h0, c0 )] # (sequence, logP, h, c)
|
||
completed_q = []
|
||
for _ in range(max_q):
|
||
new_beam = []
|
||
for seq, logp, h, c in Beam:
|
||
tgt = np.array([[seq[-1]]])
|
||
logits, next_h, next_c = decoder_q_step.predict([tgt, enc_seq_val, enc_proj_q_val, h, c], verbose=0)
|
||
log_probs = np.log(logits[0, 0] + 1e-8)
|
||
top_ids = np.argsort(log_probs)[-beam_width:]
|
||
for nid in top_ids:
|
||
new_seq = seq + [int(nid)]
|
||
new_logp = logp + log_probs[nid]
|
||
new_beam.append( (new_seq, new_logp, next_h, next_c) )
|
||
# keep best beam_width
|
||
Beam = sorted(new_beam, key=lambda x: x[1], reverse=True)[:beam_width]
|
||
# move completed
|
||
Beam, done = [], Beam # placeholder copy to modify
|
||
for seq, logp, h, c in done:
|
||
if seq[-1] == v_q["<eos>"] or len(seq) >= max_q:
|
||
completed_q.append( (seq, logp, h, c) )
|
||
else:
|
||
Beam.append( (seq, logp, h, c) )
|
||
if not Beam:
|
||
break
|
||
if completed_q:
|
||
best_q = max(completed_q, key=lambda x: x[1])
|
||
else:
|
||
best_q = max(Beam, key=lambda x: x[1])
|
||
|
||
q_seq_ids, _, h_q, c_q = best_q
|
||
q_ids = [i for i in q_seq_ids[1:] if i != v_q["<eos>"]]
|
||
|
||
# ----- Beam for Answer --------------------------------------------------
|
||
Beam = [( [v_a["<sos>"]], 0.0, h_q, c_q )]
|
||
completed_a = []
|
||
for _ in range(max_a):
|
||
new_beam = []
|
||
for seq, logp, h, c in Beam:
|
||
tgt = np.array([[seq[-1]]])
|
||
logits, next_h, next_c = decoder_a_step.predict([tgt, enc_proj_a_val, h, c], verbose=0)
|
||
log_probs = np.log(logits[0, 0] + 1e-8)
|
||
top_ids = np.argsort(log_probs)[-beam_width:]
|
||
for nid in top_ids:
|
||
new_seq = seq + [int(nid)]
|
||
new_logp = logp + log_probs[nid]
|
||
new_beam.append( (new_seq, new_logp, next_h, next_c) )
|
||
Beam = sorted(new_beam, key=lambda x: x[1], reverse=True)[:beam_width]
|
||
Beam, done = [], Beam
|
||
for seq, logp, h, c in done:
|
||
if seq[-1] == v_a["<eos>"] or len(seq) >= max_a:
|
||
completed_a.append( (seq, logp) )
|
||
else:
|
||
Beam.append( (seq, logp, h, c) )
|
||
if not Beam:
|
||
break
|
||
if completed_a:
|
||
best_a_seq, _ = max(completed_a, key=lambda x: x[1])
|
||
else:
|
||
best_a_seq, _ = max(Beam, key=lambda x: x[1])
|
||
a_ids = [i for i in best_a_seq[1:] if i != v_a["<eos>"]]
|
||
|
||
# Question type classification
|
||
typ_logits = model.predict([enc_tok, enc_ner, enc_srl, np.zeros((1, MAX_Q)), np.zeros((1, MAX_A))], verbose=0)[2]
|
||
typ_id = int(typ_logits.argmax())
|
||
q_type = [k for k, v in v_typ.items() if v == typ_id][0]
|
||
|
||
question = [iv_q.get(i, "<unk>") for i in q_ids]
|
||
answer = [iv_a.get(i, "<unk>") for i in a_ids]
|
||
|
||
return question, answer, q_type
|
||
|
||
# ----------------------- 9. BLEU evaluation -------------------------------
|
||
|
||
def evaluate_bleu(split_ratio=0.1, beam=False):
|
||
"""Compute corpus BLEU‑4 on hold‑out split."""
|
||
n_total = len(samples)
|
||
n_val = int(n_total * split_ratio)
|
||
idxs = np.random.choice(n_total, n_val, replace=False)
|
||
|
||
refs_q, hyps_q = [], []
|
||
refs_a, hyps_a = [], []
|
||
|
||
for i in idxs:
|
||
s = samples[i]
|
||
question_pred, answer_pred, _ = (beam_decode if beam else greedy_decode)(
|
||
s["tokens"], s["ner"], s["srl"],
|
||
)
|
||
refs_q.append([s["q_toks"][:-1]]) # exclude <eos>
|
||
hyps_q.append(question_pred)
|
||
refs_a.append([s["a_toks"][:-1]])
|
||
hyps_a.append(answer_pred)
|
||
|
||
bleu_q = corpus_bleu(refs_q, hyps_q)
|
||
bleu_a = corpus_bleu(refs_a, hyps_a)
|
||
print(f"BLEU‑4 Question: {bleu_q:.3f}\nBLEU‑4 Answer : {bleu_a:.3f}")
|
||
|
||
# Example usage (uncomment):
|
||
evaluate_bleu(beam=False)
|
||
evaluate_bleu(beam=True)
|