TIF_E41211115_lstm-quiz-gen.../old/train_multitask_lstm.py

143 lines
4.6 KiB
Python
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Train multitask LSTM / BiLSTM untuk NER + SRL
———————————————
• Dataset : ../dataset/dataset_ner_srl.json
• Split : 80 % train | 20 % test
• Model : Shared LSTM ➜ 2 head (NER, SRL)
• Output : multi_task_lstm_ner_srl_model.keras
"""
import json, pickle, numpy as np, tensorflow as tf
from sklearn.model_selection import train_test_split
from tensorflow.keras.layers import (
Input,
Embedding,
LSTM,
Bidirectional,
TimeDistributed,
Dense,
)
from tensorflow.keras.models import Model
from tensorflow.keras.preprocessing.sequence import pad_sequences
from tensorflow.keras.utils import to_categorical
# ------------------------------------------------------------------------------
# 1. Muat data
# ------------------------------------------------------------------------------
with open("../dataset/dataset_ner_srl.json", encoding="utf-8") as f:
DATA = json.load(f) # list[dict]
# ------------------------------------------------------------------------------
# 2. Buat vocab & tag map
# ------------------------------------------------------------------------------
vocab = {"PAD": 0, "UNK": 1}
for sample in DATA:
for tok in sample["tokens"]:
vocab.setdefault(tok.lower(), len(vocab))
def build_tag_map(key):
m = {"PAD": 0}
for s in DATA:
for t in s[key]:
m.setdefault(t, len(m))
return m
ner2idx = build_tag_map("labels_ner")
srl2idx = build_tag_map("labels_srl")
# ------------------------------------------------------------------------------
# 3. Encoding token & label ke indeks ➜ pad
# ------------------------------------------------------------------------------
MAXLEN = 50
def encode_tokens(tokens):
ids = [vocab.get(tok.lower(), vocab["UNK"]) for tok in tokens]
return pad_sequences([ids], maxlen=MAXLEN, padding="post", value=vocab["PAD"])[0]
def encode_labels(labels, tag2idx):
ids = [tag2idx[l] for l in labels]
return pad_sequences([ids], maxlen=MAXLEN, padding="post", value=tag2idx["PAD"])[0]
X = np.array([encode_tokens(s["tokens"]) for s in DATA])
y_ner = np.array([encode_labels(s["labels_ner"], ner2idx) for s in DATA])
y_srl = np.array([encode_labels(s["labels_srl"], srl2idx) for s in DATA])
# onehot (jika pakai categorical_crossentropy)
y_ner = to_categorical(y_ner, num_classes=len(ner2idx))
y_srl = to_categorical(y_srl, num_classes=len(srl2idx))
# ------------------------------------------------------------------------------
# 4. Train / test split 80 : 20
# ------------------------------------------------------------------------------
X_tr, X_te, ytr_ner, yte_ner, ytr_srl, yte_srl = train_test_split(
X, y_ner, y_srl, test_size=0.20, random_state=42, shuffle=True
)
print(f"TRAIN : {X_tr.shape[0]} | TEST : {X_te.shape[0]}")
# ------------------------------------------------------------------------------
# 5. Definisi model
# ------------------------------------------------------------------------------
EMB_DIM = 128
RNN_UNITS = 128
BILSTM = True # ganti False jika mau LSTM biasa
inp = Input(shape=(MAXLEN,))
emb = Embedding(len(vocab), EMB_DIM, mask_zero=True)(inp)
rnn = (
Bidirectional(LSTM(RNN_UNITS, return_sequences=True))
if BILSTM
else LSTM(RNN_UNITS, return_sequences=True)
)(emb)
out_ner = TimeDistributed(Dense(len(ner2idx), activation="softmax"), name="ner_output")(
rnn
)
out_srl = TimeDistributed(Dense(len(srl2idx), activation="softmax"), name="srl_output")(
rnn
)
model = Model(inp, [out_ner, out_srl])
model.compile(
optimizer="adam",
loss={
"ner_output": "categorical_crossentropy",
"srl_output": "categorical_crossentropy",
},
metrics={"ner_output": "accuracy", "srl_output": "accuracy"},
)
model.summary()
# ------------------------------------------------------------------------------
# 6. Training
# ------------------------------------------------------------------------------
EPOCHS = 15
BATCH_SIZE = 32
history = model.fit(
X_tr,
{"ner_output": ytr_ner, "srl_output": ytr_srl},
validation_data=(X_te, {"ner_output": yte_ner, "srl_output": yte_srl}),
epochs=EPOCHS,
batch_size=BATCH_SIZE,
)
# ------------------------------------------------------------------------------
# 7. Simpan artefak
# ------------------------------------------------------------------------------
model.save("multi_task_lstm_ner_srl_model.keras")
with open("word2idx.pkl", "wb") as f:
pickle.dump(vocab, f)
with open("tag2idx_ner.pkl", "wb") as f:
pickle.dump(ner2idx, f)
with open("tag2idx_srl.pkl", "wb") as f:
pickle.dump(srl2idx, f)
print("✓ Model & mapping tersimpan — siap dipakai fungsi predict_sentence()!")