import pickle from tensorflow.keras.preprocessing.sequence import pad_sequences from keras.models import load_model import numpy as np def infer_from_input(input_data, maxlen=50): with open("tokenizers.pkl", "rb") as f: tokenizers = pickle.load(f) model = load_model("new_model_lstm_qg.keras") tok_token = tokenizers["token"] tok_ner = tokenizers["ner"] tok_srl = tokenizers["srl"] tok_q = tokenizers["question"] tok_a = tokenizers["answer"] tok_type = tokenizers["type"] # Prepare input tokens = input_data["tokens"] ner = input_data["ner"] srl = input_data["srl"] x_tok = pad_sequences( [tok_token.texts_to_sequences([tokens])[0]], maxlen=maxlen, padding="post" ) x_ner = pad_sequences( [tok_ner.texts_to_sequences([ner])[0]], maxlen=maxlen, padding="post" ) x_srl = pad_sequences( [tok_srl.texts_to_sequences([srl])[0]], maxlen=maxlen, padding="post" ) # Predict pred_q, pred_a, pred_type = model.predict([x_tok, x_ner, x_srl]) pred_q_ids = np.argmax(pred_q[0], axis=-1) pred_a_ids = np.argmax(pred_a[0], axis=-1) pred_type_id = np.argmax(pred_type[0]) # Decode index2word_q = {v: k for k, v in tok_q.word_index.items()} index2word_a = {v: k for k, v in tok_a.word_index.items()} index2word_q[0] = "" index2word_a[0] = "" decoded_q = [index2word_q[i] for i in pred_q_ids if i != 0] decoded_a = [index2word_a[i] for i in pred_a_ids if i != 0] index2type = {v - 1: k for k, v in tok_type.word_index.items()} decoded_type = index2type.get(pred_type_id, "unknown") return { "question": " ".join(decoded_q), "answer": " ".join(decoded_a), "type": decoded_type, } if __name__ == "__main__": # Example input input_data = { "tokens": [ "Mars", "disebut", "juga", "sebagai", "planet", "merah", "karena", "permukaannya", "banyak", "mengandung", "zat", "besi", ".", ], "ner": ["B-LOC", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O"], "srl": [ "ARG0", "V", "O", "O", "ARG1", "ARG1", "ARGM-CAU", "ARG1", "ARGM-MNR", "ARGM-MNR", "ARG1", "ARG1", "O", ], } # input_data = { # "tokens": [ # "Proklamasi", # "Kemerdekaan", # "Indonesia", # "diproklamasikan", # "pada", # "17", # "Agustus", # "1945", # "di", # "Jakarta", # ".", # ], # "ner": [ # "B-EVENT", # "I-EVENT", # "I-EVENT", # "O", # "O", # "B-DATE", # "I-DATE", # "I-DATE", # "O", # "B-LOC", # "O", # ], # "srl": [ # "ARG1", # "ARG1", # "ARG1", # "V", # "O", # "ARGM-TMP", # "ARGM-TMP", # "ARGM-TMP", # "O", # "ARGM-LOC", # "O", # ], # } # Predict result = infer_from_input(input_data) print(result)