108 lines
3.8 KiB
Python
108 lines
3.8 KiB
Python
import json, pickle
|
||
import numpy as np
|
||
from keras.models import Model
|
||
from keras.layers import Input, Embedding, Bidirectional, LSTM, TimeDistributed, Dense
|
||
from keras.preprocessing.sequence import pad_sequences
|
||
from keras.utils import to_categorical
|
||
from seqeval.metrics import classification_report
|
||
|
||
# ---------- 1. Muat data ----------
|
||
with open("dataset/dataset_ner_srl.json", encoding="utf-8") as f:
|
||
data = json.load(f)
|
||
|
||
sentences = [[tok.lower() for tok in item["tokens"]] for item in data]
|
||
labels_ner = [item["labels_ner"] for item in data]
|
||
labels_srl = [item["labels_srl"] for item in data]
|
||
|
||
for i, label_seq in enumerate(labels_ner):
|
||
if "V" in label_seq:
|
||
print(f"Label 'V' ditemukan di index {i}: {label_seq}")
|
||
|
||
# ---------- 2. Bangun vocab & label map ----------
|
||
words = sorted({w for s in sentences for w in s})
|
||
ner_tags = sorted({t for seq in labels_ner for t in seq})
|
||
srl_tags = sorted({t for seq in labels_srl for t in seq})
|
||
|
||
word2idx = {w: i + 2 for i, w in enumerate(words)}
|
||
word2idx["PAD"], word2idx["UNK"] = 0, 1
|
||
|
||
tag2idx_ner = {t: i for i, t in enumerate(ner_tags)}
|
||
tag2idx_srl = {t: i for i, t in enumerate(srl_tags)}
|
||
idx2tag_ner = {i: t for t, i in tag2idx_ner.items()}
|
||
idx2tag_srl = {i: t for t, i in tag2idx_srl.items()}
|
||
|
||
# ---------- 3. Encoding token & label ----------
|
||
X = [[word2idx.get(w, word2idx["UNK"]) for w in s] for s in sentences]
|
||
y_ner = [[tag2idx_ner[t] for t in seq] for seq in labels_ner]
|
||
y_srl = [[tag2idx_srl[t] for t in seq] for seq in labels_srl]
|
||
|
||
maxlen = max(len(seq) for seq in X)
|
||
|
||
X = pad_sequences(X, maxlen=maxlen, padding="post", value=word2idx["PAD"])
|
||
y_ner = pad_sequences(y_ner, maxlen=maxlen, padding="post", value=tag2idx_ner["O"])
|
||
y_srl = pad_sequences(y_srl, maxlen=maxlen, padding="post", value=tag2idx_srl["O"])
|
||
|
||
y_ner = [to_categorical(seq, num_classes=len(tag2idx_ner)) for seq in y_ner]
|
||
y_srl = [to_categorical(seq, num_classes=len(tag2idx_srl)) for seq in y_srl]
|
||
|
||
# cast ke np.array biar Keras happy
|
||
X = np.array(X)
|
||
y_ner = np.array(y_ner)
|
||
y_srl = np.array(y_srl)
|
||
|
||
# ---------- 4. Arsitektur BiLSTM multi‑task ----------
|
||
input_layer = Input(shape=(maxlen,))
|
||
embed = Embedding(len(word2idx), 64)(input_layer)
|
||
bilstm = Bidirectional(LSTM(64, return_sequences=True))(embed)
|
||
|
||
ner_output = TimeDistributed(
|
||
Dense(len(tag2idx_ner), activation="softmax"), name="ner_output"
|
||
)(bilstm)
|
||
srl_output = TimeDistributed(
|
||
Dense(len(tag2idx_srl), activation="softmax"), name="srl_output"
|
||
)(bilstm)
|
||
|
||
model = Model(inputs=input_layer, outputs=[ner_output, srl_output])
|
||
model.compile(
|
||
optimizer="adam",
|
||
loss={
|
||
"ner_output": "categorical_crossentropy",
|
||
"srl_output": "categorical_crossentropy",
|
||
},
|
||
metrics={"ner_output": "accuracy", "srl_output": "accuracy"},
|
||
)
|
||
model.summary()
|
||
|
||
# ---------- 5. Training ----------
|
||
model.fit(
|
||
X, {"ner_output": y_ner, "srl_output": y_srl}, batch_size=2, epochs=10, verbose=1
|
||
)
|
||
|
||
# ---------- 6. Simpan artefak ----------
|
||
model.save("NER_SRL/multi_task_bilstm_model.keras")
|
||
with open("NER_SRL/word2idx.pkl", "wb") as f:
|
||
pickle.dump(word2idx, f)
|
||
with open("NER_SRL/tag2idx_ner.pkl", "wb") as f:
|
||
pickle.dump(tag2idx_ner, f)
|
||
with open("NER_SRL/tag2idx_srl.pkl", "wb") as f:
|
||
pickle.dump(tag2idx_srl, f)
|
||
|
||
# ---------- 7. Evaluasi ----------
|
||
y_pred_ner, y_pred_srl = model.predict(X, verbose=0)
|
||
|
||
|
||
def decode(pred, true, idx2tag):
|
||
true_tags = [[idx2tag[np.argmax(tok)] for tok in seq] for seq in true]
|
||
pred_tags = [[idx2tag[np.argmax(tok)] for tok in seq] for seq in pred]
|
||
return true_tags, pred_tags
|
||
|
||
|
||
true_ner, pred_ner = decode(y_pred_ner, y_ner, idx2tag_ner)
|
||
true_srl, pred_srl = decode(y_pred_srl, y_srl, idx2tag_srl)
|
||
|
||
print("\n📊 [NER] Classification Report:")
|
||
print(classification_report(true_ner, pred_ner))
|
||
|
||
print("\n📊 [SRL] Classification Report:")
|
||
print(classification_report(true_srl, pred_srl))
|