TIF_E41211115_lstm-quiz-gen.../NER_SRL/train.py

108 lines
3.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json, pickle
import numpy as np
from keras.models import Model
from keras.layers import Input, Embedding, Bidirectional, LSTM, TimeDistributed, Dense
from keras.preprocessing.sequence import pad_sequences
from keras.utils import to_categorical
from seqeval.metrics import classification_report
# ---------- 1. Muat data ----------
with open("dataset/dataset_ner_srl.json", encoding="utf-8") as f:
data = json.load(f)
sentences = [[tok.lower() for tok in item["tokens"]] for item in data]
labels_ner = [item["labels_ner"] for item in data]
labels_srl = [item["labels_srl"] for item in data]
for i, label_seq in enumerate(labels_ner):
if "V" in label_seq:
print(f"Label 'V' ditemukan di index {i}: {label_seq}")
# ---------- 2. Bangun vocab & label map ----------
words = sorted({w for s in sentences for w in s})
ner_tags = sorted({t for seq in labels_ner for t in seq})
srl_tags = sorted({t for seq in labels_srl for t in seq})
word2idx = {w: i + 2 for i, w in enumerate(words)}
word2idx["PAD"], word2idx["UNK"] = 0, 1
tag2idx_ner = {t: i for i, t in enumerate(ner_tags)}
tag2idx_srl = {t: i for i, t in enumerate(srl_tags)}
idx2tag_ner = {i: t for t, i in tag2idx_ner.items()}
idx2tag_srl = {i: t for t, i in tag2idx_srl.items()}
# ---------- 3. Encoding token & label ----------
X = [[word2idx.get(w, word2idx["UNK"]) for w in s] for s in sentences]
y_ner = [[tag2idx_ner[t] for t in seq] for seq in labels_ner]
y_srl = [[tag2idx_srl[t] for t in seq] for seq in labels_srl]
maxlen = max(len(seq) for seq in X)
X = pad_sequences(X, maxlen=maxlen, padding="post", value=word2idx["PAD"])
y_ner = pad_sequences(y_ner, maxlen=maxlen, padding="post", value=tag2idx_ner["O"])
y_srl = pad_sequences(y_srl, maxlen=maxlen, padding="post", value=tag2idx_srl["O"])
y_ner = [to_categorical(seq, num_classes=len(tag2idx_ner)) for seq in y_ner]
y_srl = [to_categorical(seq, num_classes=len(tag2idx_srl)) for seq in y_srl]
# cast ke np.array biar Keras happy
X = np.array(X)
y_ner = np.array(y_ner)
y_srl = np.array(y_srl)
# ---------- 4. Arsitektur BiLSTM multitask ----------
input_layer = Input(shape=(maxlen,))
embed = Embedding(len(word2idx), 64)(input_layer)
bilstm = Bidirectional(LSTM(64, return_sequences=True))(embed)
ner_output = TimeDistributed(
Dense(len(tag2idx_ner), activation="softmax"), name="ner_output"
)(bilstm)
srl_output = TimeDistributed(
Dense(len(tag2idx_srl), activation="softmax"), name="srl_output"
)(bilstm)
model = Model(inputs=input_layer, outputs=[ner_output, srl_output])
model.compile(
optimizer="adam",
loss={
"ner_output": "categorical_crossentropy",
"srl_output": "categorical_crossentropy",
},
metrics={"ner_output": "accuracy", "srl_output": "accuracy"},
)
model.summary()
# ---------- 5. Training ----------
model.fit(
X, {"ner_output": y_ner, "srl_output": y_srl}, batch_size=2, epochs=10, verbose=1
)
# ---------- 6. Simpan artefak ----------
model.save("NER_SRL/multi_task_bilstm_model.keras")
with open("NER_SRL/word2idx.pkl", "wb") as f:
pickle.dump(word2idx, f)
with open("NER_SRL/tag2idx_ner.pkl", "wb") as f:
pickle.dump(tag2idx_ner, f)
with open("NER_SRL/tag2idx_srl.pkl", "wb") as f:
pickle.dump(tag2idx_srl, f)
# ---------- 7. Evaluasi ----------
y_pred_ner, y_pred_srl = model.predict(X, verbose=0)
def decode(pred, true, idx2tag):
true_tags = [[idx2tag[np.argmax(tok)] for tok in seq] for seq in true]
pred_tags = [[idx2tag[np.argmax(tok)] for tok in seq] for seq in pred]
return true_tags, pred_tags
true_ner, pred_ner = decode(y_pred_ner, y_ner, idx2tag_ner)
true_srl, pred_srl = decode(y_pred_srl, y_srl, idx2tag_srl)
print("\n📊 [NER] Classification Report:")
print(classification_report(true_ner, pred_ner))
print("\n📊 [SRL] Classification Report:")
print(classification_report(true_srl, pred_srl))