TIF_E41211115_lstm-quiz-gen.../question_generation/answer_predict.py

162 lines
5.8 KiB
Python

import json
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing.sequence import pad_sequences
from tensorflow.keras.preprocessing.text import tokenizer_from_json
import re
import random
# Load tokenizers and model configurations
with open("qa_tokenizers.json", "r") as f:
tokenizer_data = json.load(f)
tokenizer = tokenizer_from_json(tokenizer_data["word_tokenizer"])
ner_tokenizer = tokenizer_from_json(tokenizer_data["ner_tokenizer"])
srl_tokenizer = tokenizer_from_json(tokenizer_data["srl_tokenizer"])
answer_tokenizer = tokenizer_from_json(tokenizer_data["answer_tokenizer"])
q_type_tokenizer = tokenizer_from_json(tokenizer_data["q_type_tokenizer"])
max_context_len = tokenizer_data["max_context_len"]
max_question_len = tokenizer_data["max_question_len"]
max_token_len = tokenizer_data["max_token_len"]
q_type_vocab_size = len(q_type_tokenizer.word_index) + 1
# Load trained model
model = load_model("qa_lstm_model_final.keras")
def preprocess_text(text):
text = text.lower()
text = re.sub(r"\s+", " ", text).strip()
return text
def predict_answer(context, question, tokens, ner, srl, q_type):
context_seq = tokenizer.texts_to_sequences([preprocess_text(context)])
question_seq = tokenizer.texts_to_sequences([preprocess_text(question)])
token_seq = [tokenizer.texts_to_sequences([" ".join(tokens)])[0]]
ner_seq = [ner_tokenizer.texts_to_sequences([" ".join(ner)])[0]]
srl_seq = [srl_tokenizer.texts_to_sequences([" ".join(srl)])[0]]
q_type_idx = q_type_tokenizer.word_index.get(q_type, 0)
q_type_cat = tf.keras.utils.to_categorical(
[q_type_idx], num_classes=q_type_vocab_size
)
# Pad sequences
context_pad = pad_sequences(context_seq, maxlen=max_context_len, padding="post")
question_pad = pad_sequences(question_seq, maxlen=max_question_len, padding="post")
token_pad = pad_sequences(token_seq, maxlen=max_token_len, padding="post")
ner_pad = pad_sequences(ner_seq, maxlen=max_token_len, padding="post")
srl_pad = pad_sequences(srl_seq, maxlen=max_token_len, padding="post")
# Predict
prediction = model.predict(
[context_pad, question_pad, token_pad, ner_pad, srl_pad, q_type_cat], verbose=0
)
answer_idx = np.argmax(prediction[0])
# Retrieve predicted answer word
for word, idx in answer_tokenizer.word_index.items():
if idx == answer_idx:
return word
return "Unknown"
def generate_question_answer(context, tokens, ner, srl, question_type="isian"):
entities = {}
predicate = ""
for i, token in enumerate(tokens):
if ner[i] != "O":
entities.setdefault(ner[i], []).append(token)
if srl[i] == "V":
predicate = token
elif srl[i].startswith("ARG"):
entities.setdefault(srl[i], []).append(token)
subject = " ".join(entities.get("ARG0", [""]))
if question_type == "isian":
if "LOC" in entities:
location = " ".join(entities["LOC"])
return f"Dimana {subject} {predicate} ___", location
elif "DATE" in entities:
date = " ".join(entities["DATE"])
return f"Kapan {subject} {predicate} ___", date
elif question_type == "true_false":
if "DATE" in entities:
original_date = " ".join(entities["DATE"])
try:
modified_year = str(int(entities["DATE"][-1]) + random.randint(1, 5))
modified_date = (
f"{entities['DATE'][0]} {entities['DATE'][1]} {modified_year}"
)
except:
modified_date = original_date # Fallback if parsing fails
return f"{subject} {predicate} pada {modified_date} ___", "false"
elif question_type == "opsi":
if "LOC" in entities:
correct_location = " ".join(entities["LOC"])
distractors = ["singasari", "kuta", "banten", "kediri", "makassar"]
distractors = [d for d in distractors if d != correct_location]
options = random.sample(distractors, 3) + [correct_location]
random.shuffle(options)
return f"Dimana {subject} {predicate} ___", options, correct_location
return "Apa yang terjadi dalam teks ini ___", context
# ✅ Example Usage with Random Sampling
if __name__ == "__main__":
with open("../dataset/stable_qg_qa_train_dataset.json", "r") as f:
data = json.load(f)
# Randomly select an example for testing
test_item = random.choice(data)
test_qa = random.choice(test_item["qas"])
predicted_answer = predict_answer(
test_item["context"],
test_qa["question"],
test_item["tokens"],
test_item["ner"],
test_item["srl"],
test_qa["type"],
)
print(f"Context: {test_item['context']}")
print(f"Question: {test_qa['question']}")
print(f"True Answer: {test_qa['answer']}")
print(f"Predicted Answer: {predicted_answer}")
# Generate Random Question Example
example_context = test_item["context"]
example_tokens = test_item["tokens"]
example_ner = test_item["ner"]
example_srl = test_item["srl"]
random_question_type = random.choice(["isian", "true_false", "opsi"])
result = generate_question_answer(
example_context, example_tokens, example_ner, example_srl, random_question_type
)
print("\nGenerated Question Example:")
print(f"Context: {example_context}")
print(f"Question Type: {random_question_type}")
if random_question_type == "opsi":
question, options, correct_answer = result
print(f"Generated Question: {question}")
print(f"Options: {options}")
print(f"Correct Answer: {correct_answer}")
else:
question, answer = result
print(f"Generated Question: {question}")
print(f"Answer: {answer}")