152 lines
5.7 KiB
Python
152 lines
5.7 KiB
Python
import json
|
|
import numpy as np
|
|
import tensorflow as tf
|
|
from tensorflow.keras.models import load_model
|
|
from tensorflow.keras.preprocessing.sequence import pad_sequences
|
|
from tensorflow.keras.preprocessing.text import tokenizer_from_json
|
|
import re
|
|
import random
|
|
|
|
# Load tokenizers and model configurations
|
|
with open("qa_tokenizers.json", "r") as f:
|
|
tokenizer_data = json.load(f)
|
|
|
|
tokenizer = tokenizer_from_json(tokenizer_data["word_tokenizer"])
|
|
ner_tokenizer = tokenizer_from_json(tokenizer_data["ner_tokenizer"])
|
|
srl_tokenizer = tokenizer_from_json(tokenizer_data["srl_tokenizer"])
|
|
answer_tokenizer = tokenizer_from_json(tokenizer_data["answer_tokenizer"])
|
|
q_type_tokenizer = tokenizer_from_json(tokenizer_data["q_type_tokenizer"])
|
|
|
|
max_context_len = tokenizer_data["max_context_len"]
|
|
max_question_len = tokenizer_data["max_question_len"]
|
|
max_token_len = tokenizer_data["max_token_len"]
|
|
q_type_vocab_size = len(q_type_tokenizer.word_index) + 1
|
|
|
|
# Load trained model
|
|
model = load_model("qa_lstm_model_final.h5")
|
|
|
|
def preprocess_text(text):
|
|
text = text.lower()
|
|
text = re.sub(r"\s+", " ", text).strip()
|
|
return text
|
|
|
|
def predict_answer(context, question, tokens, ner, srl, q_type):
|
|
context_seq = tokenizer.texts_to_sequences([preprocess_text(context)])
|
|
question_seq = tokenizer.texts_to_sequences([preprocess_text(question)])
|
|
token_seq = [tokenizer.texts_to_sequences([" ".join(tokens)])[0]]
|
|
ner_seq = [ner_tokenizer.texts_to_sequences([" ".join(ner)])[0]]
|
|
srl_seq = [srl_tokenizer.texts_to_sequences([" ".join(srl)])[0]]
|
|
|
|
q_type_idx = q_type_tokenizer.word_index.get(q_type, 0)
|
|
q_type_cat = tf.keras.utils.to_categorical([q_type_idx], num_classes=q_type_vocab_size)
|
|
|
|
# Pad sequences
|
|
context_pad = pad_sequences(context_seq, maxlen=max_context_len, padding="post")
|
|
question_pad = pad_sequences(question_seq, maxlen=max_question_len, padding="post")
|
|
token_pad = pad_sequences(token_seq, maxlen=max_token_len, padding="post")
|
|
ner_pad = pad_sequences(ner_seq, maxlen=max_token_len, padding="post")
|
|
srl_pad = pad_sequences(srl_seq, maxlen=max_token_len, padding="post")
|
|
|
|
# Predict
|
|
prediction = model.predict([context_pad, question_pad, token_pad, ner_pad, srl_pad, q_type_cat], verbose=0)
|
|
answer_idx = np.argmax(prediction[0])
|
|
|
|
# Retrieve predicted answer word
|
|
for word, idx in answer_tokenizer.word_index.items():
|
|
if idx == answer_idx:
|
|
return word
|
|
|
|
return "Unknown"
|
|
|
|
def generate_question_answer(context, tokens, ner, srl, question_type="isian"):
|
|
entities = {}
|
|
predicate = ""
|
|
|
|
for i, token in enumerate(tokens):
|
|
if ner[i] != "O":
|
|
entities.setdefault(ner[i], []).append(token)
|
|
if srl[i] == "V":
|
|
predicate = token
|
|
elif srl[i].startswith("ARG"):
|
|
entities.setdefault(srl[i], []).append(token)
|
|
|
|
subject = " ".join(entities.get("ARG0", [""]))
|
|
|
|
if question_type == "isian":
|
|
if "LOC" in entities:
|
|
location = " ".join(entities["LOC"])
|
|
return f"Dimana {subject} {predicate} ___", location
|
|
elif "DATE" in entities:
|
|
date = " ".join(entities["DATE"])
|
|
return f"Kapan {subject} {predicate} ___", date
|
|
|
|
elif question_type == "true_false":
|
|
if "DATE" in entities:
|
|
original_date = " ".join(entities["DATE"])
|
|
try:
|
|
modified_year = str(int(entities['DATE'][-1]) + random.randint(1, 5))
|
|
modified_date = f"{entities['DATE'][0]} {entities['DATE'][1]} {modified_year}"
|
|
except:
|
|
modified_date = original_date # Fallback if parsing fails
|
|
return f"{subject} {predicate} pada {modified_date} ___", "false"
|
|
|
|
elif question_type == "opsi":
|
|
if "LOC" in entities:
|
|
correct_location = " ".join(entities["LOC"])
|
|
distractors = ["singasari", "kuta", "banten", "kediri", "makassar"]
|
|
distractors = [d for d in distractors if d != correct_location]
|
|
options = random.sample(distractors, 3) + [correct_location]
|
|
random.shuffle(options)
|
|
return f"Dimana {subject} {predicate} ___", options, correct_location
|
|
|
|
return "Apa yang terjadi dalam teks ini ___", context
|
|
|
|
# ✅ Example Usage with Random Sampling
|
|
if __name__ == "__main__":
|
|
with open("data_converted.json", "r") as f:
|
|
data = json.load(f)
|
|
|
|
# Randomly select an example for testing
|
|
test_item = random.choice(data)
|
|
test_qa = random.choice(test_item["qas"])
|
|
|
|
predicted_answer = predict_answer(
|
|
test_item["context"],
|
|
test_qa["question"],
|
|
test_item["tokens"],
|
|
test_item["ner"],
|
|
test_item["srl"],
|
|
test_qa["type"]
|
|
)
|
|
|
|
print(f"Context: {test_item['context']}")
|
|
print(f"Question: {test_qa['question']}")
|
|
print(f"True Answer: {test_qa['answer']}")
|
|
print(f"Predicted Answer: {predicted_answer}")
|
|
|
|
# Generate Random Question Example
|
|
example_context = test_item["context"]
|
|
example_tokens = test_item["tokens"]
|
|
example_ner = test_item["ner"]
|
|
example_srl = test_item["srl"]
|
|
|
|
random_question_type = random.choice(["isian", "true_false", "opsi"])
|
|
|
|
result = generate_question_answer(
|
|
example_context, example_tokens, example_ner, example_srl, random_question_type
|
|
)
|
|
|
|
print("\nGenerated Question Example:")
|
|
print(f"Context: {example_context}")
|
|
print(f"Question Type: {random_question_type}")
|
|
|
|
if random_question_type == "opsi":
|
|
question, options, correct_answer = result
|
|
print(f"Generated Question: {question}")
|
|
print(f"Options: {options}")
|
|
print(f"Correct Answer: {correct_answer}")
|
|
else:
|
|
question, answer = result
|
|
print(f"Generated Question: {question}")
|
|
print(f"Answer: {answer}")
|