309 lines
12 KiB
Python
309 lines
12 KiB
Python
#!/usr/bin/env python3
|
||
# ===============================================================
|
||
# Question‑Generation seq‑to‑seq (tokens + NER + SRL → Q/A/type)
|
||
# – revised version 2025‑05‑11
|
||
# ===============================================================
|
||
|
||
import json, pickle, random
|
||
from pathlib import Path
|
||
from itertools import chain
|
||
|
||
import numpy as np
|
||
import tensorflow as tf
|
||
from tensorflow.keras.layers import (
|
||
Input, Embedding, LSTM, Concatenate,
|
||
Dense, TimeDistributed
|
||
)
|
||
from tensorflow.keras.models import Model
|
||
from nltk.translate.bleu_score import corpus_bleu, SmoothingFunction
|
||
from rouge_score import rouge_scorer, scoring
|
||
|
||
|
||
# -----------------------------------------------------------------
|
||
# 0. LOAD & FLATTEN DATA
|
||
# -----------------------------------------------------------------
|
||
RAW = json.loads(Path("../dataset/dev_dataset_qg.json").read_text())
|
||
|
||
samples = []
|
||
for item in RAW:
|
||
for qp in item["quiz_posibility"]:
|
||
samples.append({
|
||
"tokens" : item["tokens"],
|
||
"ner" : item["ner"],
|
||
"srl" : item["srl"],
|
||
"q_type" : qp["type"], # isian / opsi / benar_salah
|
||
"q_toks" : qp["question"] + ["<eos>"],
|
||
"a_toks" : (qp["answer"] if isinstance(qp["answer"], list)
|
||
else [qp["answer"]]) + ["<eos>"]
|
||
})
|
||
|
||
print("flattened samples :", len(samples))
|
||
|
||
|
||
# -----------------------------------------------------------------
|
||
# 1. VOCABULARIES
|
||
# -----------------------------------------------------------------
|
||
def build_vocab(seq_iter, reserved=("<pad>", "<unk>", "<sos>", "<eos>")):
|
||
vocab = {tok: idx for idx, tok in enumerate(reserved)}
|
||
for tok in chain.from_iterable(seq_iter):
|
||
vocab.setdefault(tok, len(vocab))
|
||
return vocab
|
||
|
||
vocab_tok = build_vocab((s["tokens"] for s in samples))
|
||
vocab_ner = build_vocab((s["ner"] for s in samples), reserved=("<pad>","<unk>"))
|
||
vocab_srl = build_vocab((s["srl"] for s in samples), reserved=("<pad>","<unk>"))
|
||
vocab_q = build_vocab((s["q_toks"] for s in samples))
|
||
vocab_a = build_vocab((s["a_toks"] for s in samples))
|
||
vocab_typ = {"isian":0, "opsi":1, "benar_salah":2}
|
||
|
||
|
||
# -----------------------------------------------------------------
|
||
# 2. ENCODING & PADDING
|
||
# -----------------------------------------------------------------
|
||
def enc(seq, v): return [v.get(t, v["<unk>"]) for t in seq]
|
||
|
||
MAX_SENT = max(len(s["tokens"]) for s in samples)
|
||
MAX_Q = max(len(s["q_toks"]) for s in samples)
|
||
MAX_A = max(len(s["a_toks"]) for s in samples)
|
||
|
||
def pad_batch(seqs, vmap, maxlen):
|
||
return tf.keras.preprocessing.sequence.pad_sequences(
|
||
[enc(s, vmap) for s in seqs], maxlen=maxlen, padding="post"
|
||
)
|
||
|
||
X_tok = pad_batch((s["tokens"] for s in samples), vocab_tok, MAX_SENT)
|
||
X_ner = pad_batch((s["ner"] for s in samples), vocab_ner, MAX_SENT)
|
||
X_srl = pad_batch((s["srl"] for s in samples), vocab_srl, MAX_SENT)
|
||
|
||
dec_q_in = pad_batch(
|
||
([["<sos>"]+s["q_toks"][:-1] for s in samples]), vocab_q, MAX_Q)
|
||
dec_q_out = pad_batch((s["q_toks"] for s in samples), vocab_q, MAX_Q)
|
||
|
||
dec_a_in = pad_batch(
|
||
([["<sos>"]+s["a_toks"][:-1] for s in samples]), vocab_a, MAX_A)
|
||
dec_a_out = pad_batch((s["a_toks"] for s in samples), vocab_a, MAX_A)
|
||
|
||
y_type = np.array([vocab_typ[s["q_type"]] for s in samples])
|
||
|
||
|
||
# -----------------------------------------------------------------
|
||
# 3. MODEL
|
||
# -----------------------------------------------------------------
|
||
d_tok, d_tag, units = 128, 32, 256
|
||
pad_tok, pad_q, pad_a = vocab_tok["<pad>"], vocab_q["<pad>"], vocab_a["<pad>"]
|
||
|
||
# ---- Encoder ----------------------------------------------------
|
||
inp_tok = Input((MAX_SENT,), name="tok_in")
|
||
inp_ner = Input((MAX_SENT,), name="ner_in")
|
||
inp_srl = Input((MAX_SENT,), name="srl_in")
|
||
|
||
emb_tok = Embedding(len(vocab_tok), d_tok, mask_zero=True, name="emb_tok")(inp_tok)
|
||
emb_ner = Embedding(len(vocab_ner), d_tag, mask_zero=True, name="emb_ner")(inp_ner)
|
||
emb_srl = Embedding(len(vocab_srl), d_tag, mask_zero=True, name="emb_srl")(inp_srl)
|
||
|
||
enc_concat = Concatenate()([emb_tok, emb_ner, emb_srl])
|
||
enc_out, state_h, state_c = LSTM(units, return_state=True, name="enc_lstm")(enc_concat)
|
||
|
||
# ---- Decoder : Question ----------------------------------------
|
||
dec_q_inp = Input((MAX_Q,), name="dec_q_in")
|
||
dec_emb_q = Embedding(len(vocab_q), d_tok, mask_zero=True, name="emb_q")(dec_q_inp)
|
||
dec_q_seq, _, _ = LSTM(units, return_sequences=True, return_state=True,
|
||
name="lstm_q")(dec_emb_q, initial_state=[state_h, state_c])
|
||
q_out = TimeDistributed(Dense(len(vocab_q), activation="softmax"), name="q_out")(dec_q_seq)
|
||
|
||
# ---- Decoder : Answer ------------------------------------------
|
||
dec_a_inp = Input((MAX_A,), name="dec_a_in")
|
||
dec_emb_a = Embedding(len(vocab_a), d_tok, mask_zero=True, name="emb_a")(dec_a_inp)
|
||
dec_a_seq, _, _ = LSTM(units, return_sequences=True, return_state=True,
|
||
name="lstm_a")(dec_emb_a, initial_state=[state_h, state_c])
|
||
a_out = TimeDistributed(Dense(len(vocab_a), activation="softmax"), name="a_out")(dec_a_seq)
|
||
|
||
# ---- Classifier -------------------------------------------------
|
||
type_out = Dense(len(vocab_typ), activation="softmax", name="type_out")(enc_out)
|
||
|
||
model = Model(
|
||
[inp_tok, inp_ner, inp_srl, dec_q_inp, dec_a_inp],
|
||
[q_out, a_out, type_out]
|
||
)
|
||
|
||
# ---- Masked loss helpers ---------------------------------------
|
||
scce = tf.keras.losses.SparseCategoricalCrossentropy(reduction="none")
|
||
def masked_loss_factory(pad_id):
|
||
def loss(y_true, y_pred):
|
||
l = scce(y_true, y_pred)
|
||
mask = tf.cast(tf.not_equal(y_true, pad_id), tf.float32)
|
||
return tf.reduce_sum(l*mask) / tf.reduce_sum(mask)
|
||
return loss
|
||
|
||
model.compile(
|
||
optimizer="adam",
|
||
loss = {"q_out":masked_loss_factory(pad_q),
|
||
"a_out":masked_loss_factory(pad_a),
|
||
"type_out":"sparse_categorical_crossentropy"},
|
||
loss_weights={"q_out":1.0, "a_out":1.0, "type_out":0.3},
|
||
metrics={"q_out":"sparse_categorical_accuracy",
|
||
"a_out":"sparse_categorical_accuracy",
|
||
"type_out":tf.keras.metrics.SparseCategoricalAccuracy(name="type_acc")}
|
||
)
|
||
model.summary()
|
||
|
||
# -----------------------------------------------------------------
|
||
# 4. TRAIN
|
||
# -----------------------------------------------------------------
|
||
history = model.fit(
|
||
[X_tok, X_ner, X_srl, dec_q_in, dec_a_in],
|
||
[dec_q_out, dec_a_out, y_type],
|
||
validation_split=0.1,
|
||
epochs=30,
|
||
batch_size=64,
|
||
callbacks=[tf.keras.callbacks.EarlyStopping(patience=4, restore_best_weights=True)],
|
||
verbose=2
|
||
)
|
||
model.save("full_seq2seq.keras")
|
||
|
||
|
||
# -----------------------------------------------------------------
|
||
# 5. SAVE VOCABS (.pkl keeps python dict intact)
|
||
# -----------------------------------------------------------------
|
||
def save_vocab(v, name): pickle.dump(v, open(name,"wb"))
|
||
save_vocab(vocab_tok,"vocab_tok.pkl"); save_vocab(vocab_ner,"vocab_ner.pkl")
|
||
save_vocab(vocab_srl,"vocab_srl.pkl"); save_vocab(vocab_q, "vocab_q.pkl")
|
||
save_vocab(vocab_a, "vocab_a.pkl"); save_vocab(vocab_typ,"vocab_typ.pkl")
|
||
|
||
|
||
# -----------------------------------------------------------------
|
||
# 6. INFERENCE MODELS (encoder & decoders)
|
||
# -----------------------------------------------------------------
|
||
def build_inference_models(trained):
|
||
# encoder
|
||
t_in = Input((MAX_SENT,), name="t_in")
|
||
n_in = Input((MAX_SENT,), name="n_in")
|
||
s_in = Input((MAX_SENT,), name="s_in")
|
||
e_t = trained.get_layer("emb_tok")(t_in)
|
||
e_n = trained.get_layer("emb_ner")(n_in)
|
||
e_s = trained.get_layer("emb_srl")(s_in)
|
||
concat = Concatenate()([e_t,e_n,e_s])
|
||
_, h, c = trained.get_layer("enc_lstm")(concat)
|
||
enc_model = Model([t_in,n_in,s_in],[h,c])
|
||
|
||
# question‑decoder
|
||
dq_in = Input((1,), name="dq_tok")
|
||
dh = Input((units,), name="dh"); dc = Input((units,), name="dc")
|
||
dq_emb = trained.get_layer("emb_q")(dq_in)
|
||
dq_lstm, nh, nc = trained.get_layer("lstm_q")(dq_emb, initial_state=[dh,dc])
|
||
dq_out = trained.get_layer("q_out").layer(dq_lstm)
|
||
dec_q_model = Model([dq_in, dh, dc], [dq_out, nh, nc])
|
||
|
||
# answer‑decoder
|
||
da_in = Input((1,), name="da_tok")
|
||
ah = Input((units,), name="ah"); ac = Input((units,), name="ac")
|
||
da_emb = trained.get_layer("emb_a")(da_in)
|
||
da_lstm, nh2, nc2 = trained.get_layer("lstm_a")(da_emb, initial_state=[ah,ac])
|
||
da_out = trained.get_layer("a_out").layer(da_lstm)
|
||
dec_a_model = Model([da_in, ah, ac], [da_out, nh2, nc2])
|
||
|
||
# type classifier
|
||
type_dense = trained.get_layer("type_out")
|
||
type_model = Model([t_in,n_in,s_in], type_dense(_)) # use _ = enc_lstm output
|
||
|
||
return enc_model, dec_q_model, dec_a_model, type_model
|
||
|
||
encoder_model, decoder_q, decoder_a, classifier_model = build_inference_models(model)
|
||
|
||
inv_q = {v:k for k,v in vocab_q.items()}
|
||
inv_a = {v:k for k,v in vocab_a.items()}
|
||
|
||
def enc_pad(seq, vmap, maxlen):
|
||
x = [vmap.get(t, vmap["<unk>"]) for t in seq]
|
||
return x + [vmap["<pad>"]] * (maxlen-len(x))
|
||
|
||
def greedy_decode(tokens, ner, srl, max_q=20, max_a=10):
|
||
et = np.array([enc_pad(tokens, vocab_tok, MAX_SENT)])
|
||
en = np.array([enc_pad(ner, vocab_ner, MAX_SENT)])
|
||
es = np.array([enc_pad(srl, vocab_srl, MAX_SENT)])
|
||
|
||
h,c = encoder_model.predict([et,en,es], verbose=0)
|
||
|
||
# --- question
|
||
q_ids = []
|
||
tgt = np.array([[vocab_q["<sos>"]]])
|
||
for _ in range(max_q):
|
||
logits,h,c = decoder_q.predict([tgt,h,c], verbose=0)
|
||
nxt = int(logits[0,-1].argmax())
|
||
if nxt==vocab_q["<eos>"]: break
|
||
q_ids.append(nxt)
|
||
tgt = np.array([[nxt]])
|
||
|
||
# --- answer (re‑use fresh h,c)
|
||
h,c = encoder_model.predict([et,en,es], verbose=0)
|
||
a_ids = []
|
||
tgt = np.array([[vocab_a["<sos>"]]])
|
||
for _ in range(max_a):
|
||
logits,h,c = decoder_a.predict([tgt,h,c], verbose=0)
|
||
nxt = int(logits[0,-1].argmax())
|
||
if nxt==vocab_a["<eos>"]: break
|
||
a_ids.append(nxt)
|
||
tgt = np.array([[nxt]])
|
||
|
||
# --- type
|
||
t_id = int(classifier_model.predict([et,en,es], verbose=0).argmax())
|
||
|
||
return [inv_q[i] for i in q_ids], [inv_a[i] for i in a_ids], \
|
||
[k for k,v in vocab_typ.items() if v==t_id][0]
|
||
|
||
|
||
# -----------------------------------------------------------------
|
||
# 7. QUICK DEMO
|
||
# -----------------------------------------------------------------
|
||
test_tokens = ["soekarno","membacakan","teks","proklamasi","pada",
|
||
"17","agustus","1945"]
|
||
test_ner = ["B-PER","O","O","O","O","B-DATE","I-DATE","I-DATE"]
|
||
test_srl = ["ARG0","V","ARG1","ARG1","O","ARGM-TMP","ARGM-TMP","ARGM-TMP"]
|
||
|
||
q,a,t = greedy_decode(test_tokens,test_ner,test_srl,max_q=MAX_Q,max_a=MAX_A)
|
||
print("\nDEMO\n----")
|
||
print("Q :", " ".join(q))
|
||
print("A :", " ".join(a))
|
||
print("T :", t)
|
||
|
||
|
||
# -----------------------------------------------------------------
|
||
# 8. EVALUATION (corpus‑level BLEU + ROUGE‑1/‑L)
|
||
# -----------------------------------------------------------------
|
||
smooth = SmoothingFunction().method4
|
||
r_scorer = rouge_scorer.RougeScorer(["rouge1","rougeL"], use_stemmer=True)
|
||
|
||
def strip_special(seq, pad_id, eos_id):
|
||
return [x for x in seq if x not in (pad_id, eos_id)]
|
||
|
||
def ids_to_text(ids, inv):
|
||
return " ".join(inv[i] for i in ids)
|
||
|
||
def evaluate(n=200):
|
||
idxs = random.sample(range(len(samples)), n)
|
||
refs, hyps = [], []
|
||
agg = scoring.BootstrapAggregator()
|
||
|
||
for i in idxs:
|
||
gt_ids = strip_special(dec_q_out[i], pad_q, vocab_q["<eos>"])
|
||
ref = ids_to_text(gt_ids, inv_q)
|
||
pred = " ".join(greedy_decode(
|
||
samples[i]["tokens"],
|
||
samples[i]["ner"],
|
||
samples[i]["srl"]
|
||
)[0])
|
||
refs.append([ref.split()])
|
||
hyps.append(pred.split())
|
||
agg.add_scores(r_scorer.score(ref, pred))
|
||
|
||
bleu = corpus_bleu(refs, hyps, smoothing_function=smooth)
|
||
r1 = agg.aggregate()["rouge1"].mid
|
||
rL = agg.aggregate()["rougeL"].mid
|
||
|
||
print(f"\nEVAL (n={n})")
|
||
print(f"BLEU‑4 : {bleu:.4f}")
|
||
print(f"ROUGE‑1 : P={r1.precision:.3f} R={r1.recall:.3f} F1={r1.fmeasure:.3f}")
|
||
print(f"ROUGE‑L : P={rL.precision:.3f} R={rL.recall:.3f} F1={rL.fmeasure:.3f}")
|
||
|
||
evaluate(2) # run on 150 random samples
|