115 lines
3.0 KiB
Python
115 lines
3.0 KiB
Python
import pickle
|
|
from tensorflow.keras.preprocessing.sequence import pad_sequences
|
|
from keras.models import load_model
|
|
import numpy as np
|
|
|
|
|
|
def infer_from_input(input_data, maxlen=50):
|
|
|
|
with open("QC/tokenizers.pkl", "rb") as f:
|
|
tokenizers = pickle.load(f)
|
|
|
|
model = load_model("QC/new_model_lstm_qg.keras")
|
|
|
|
tok_token = tokenizers["token"]
|
|
tok_ner = tokenizers["ner"]
|
|
tok_srl = tokenizers["srl"]
|
|
tok_q = tokenizers["question"]
|
|
tok_a = tokenizers["answer"]
|
|
tok_type = tokenizers["type"]
|
|
|
|
# Prepare input
|
|
tokens = input_data["tokens"]
|
|
ner = input_data["ner"]
|
|
srl = input_data["srl"]
|
|
|
|
x_tok = pad_sequences(
|
|
[tok_token.texts_to_sequences([tokens])[0]], maxlen=maxlen, padding="post"
|
|
)
|
|
x_ner = pad_sequences(
|
|
[tok_ner.texts_to_sequences([ner])[0]], maxlen=maxlen, padding="post"
|
|
)
|
|
x_srl = pad_sequences(
|
|
[tok_srl.texts_to_sequences([srl])[0]], maxlen=maxlen, padding="post"
|
|
)
|
|
|
|
# Predict
|
|
pred_q, pred_a, pred_type = model.predict([x_tok, x_ner, x_srl])
|
|
pred_q_ids = np.argmax(pred_q[0], axis=-1)
|
|
pred_a_ids = np.argmax(pred_a[0], axis=-1)
|
|
pred_type_id = np.argmax(pred_type[0])
|
|
|
|
# Decode
|
|
index2word_q = {v: k for k, v in tok_q.word_index.items()}
|
|
index2word_a = {v: k for k, v in tok_a.word_index.items()}
|
|
index2word_q[0] = "<PAD>"
|
|
index2word_a[0] = "<PAD>"
|
|
|
|
decoded_q = [index2word_q[i] for i in pred_q_ids if i != 0]
|
|
decoded_a = [index2word_a[i] for i in pred_a_ids if i != 0]
|
|
|
|
index2type = {v - 1: k for k, v in tok_type.word_index.items()}
|
|
decoded_type = index2type.get(pred_type_id, "unknown")
|
|
|
|
return {
|
|
"question": " ".join(decoded_q),
|
|
"answer": " ".join(decoded_a),
|
|
"type": decoded_type,
|
|
}
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
# Example input
|
|
input_data = {
|
|
"tokens": ["Nama", "lengkap", "saya", "adalah", "Bayu", "Prabowo", "."],
|
|
"ner": ["O", "O", "O", "O", "B-PER", "I-PER", "O"],
|
|
"srl": ["ARG1", "ARG1", "ARG2", "V", "ARG0", "ARG0", "O"],
|
|
}
|
|
|
|
# input_data = {
|
|
# "tokens": [
|
|
# "Proklamasi",
|
|
# "Kemerdekaan",
|
|
# "Indonesia",
|
|
# "diproklamasikan",
|
|
# "pada",
|
|
# "17",
|
|
# "Agustus",
|
|
# "1945",
|
|
# "di",
|
|
# "Jakarta",
|
|
# ".",
|
|
# ],
|
|
# "ner": [
|
|
# "B-EVENT",
|
|
# "I-EVENT",
|
|
# "I-EVENT",
|
|
# "O",
|
|
# "O",
|
|
# "B-DATE",
|
|
# "I-DATE",
|
|
# "I-DATE",
|
|
# "O",
|
|
# "B-LOC",
|
|
# "O",
|
|
# ],
|
|
# "srl": [
|
|
# "ARG1",
|
|
# "ARG1",
|
|
# "ARG1",
|
|
# "V",
|
|
# "O",
|
|
# "ARGM-TMP",
|
|
# "ARGM-TMP",
|
|
# "ARGM-TMP",
|
|
# "O",
|
|
# "ARGM-LOC",
|
|
# "O",
|
|
# ],
|
|
# }
|
|
|
|
# Predict
|
|
result = infer_from_input(input_data)
|
|
print(result)
|