TIF_E41211115_lstm-quiz-gen.../NER_SRL/test_model.py

53 lines
1.4 KiB
Python

import json
import numpy as np
import pickle
from keras.models import load_model
from keras.preprocessing.sequence import pad_sequences
model = load_model("multi_task_bilstm_model.keras")
with open("word2idx.pkl", "rb") as f:
word2idx = pickle.load(f)
with open("tag2idx_ner.pkl", "rb") as f:
tag2idx_ner = pickle.load(f)
with open("tag2idx_srl.pkl", "rb") as f:
tag2idx_srl = pickle.load(f)
idx2tag_ner = {i: t for t, i in tag2idx_ner.items()}
idx2tag_srl = {i: t for t, i in tag2idx_srl.items()}
max = 50
def predict_sentence(sentence):
tokens = sentence.strip().lower().split()
print(tokens)
x = [word2idx.get(w.lower(), word2idx["UNK"]) for w in tokens]
x = pad_sequences([x], maxlen=50, padding="post", value=word2idx["PAD"])
preds = model.predict(x)
pred_labels_ner = np.argmax(preds[0], axis=-1)[0]
pred_labels_srl = np.argmax(preds[1], axis=-1)[0]
print("Hasil prediksi NER:")
for token, label_idx in zip(tokens, pred_labels_ner[: len(tokens)]):
print(f"{token}\t{idx2tag_ner[int(label_idx)]}")
print("\nHasil prediksi SRL:")
for token, label_idx in zip(tokens, pred_labels_srl[: len(tokens)]):
print(f"{token}\t{idx2tag_srl[int(label_idx)]}")
if __name__ == "__main__":
try:
sentence = "aku lahir di indonesia"
predict_sentence(sentence)
except KeyboardInterrupt:
print("\n\nSelesai.")