153 lines
5.9 KiB
Python
153 lines
5.9 KiB
Python
# ner_srl_multitask.py
|
||
# ----------------------------------------------------------
|
||
# Train a multi‑task (Bi)LSTM that predicts NER + SRL tags
|
||
# ----------------------------------------------------------
|
||
import json, numpy as np, tensorflow as tf
|
||
from tensorflow.keras.layers import (Input, Embedding, LSTM, Bidirectional,
|
||
TimeDistributed, Dense)
|
||
from tensorflow.keras.models import Model
|
||
from tensorflow.keras.preprocessing.sequence import pad_sequences
|
||
from tensorflow.keras.utils import to_categorical
|
||
from sklearn.model_selection import train_test_split
|
||
from seqeval.metrics import classification_report
|
||
# ----------------------------------------------------------
|
||
# 1. Load and prepare data
|
||
# ----------------------------------------------------------
|
||
DATA = json.load(open("../dataset/dataset_ner_srl.json", "r", encoding="utf8"))
|
||
|
||
# --- token vocabulary -------------------------------------------------
|
||
vocab = {"PAD": 0, "UNK": 1}
|
||
for sample in DATA:
|
||
for tok in sample["tokens"]:
|
||
vocab.setdefault(tok.lower(), len(vocab))
|
||
|
||
# --- label maps -------------------------------------------------------
|
||
def build_label_map(key):
|
||
tags = {"PAD": 0} # keep 0 for padding
|
||
for s in DATA:
|
||
for t in s[key]:
|
||
tags.setdefault(t, len(tags))
|
||
return tags
|
||
|
||
ner2idx = build_label_map("labels_ner")
|
||
srl2idx = build_label_map("labels_srl")
|
||
idx2ner = {i: t for t, i in ner2idx.items()}
|
||
idx2srl = {i: t for t, i in srl2idx.items()}
|
||
|
||
# --- sequences --------------------------------------------------------
|
||
MAXLEN = max(len(x["tokens"]) for x in DATA)
|
||
|
||
X = [[vocab.get(tok.lower(), vocab["UNK"]) for tok in s["tokens"]]
|
||
for s in DATA]
|
||
y_ner = [[ner2idx[t] for t in s["labels_ner"]]
|
||
for s in DATA]
|
||
y_srl = [[srl2idx[t] for t in s["labels_srl"]]
|
||
for s in DATA]
|
||
|
||
X = pad_sequences(X, maxlen=MAXLEN, padding="post", value=vocab["PAD"])
|
||
y_ner = pad_sequences(y_ner, maxlen=MAXLEN, padding="post", value=ner2idx["PAD"])
|
||
y_srl = pad_sequences(y_srl, maxlen=MAXLEN, padding="post", value=srl2idx["PAD"])
|
||
|
||
# --- one‑hot for softmax ---------------------------------------------
|
||
y_ner = to_categorical(y_ner, num_classes=len(ner2idx))
|
||
y_srl = to_categorical(y_srl, num_classes=len(srl2idx))
|
||
|
||
# ----------------------------------------------------------
|
||
# 2. Train / validation split
|
||
# ----------------------------------------------------------
|
||
# *All* arrays must be passed to train_test_split in one call so they
|
||
# stay aligned. Order‑of‑return = train,test for each array.
|
||
X_tr, X_val, y_tr_ner, y_val_ner, y_tr_srl, y_val_srl = train_test_split(
|
||
X, y_ner, y_srl, test_size=0.15, random_state=42
|
||
)
|
||
|
||
# ----------------------------------------------------------
|
||
# 3. Model definition
|
||
# ----------------------------------------------------------
|
||
EMB_DIM = 128
|
||
LSTM_UNITS = 128
|
||
|
||
inp = Input(shape=(MAXLEN,))
|
||
emb = Embedding(len(vocab), EMB_DIM, mask_zero=True)(inp)
|
||
bilstm= Bidirectional(LSTM(LSTM_UNITS, return_sequences=True))(emb)
|
||
|
||
ner_out = TimeDistributed(
|
||
Dense(len(ner2idx), activation="softmax"), name="ner")(bilstm)
|
||
srl_out = TimeDistributed(
|
||
Dense(len(srl2idx), activation="softmax"), name="srl")(bilstm)
|
||
|
||
model = Model(inp, [ner_out, srl_out])
|
||
model.compile(
|
||
optimizer="adam",
|
||
loss ={"ner": "categorical_crossentropy",
|
||
"srl": "categorical_crossentropy"},
|
||
metrics={"ner": "accuracy",
|
||
"srl": "accuracy"}
|
||
)
|
||
model.summary()
|
||
|
||
# ----------------------------------------------------------
|
||
# 4. Train
|
||
# ----------------------------------------------------------
|
||
history = model.fit(
|
||
X_tr,
|
||
{"ner": y_tr_ner, "srl": y_tr_srl},
|
||
validation_data=(X_val, {"ner": y_val_ner, "srl": y_val_srl}),
|
||
epochs=15,
|
||
batch_size=32,
|
||
verbose=2,
|
||
)
|
||
|
||
# ----------------------------------------------------------
|
||
# 5. Helper: decode with a mask (so lens always match)
|
||
# ----------------------------------------------------------
|
||
def decode(pred, idx2tag, mask):
|
||
"""
|
||
pred : [n, MAXLEN, n_tags] (one‑hot or probabilities)
|
||
mask : [n, MAXLEN] (True for real tokens, False for PAD)
|
||
"""
|
||
out = []
|
||
for seq, m in zip(pred, mask):
|
||
tags = [idx2tag[np.argmax(tok)] for tok, keep in zip(seq, m) if keep]
|
||
out.append(tags)
|
||
return out
|
||
|
||
# ----------------------------------------------------------
|
||
# 6. Evaluation
|
||
# ----------------------------------------------------------
|
||
y_pred_ner, y_pred_srl = model.predict(X_val, verbose=0)
|
||
|
||
mask_val = (X_val != vocab["PAD"]) # True for real tokens
|
||
|
||
true_ner = decode(y_val_ner , idx2ner, mask_val)
|
||
pred_ner = decode(y_pred_ner, idx2ner, mask_val)
|
||
true_srl = decode(y_val_srl , idx2srl, mask_val)
|
||
pred_srl = decode(y_pred_srl, idx2srl, mask_val)
|
||
|
||
print("\n📊 NER report")
|
||
print(classification_report(true_ner, pred_ner))
|
||
|
||
print("\n📊 SRL report")
|
||
print(classification_report(true_srl, pred_srl))
|
||
|
||
# # ----------------------------------------------------------
|
||
# # 7. Quick inference function
|
||
# # ----------------------------------------------------------
|
||
# def predict_sentence(sentence: str):
|
||
# tokens = sentence.strip().split()
|
||
# ids = [vocab.get(w.lower(), vocab["UNK"]) for w in tokens]
|
||
# ids = pad_sequences([ids], maxlen=MAXLEN, padding="post",
|
||
# value=vocab["PAD"])
|
||
# mask = (ids != vocab["PAD"])
|
||
# p_ner, p_srl = model.predict(ids, verbose=0)
|
||
# ner_tags = decode(p_ner , idx2ner , mask)[0]
|
||
# srl_tags = decode(p_srl , idx2srl , mask)[0]
|
||
# return list(zip(tokens, ner_tags, srl_tags))
|
||
|
||
# # ---- demo ------------------------------------------------
|
||
# if __name__ == "__main__":
|
||
# print("\n🔍 Demo:")
|
||
# for tok, ner, srl in predict_sentence(
|
||
# "Keanekaragaman hayati Indonesia sangat dipengaruhi faktor iklim."):
|
||
# print(f"{tok:15} {ner:10} {srl}")
|