59 lines
1.5 KiB
Python
59 lines
1.5 KiB
Python
MAX_CTX_LEN = 50
|
|
|
|
|
|
# -- dummy placeholder untuk model NER/SRL Anda -------------------------------
|
|
def predict_ner(tokens): # ganti sesuai implementasi
|
|
return ["O"] * len(tokens)
|
|
|
|
|
|
def predict_srl(tokens): # ganti sesuai implementasi
|
|
return ["O"] * len(tokens)
|
|
|
|
|
|
# ------------------------------------------------------------------------------
|
|
|
|
|
|
def greedy_decode(context_tokens):
|
|
"""Menghasilkan satu pertanyaan (greedy)."""
|
|
# 6.1 Tagging
|
|
ner_tags = predict_ner(context_tokens)
|
|
srl_tags = predict_srl(context_tokens)
|
|
|
|
# 6.2 Encode
|
|
ctx_ids = encode(context_tokens, w2i_ctx, MAX_CTX_LEN)[None]
|
|
ner_ids = encode(ner_tags, t2i_ner, MAX_CTX_LEN)[None]
|
|
srl_ids = encode(srl_tags, t2i_srl, MAX_CTX_LEN)[None]
|
|
|
|
dec_seq = [w2i_q["<bos>"]]
|
|
for _ in range(MAX_Q_LEN - 1):
|
|
dec_pad = dec_seq + [w2i_q["<pad>"]] * (MAX_Q_LEN - len(dec_seq))
|
|
pred = model.predict(
|
|
[ctx_ids, ner_ids, srl_ids, np.array([dec_pad])], verbose=0
|
|
)
|
|
next_id = int(pred[0, len(dec_seq) - 1].argmax())
|
|
if i2w_q[next_id] == "<eos>":
|
|
break
|
|
dec_seq.append(next_id)
|
|
|
|
tokens_q = [i2w_q[t] for t in dec_seq[1:]]
|
|
return " ".join(tokens_q)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
sample = [
|
|
"Keberagaman",
|
|
"potensi",
|
|
"sumber",
|
|
"daya",
|
|
"alam",
|
|
"Indonesia",
|
|
"tidak",
|
|
"lepas",
|
|
"dari",
|
|
"proses",
|
|
"geografis",
|
|
".",
|
|
]
|
|
print("\n[CTX]", " ".join(sample))
|
|
print("[Q] ", greedy_decode(sample))
|