TIF_E41211115_lstm-quiz-gen.../old/testing.py

71 lines
2.6 KiB
Python

import tensorflow as tf
import pickle
import numpy as np
from tensorflow.keras.preprocessing.sequence import pad_sequences
class QuizGeneratorService:
def __init__(
self,
model_path="lstm_multi_output_model.keras",
tokenizer_path="tokenizer.pkl",
max_length=100,
):
# Load the tokenizer
with open(tokenizer_path, "rb") as handle:
self.tokenizer = pickle.load(handle)
# Load the trained model
self.model = tf.keras.models.load_model(model_path)
self.max_length = max_length
def sequence_to_text(self, sequence):
"""
Convert a sequence of indices to text using the tokenizer's index_word mapping.
Skips any padding token (0).
"""
return [
self.tokenizer.index_word.get(idx, "<OOV>") for idx in sequence if idx != 0
]
def generate_quiz(self, context):
"""
Given a raw context string, this method tokenizes the input, performs inference
using the loaded model, and returns the generated question, answer, and question type.
Parameters:
context (str): The raw context text.
Returns:
dict: A dictionary containing the generated question, answer, and question type.
"""
# Tokenize the context directly (assuming no extra preprocessing is needed)
sequence = self.tokenizer.texts_to_sequences([context])
print(sequence)
padded_sequence = pad_sequences(
sequence, maxlen=self.max_length, padding="post", truncating="post"
)
# Use the same padded sequence for both the context and the question decoder input.
pred_question, pred_answer, pred_qtype = self.model.predict(
[padded_sequence, padded_sequence]
)
# Convert predicted sequences to text (using argmax for each timestep)
question_tokens = self.sequence_to_text(np.argmax(pred_question[0], axis=-1))
answer_tokens = self.sequence_to_text(np.argmax(pred_answer[0], axis=-1))
qtype = int(np.argmax(pred_qtype[0]))
return {
"generated_question": " ".join(question_tokens),
"generated_answer": " ".join(answer_tokens),
"question_type": qtype, # You can map this integer to a descriptive label if needed
}
# Example usage:
if __name__ == "__main__":
quiz_service = QuizGeneratorService()
context_input = "Pada tahun 1619, Jan Pieterszoon Coen menaklukkan Jayakarta dan menggantinya dengan nama Batavia, yang menjadi pusat kekuasaan VOC di Nusantara."
result = quiz_service.generate_quiz(context_input)
print(result)