TIF_E41211115_lstm-quiz-gen.../question_generation/question_predict.py

211 lines
8.2 KiB
Python

import numpy as np
import json
import tensorflow as tf
from tensorflow.keras.preprocessing.text import tokenizer_from_json
from tensorflow.keras.preprocessing.sequence import pad_sequences
from tensorflow.keras.models import load_model
import re
class QuestionPredictionModel:
def __init__(self, model_path, tokenizer_path):
"""
Initialize question prediction model with pre-trained model and tokenizers
"""
# Load model
self.model = load_model(model_path)
# Load tokenizers
with open(tokenizer_path, 'r') as f:
tokenizer_data = json.load(f)
# Reconstruct tokenizers
self.word_tokenizer = tokenizer_from_json(tokenizer_data['word_tokenizer'])
self.ner_tokenizer = tokenizer_from_json(tokenizer_data['ner_tokenizer'])
self.srl_tokenizer = tokenizer_from_json(tokenizer_data['srl_tokenizer'])
self.q_type_tokenizer = tokenizer_from_json(tokenizer_data['q_type_tokenizer'])
# Get max lengths
self.max_context_len = tokenizer_data['max_context_len']
self.max_answer_len = tokenizer_data['max_answer_len']
self.max_question_len = tokenizer_data['max_question_len']
self.max_token_len = tokenizer_data['max_token_len']
# Get vocabulary sizes
self.vocab_size = len(self.word_tokenizer.word_index) + 1
self.q_type_vocab_size = len(self.q_type_tokenizer.word_index) + 1
def preprocess_text(self, text):
"""Basic text preprocessing"""
text = text.lower()
text = re.sub(r"\s+", " ", text).strip()
return text
def predict_question(self, context, answer, tokens, ner, srl, q_type):
"""
Predict a question based on given context, answer, tokens, NER, SRL, and question type
Args:
context (str): The context text
answer (str): The answer to generate a question for
tokens (list): List of tokens
ner (list): List of NER tags corresponding to tokens
srl (list): List of SRL tags corresponding to tokens
q_type (str): Question type ('isian', 'opsi', or 'true_false')
Returns:
str: The predicted question
"""
# Preprocess inputs
context = self.preprocess_text(context)
answer = self.preprocess_text(answer)
# Convert to sequences
context_seq = self.word_tokenizer.texts_to_sequences([context])[0]
answer_seq = self.word_tokenizer.texts_to_sequences([answer])[0]
tokens_seq = self.word_tokenizer.texts_to_sequences([" ".join(tokens)])[0]
ner_seq = self.ner_tokenizer.texts_to_sequences([" ".join(ner)])[0]
srl_seq = self.srl_tokenizer.texts_to_sequences([" ".join(srl)])[0]
# Pad sequences
context_padded = pad_sequences([context_seq], maxlen=self.max_context_len, padding="post")
answer_padded = pad_sequences([answer_seq], maxlen=self.max_answer_len, padding="post")
tokens_padded = pad_sequences([tokens_seq], maxlen=self.max_token_len, padding="post")
ner_padded = pad_sequences([ner_seq], maxlen=self.max_token_len, padding="post")
srl_padded = pad_sequences([srl_seq], maxlen=self.max_token_len, padding="post")
# One-hot encode question type
q_type_idx = self.q_type_tokenizer.word_index.get(q_type, 0)
q_type_categorical = tf.keras.utils.to_categorical(
[q_type_idx], num_classes=self.q_type_vocab_size
)
# Make prediction
predicted_seq = self.model.predict(
[context_padded, answer_padded, tokens_padded, ner_padded, srl_padded, q_type_categorical]
)
# Convert predictions to tokens (taking the highest probability token at each position)
predicted_indices = np.argmax(predicted_seq[0], axis=1)
# Create reversed word index for converting indices back to words
reverse_word_index = {v: k for k, v in self.word_tokenizer.word_index.items()}
# Convert indices to words
predicted_words = []
for idx in predicted_indices:
if idx != 0: # Skip padding tokens
predicted_words.append(reverse_word_index.get(idx, ''))
# Form the question
predicted_question = ' '.join(predicted_words)
# Add "___" to the end based on question type convention
if "___" not in predicted_question:
predicted_question += " ___"
return predicted_question
def batch_predict_questions(self, data):
"""
Predict questions for a batch of data
Args:
data (list): List of dictionaries with context, tokens, ner, srl, and answers
Returns:
list: List of predicted questions
"""
results = []
for item in data:
context = item["context"]
tokens = item["tokens"]
ner = item["ner"]
srl = item["srl"]
# If there are Q&A pairs, use them for evaluation
if "qas" in item:
for qa in item["qas"]:
answer = qa["answer"]
q_type = qa["type"]
ground_truth = qa["question"]
predicted_question = self.predict_question(
context, answer, tokens, ner, srl, q_type
)
results.append({
"context": context,
"answer": answer,
"predicted_question": predicted_question,
"ground_truth": ground_truth,
"question_type": q_type
})
else:
# If no Q&A pairs, generate questions for all question types
for q_type in ["isian", "true_false", "opsi"]:
# For demo purposes, use a placeholder answer (would need actual answers in real use)
# In practice, you might extract potential answers from the context
placeholders = {
"isian": "placeholder",
"true_false": "true",
"opsi": "placeholder"
}
predicted_question = self.predict_question(
context, placeholders[q_type], tokens, ner, srl, q_type
)
results.append({
"context": context,
"predicted_question": predicted_question,
"question_type": q_type
})
return results
# Example usage
if __name__ == "__main__":
# Load test data
with open("data_converted.json", "r") as f:
test_data = json.load(f)
# Initialize model
question_predictor = QuestionPredictionModel(
model_path="question_prediction_model_final.h5",
tokenizer_path="question_prediction_tokenizers.json"
)
# Example single prediction
sample = test_data[0]
context = sample["context"]
tokens = sample["tokens"]
ner = sample["ner"]
srl = sample["srl"]
answer = sample["qas"][0]["answer"]
q_type = sample["qas"][0]["type"]
predicted_question = question_predictor.predict_question(
context, answer, tokens, ner, srl, q_type
)
print(f"Context: {context}")
print(f"Answer: {answer}")
print(f"Question Type: {q_type}")
print(f"Predicted Question: {predicted_question}")
print(f"Ground Truth: {sample['qas'][0]['question']}")
# Batch prediction
results = question_predictor.batch_predict_questions(test_data[:3])
print("\nBatch Results:")
for i, result in enumerate(results):
print(f"\nResult {i+1}:")
print(f"Context: {result['context']}")
print(f"Answer: {result.get('answer', 'N/A')}")
print(f"Question Type: {result['question_type']}")
print(f"Predicted Question: {result['predicted_question']}")
if 'ground_truth' in result:
print(f"Ground Truth: {result['ground_truth']}")