TIF_E41211115_lstm-quiz-gen.../QC/test_model_qc.py

59 lines
1.5 KiB
Python

MAX_CTX_LEN = 50
# -- dummy placeholder untuk model NER/SRL Anda -------------------------------
def predict_ner(tokens): # ganti sesuai implementasi
return ["O"] * len(tokens)
def predict_srl(tokens): # ganti sesuai implementasi
return ["O"] * len(tokens)
# ------------------------------------------------------------------------------
def greedy_decode(context_tokens):
"""Menghasilkan satu pertanyaan (greedy)."""
# 6.1 Tagging
ner_tags = predict_ner(context_tokens)
srl_tags = predict_srl(context_tokens)
# 6.2 Encode
ctx_ids = encode(context_tokens, w2i_ctx, MAX_CTX_LEN)[None]
ner_ids = encode(ner_tags, t2i_ner, MAX_CTX_LEN)[None]
srl_ids = encode(srl_tags, t2i_srl, MAX_CTX_LEN)[None]
dec_seq = [w2i_q["<bos>"]]
for _ in range(MAX_Q_LEN - 1):
dec_pad = dec_seq + [w2i_q["<pad>"]] * (MAX_Q_LEN - len(dec_seq))
pred = model.predict(
[ctx_ids, ner_ids, srl_ids, np.array([dec_pad])], verbose=0
)
next_id = int(pred[0, len(dec_seq) - 1].argmax())
if i2w_q[next_id] == "<eos>":
break
dec_seq.append(next_id)
tokens_q = [i2w_q[t] for t in dec_seq[1:]]
return " ".join(tokens_q)
if __name__ == "__main__":
sample = [
"Keberagaman",
"potensi",
"sumber",
"daya",
"alam",
"Indonesia",
"tidak",
"lepas",
"dari",
"proses",
"geografis",
".",
]
print("\n[CTX]", " ".join(sample))
print("[Q] ", greedy_decode(sample))