TIF_E41211115_lstm-quiz-gen.../old/lses.py

153 lines
5.9 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# ner_srl_multitask.py
# ----------------------------------------------------------
# Train a multitask (Bi)LSTM that predicts NER + SRL tags
# ----------------------------------------------------------
import json, numpy as np, tensorflow as tf
from tensorflow.keras.layers import (Input, Embedding, LSTM, Bidirectional,
TimeDistributed, Dense)
from tensorflow.keras.models import Model
from tensorflow.keras.preprocessing.sequence import pad_sequences
from tensorflow.keras.utils import to_categorical
from sklearn.model_selection import train_test_split
from seqeval.metrics import classification_report
# ----------------------------------------------------------
# 1. Load and prepare data
# ----------------------------------------------------------
DATA = json.load(open("../dataset/dataset_ner_srl.json", "r", encoding="utf8"))
# --- token vocabulary -------------------------------------------------
vocab = {"PAD": 0, "UNK": 1}
for sample in DATA:
for tok in sample["tokens"]:
vocab.setdefault(tok.lower(), len(vocab))
# --- label maps -------------------------------------------------------
def build_label_map(key):
tags = {"PAD": 0} # keep 0 for padding
for s in DATA:
for t in s[key]:
tags.setdefault(t, len(tags))
return tags
ner2idx = build_label_map("labels_ner")
srl2idx = build_label_map("labels_srl")
idx2ner = {i: t for t, i in ner2idx.items()}
idx2srl = {i: t for t, i in srl2idx.items()}
# --- sequences --------------------------------------------------------
MAXLEN = max(len(x["tokens"]) for x in DATA)
X = [[vocab.get(tok.lower(), vocab["UNK"]) for tok in s["tokens"]]
for s in DATA]
y_ner = [[ner2idx[t] for t in s["labels_ner"]]
for s in DATA]
y_srl = [[srl2idx[t] for t in s["labels_srl"]]
for s in DATA]
X = pad_sequences(X, maxlen=MAXLEN, padding="post", value=vocab["PAD"])
y_ner = pad_sequences(y_ner, maxlen=MAXLEN, padding="post", value=ner2idx["PAD"])
y_srl = pad_sequences(y_srl, maxlen=MAXLEN, padding="post", value=srl2idx["PAD"])
# --- onehot for softmax ---------------------------------------------
y_ner = to_categorical(y_ner, num_classes=len(ner2idx))
y_srl = to_categorical(y_srl, num_classes=len(srl2idx))
# ----------------------------------------------------------
# 2. Train / validation split
# ----------------------------------------------------------
# *All* arrays must be passed to train_test_split in one call so they
# stay aligned. Orderofreturn = train,test for each array.
X_tr, X_val, y_tr_ner, y_val_ner, y_tr_srl, y_val_srl = train_test_split(
X, y_ner, y_srl, test_size=0.15, random_state=42
)
# ----------------------------------------------------------
# 3. Model definition
# ----------------------------------------------------------
EMB_DIM = 128
LSTM_UNITS = 128
inp = Input(shape=(MAXLEN,))
emb = Embedding(len(vocab), EMB_DIM, mask_zero=True)(inp)
bilstm= Bidirectional(LSTM(LSTM_UNITS, return_sequences=True))(emb)
ner_out = TimeDistributed(
Dense(len(ner2idx), activation="softmax"), name="ner")(bilstm)
srl_out = TimeDistributed(
Dense(len(srl2idx), activation="softmax"), name="srl")(bilstm)
model = Model(inp, [ner_out, srl_out])
model.compile(
optimizer="adam",
loss ={"ner": "categorical_crossentropy",
"srl": "categorical_crossentropy"},
metrics={"ner": "accuracy",
"srl": "accuracy"}
)
model.summary()
# ----------------------------------------------------------
# 4. Train
# ----------------------------------------------------------
history = model.fit(
X_tr,
{"ner": y_tr_ner, "srl": y_tr_srl},
validation_data=(X_val, {"ner": y_val_ner, "srl": y_val_srl}),
epochs=15,
batch_size=32,
verbose=2,
)
# ----------------------------------------------------------
# 5. Helper: decode with a mask (so lens always match)
# ----------------------------------------------------------
def decode(pred, idx2tag, mask):
"""
pred : [n, MAXLEN, n_tags] (onehot or probabilities)
mask : [n, MAXLEN] (True for real tokens, False for PAD)
"""
out = []
for seq, m in zip(pred, mask):
tags = [idx2tag[np.argmax(tok)] for tok, keep in zip(seq, m) if keep]
out.append(tags)
return out
# ----------------------------------------------------------
# 6. Evaluation
# ----------------------------------------------------------
y_pred_ner, y_pred_srl = model.predict(X_val, verbose=0)
mask_val = (X_val != vocab["PAD"]) # True for real tokens
true_ner = decode(y_val_ner , idx2ner, mask_val)
pred_ner = decode(y_pred_ner, idx2ner, mask_val)
true_srl = decode(y_val_srl , idx2srl, mask_val)
pred_srl = decode(y_pred_srl, idx2srl, mask_val)
print("\n📊 NER report")
print(classification_report(true_ner, pred_ner))
print("\n📊 SRL report")
print(classification_report(true_srl, pred_srl))
# # ----------------------------------------------------------
# # 7. Quick inference function
# # ----------------------------------------------------------
# def predict_sentence(sentence: str):
# tokens = sentence.strip().split()
# ids = [vocab.get(w.lower(), vocab["UNK"]) for w in tokens]
# ids = pad_sequences([ids], maxlen=MAXLEN, padding="post",
# value=vocab["PAD"])
# mask = (ids != vocab["PAD"])
# p_ner, p_srl = model.predict(ids, verbose=0)
# ner_tags = decode(p_ner , idx2ner , mask)[0]
# srl_tags = decode(p_srl , idx2srl , mask)[0]
# return list(zip(tokens, ner_tags, srl_tags))
# # ---- demo ------------------------------------------------
# if __name__ == "__main__":
# print("\n🔍 Demo:")
# for tok, ner, srl in predict_sentence(
# "Keanekaragaman hayati Indonesia sangat dipengaruhi faktor iklim."):
# print(f"{tok:15} {ner:10} {srl}")