# ner_srl_multitask.py # ---------------------------------------------------------- # Train a multi‑task (Bi)LSTM that predicts NER + SRL tags # ---------------------------------------------------------- import json, numpy as np, tensorflow as tf from tensorflow.keras.layers import (Input, Embedding, LSTM, Bidirectional, TimeDistributed, Dense) from tensorflow.keras.models import Model from tensorflow.keras.preprocessing.sequence import pad_sequences from tensorflow.keras.utils import to_categorical from sklearn.model_selection import train_test_split from seqeval.metrics import classification_report # ---------------------------------------------------------- # 1. Load and prepare data # ---------------------------------------------------------- DATA = json.load(open("../dataset/dataset_ner_srl.json", "r", encoding="utf8")) # --- token vocabulary ------------------------------------------------- vocab = {"PAD": 0, "UNK": 1} for sample in DATA: for tok in sample["tokens"]: vocab.setdefault(tok.lower(), len(vocab)) # --- label maps ------------------------------------------------------- def build_label_map(key): tags = {"PAD": 0} # keep 0 for padding for s in DATA: for t in s[key]: tags.setdefault(t, len(tags)) return tags ner2idx = build_label_map("labels_ner") srl2idx = build_label_map("labels_srl") idx2ner = {i: t for t, i in ner2idx.items()} idx2srl = {i: t for t, i in srl2idx.items()} # --- sequences -------------------------------------------------------- MAXLEN = max(len(x["tokens"]) for x in DATA) X = [[vocab.get(tok.lower(), vocab["UNK"]) for tok in s["tokens"]] for s in DATA] y_ner = [[ner2idx[t] for t in s["labels_ner"]] for s in DATA] y_srl = [[srl2idx[t] for t in s["labels_srl"]] for s in DATA] X = pad_sequences(X, maxlen=MAXLEN, padding="post", value=vocab["PAD"]) y_ner = pad_sequences(y_ner, maxlen=MAXLEN, padding="post", value=ner2idx["PAD"]) y_srl = pad_sequences(y_srl, maxlen=MAXLEN, padding="post", value=srl2idx["PAD"]) # --- one‑hot for softmax --------------------------------------------- y_ner = to_categorical(y_ner, num_classes=len(ner2idx)) y_srl = to_categorical(y_srl, num_classes=len(srl2idx)) # ---------------------------------------------------------- # 2. Train / validation split # ---------------------------------------------------------- # *All* arrays must be passed to train_test_split in one call so they # stay aligned. Order‑of‑return = train,test for each array. X_tr, X_val, y_tr_ner, y_val_ner, y_tr_srl, y_val_srl = train_test_split( X, y_ner, y_srl, test_size=0.15, random_state=42 ) # ---------------------------------------------------------- # 3. Model definition # ---------------------------------------------------------- EMB_DIM = 128 LSTM_UNITS = 128 inp = Input(shape=(MAXLEN,)) emb = Embedding(len(vocab), EMB_DIM, mask_zero=True)(inp) bilstm= Bidirectional(LSTM(LSTM_UNITS, return_sequences=True))(emb) ner_out = TimeDistributed( Dense(len(ner2idx), activation="softmax"), name="ner")(bilstm) srl_out = TimeDistributed( Dense(len(srl2idx), activation="softmax"), name="srl")(bilstm) model = Model(inp, [ner_out, srl_out]) model.compile( optimizer="adam", loss ={"ner": "categorical_crossentropy", "srl": "categorical_crossentropy"}, metrics={"ner": "accuracy", "srl": "accuracy"} ) model.summary() # ---------------------------------------------------------- # 4. Train # ---------------------------------------------------------- history = model.fit( X_tr, {"ner": y_tr_ner, "srl": y_tr_srl}, validation_data=(X_val, {"ner": y_val_ner, "srl": y_val_srl}), epochs=15, batch_size=32, verbose=2, ) # ---------------------------------------------------------- # 5. Helper: decode with a mask (so lens always match) # ---------------------------------------------------------- def decode(pred, idx2tag, mask): """ pred : [n, MAXLEN, n_tags] (one‑hot or probabilities) mask : [n, MAXLEN] (True for real tokens, False for PAD) """ out = [] for seq, m in zip(pred, mask): tags = [idx2tag[np.argmax(tok)] for tok, keep in zip(seq, m) if keep] out.append(tags) return out # ---------------------------------------------------------- # 6. Evaluation # ---------------------------------------------------------- y_pred_ner, y_pred_srl = model.predict(X_val, verbose=0) mask_val = (X_val != vocab["PAD"]) # True for real tokens true_ner = decode(y_val_ner , idx2ner, mask_val) pred_ner = decode(y_pred_ner, idx2ner, mask_val) true_srl = decode(y_val_srl , idx2srl, mask_val) pred_srl = decode(y_pred_srl, idx2srl, mask_val) print("\n📊 NER report") print(classification_report(true_ner, pred_ner)) print("\n📊 SRL report") print(classification_report(true_srl, pred_srl)) # # ---------------------------------------------------------- # # 7. Quick inference function # # ---------------------------------------------------------- # def predict_sentence(sentence: str): # tokens = sentence.strip().split() # ids = [vocab.get(w.lower(), vocab["UNK"]) for w in tokens] # ids = pad_sequences([ids], maxlen=MAXLEN, padding="post", # value=vocab["PAD"]) # mask = (ids != vocab["PAD"]) # p_ner, p_srl = model.predict(ids, verbose=0) # ner_tags = decode(p_ner , idx2ner , mask)[0] # srl_tags = decode(p_srl , idx2srl , mask)[0] # return list(zip(tokens, ner_tags, srl_tags)) # # ---- demo ------------------------------------------------ # if __name__ == "__main__": # print("\n🔍 Demo:") # for tok, ner, srl in predict_sentence( # "Keanekaragaman hayati Indonesia sangat dipengaruhi faktor iklim."): # print(f"{tok:15} {ner:10} {srl}")