71 lines
2.6 KiB
Python
71 lines
2.6 KiB
Python
import tensorflow as tf
|
|
import pickle
|
|
import numpy as np
|
|
from tensorflow.keras.preprocessing.sequence import pad_sequences
|
|
|
|
|
|
class QuizGeneratorService:
|
|
def __init__(
|
|
self,
|
|
model_path="lstm_multi_output_model.keras",
|
|
tokenizer_path="tokenizer.pkl",
|
|
max_length=100,
|
|
):
|
|
# Load the tokenizer
|
|
with open(tokenizer_path, "rb") as handle:
|
|
self.tokenizer = pickle.load(handle)
|
|
# Load the trained model
|
|
self.model = tf.keras.models.load_model(model_path)
|
|
self.max_length = max_length
|
|
|
|
def sequence_to_text(self, sequence):
|
|
"""
|
|
Convert a sequence of indices to text using the tokenizer's index_word mapping.
|
|
Skips any padding token (0).
|
|
"""
|
|
return [
|
|
self.tokenizer.index_word.get(idx, "<OOV>") for idx in sequence if idx != 0
|
|
]
|
|
|
|
def generate_quiz(self, context):
|
|
"""
|
|
Given a raw context string, this method tokenizes the input, performs inference
|
|
using the loaded model, and returns the generated question, answer, and question type.
|
|
|
|
Parameters:
|
|
context (str): The raw context text.
|
|
|
|
Returns:
|
|
dict: A dictionary containing the generated question, answer, and question type.
|
|
"""
|
|
# Tokenize the context directly (assuming no extra preprocessing is needed)
|
|
sequence = self.tokenizer.texts_to_sequences([context])
|
|
print(sequence)
|
|
padded_sequence = pad_sequences(
|
|
sequence, maxlen=self.max_length, padding="post", truncating="post"
|
|
)
|
|
|
|
# Use the same padded sequence for both the context and the question decoder input.
|
|
pred_question, pred_answer, pred_qtype = self.model.predict(
|
|
[padded_sequence, padded_sequence]
|
|
)
|
|
|
|
# Convert predicted sequences to text (using argmax for each timestep)
|
|
question_tokens = self.sequence_to_text(np.argmax(pred_question[0], axis=-1))
|
|
answer_tokens = self.sequence_to_text(np.argmax(pred_answer[0], axis=-1))
|
|
qtype = int(np.argmax(pred_qtype[0]))
|
|
|
|
return {
|
|
"generated_question": " ".join(question_tokens),
|
|
"generated_answer": " ".join(answer_tokens),
|
|
"question_type": qtype, # You can map this integer to a descriptive label if needed
|
|
}
|
|
|
|
|
|
# Example usage:
|
|
if __name__ == "__main__":
|
|
quiz_service = QuizGeneratorService()
|
|
context_input = "Pada tahun 1619, Jan Pieterszoon Coen menaklukkan Jayakarta dan menggantinya dengan nama Batavia, yang menjadi pusat kekuasaan VOC di Nusantara."
|
|
result = quiz_service.generate_quiz(context_input)
|
|
print(result)
|