feat: adjustment on the model and adding some dataset
This commit is contained in:
parent
4694f0eb9c
commit
3a04f94fb3
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,13 @@
|
|||
import json
|
||||
|
||||
with open("dataset/dataset_ner_srl.json", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
|
||||
with open("dataset/dataset_ner_srl.tsv", "w", encoding="utf-8") as f:
|
||||
for entry in data:
|
||||
for tok, ner, srl in zip(
|
||||
entry["tokens"], entry["labels_ner"], entry["labels_srl"]
|
||||
):
|
||||
f.write(f"{tok}\t{ner}\t{srl}\n")
|
||||
f.write("\n") # Separate sentences
|
|
@ -0,0 +1,142 @@
|
|||
"""
|
||||
Train multi‑task LSTM / BiLSTM untuk NER + SRL
|
||||
———————————————
|
||||
• Dataset : ../dataset/dataset_ner_srl.json
|
||||
• Split : 80 % train | 20 % test
|
||||
• Model : Shared LSTM ➜ 2 head (NER, SRL)
|
||||
• Output : multi_task_lstm_ner_srl_model.keras
|
||||
"""
|
||||
|
||||
import json, pickle, numpy as np, tensorflow as tf
|
||||
from sklearn.model_selection import train_test_split
|
||||
from tensorflow.keras.layers import (
|
||||
Input,
|
||||
Embedding,
|
||||
LSTM,
|
||||
Bidirectional,
|
||||
TimeDistributed,
|
||||
Dense,
|
||||
)
|
||||
from tensorflow.keras.models import Model
|
||||
from tensorflow.keras.preprocessing.sequence import pad_sequences
|
||||
from tensorflow.keras.utils import to_categorical
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# 1. Muat data
|
||||
# ------------------------------------------------------------------------------
|
||||
with open("../dataset/dataset_ner_srl.json", encoding="utf-8") as f:
|
||||
DATA = json.load(f) # list[dict]
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# 2. Buat vocab & tag map
|
||||
# ------------------------------------------------------------------------------
|
||||
vocab = {"PAD": 0, "UNK": 1}
|
||||
for sample in DATA:
|
||||
for tok in sample["tokens"]:
|
||||
vocab.setdefault(tok.lower(), len(vocab))
|
||||
|
||||
|
||||
def build_tag_map(key):
|
||||
m = {"PAD": 0}
|
||||
for s in DATA:
|
||||
for t in s[key]:
|
||||
m.setdefault(t, len(m))
|
||||
return m
|
||||
|
||||
|
||||
ner2idx = build_tag_map("labels_ner")
|
||||
srl2idx = build_tag_map("labels_srl")
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# 3. Encoding token & label ke indeks ➜ pad
|
||||
# ------------------------------------------------------------------------------
|
||||
MAXLEN = 50
|
||||
|
||||
|
||||
def encode_tokens(tokens):
|
||||
ids = [vocab.get(tok.lower(), vocab["UNK"]) for tok in tokens]
|
||||
return pad_sequences([ids], maxlen=MAXLEN, padding="post", value=vocab["PAD"])[0]
|
||||
|
||||
|
||||
def encode_labels(labels, tag2idx):
|
||||
ids = [tag2idx[l] for l in labels]
|
||||
return pad_sequences([ids], maxlen=MAXLEN, padding="post", value=tag2idx["PAD"])[0]
|
||||
|
||||
|
||||
X = np.array([encode_tokens(s["tokens"]) for s in DATA])
|
||||
y_ner = np.array([encode_labels(s["labels_ner"], ner2idx) for s in DATA])
|
||||
y_srl = np.array([encode_labels(s["labels_srl"], srl2idx) for s in DATA])
|
||||
|
||||
# one‑hot (jika pakai categorical_crossentropy)
|
||||
y_ner = to_categorical(y_ner, num_classes=len(ner2idx))
|
||||
y_srl = to_categorical(y_srl, num_classes=len(srl2idx))
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# 4. Train / test split 80 : 20
|
||||
# ------------------------------------------------------------------------------
|
||||
X_tr, X_te, ytr_ner, yte_ner, ytr_srl, yte_srl = train_test_split(
|
||||
X, y_ner, y_srl, test_size=0.20, random_state=42, shuffle=True
|
||||
)
|
||||
|
||||
print(f"TRAIN : {X_tr.shape[0]} | TEST : {X_te.shape[0]}")
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# 5. Definisi model
|
||||
# ------------------------------------------------------------------------------
|
||||
EMB_DIM = 128
|
||||
RNN_UNITS = 128
|
||||
BILSTM = True # ganti False jika mau LSTM biasa
|
||||
|
||||
inp = Input(shape=(MAXLEN,))
|
||||
emb = Embedding(len(vocab), EMB_DIM, mask_zero=True)(inp)
|
||||
rnn = (
|
||||
Bidirectional(LSTM(RNN_UNITS, return_sequences=True))
|
||||
if BILSTM
|
||||
else LSTM(RNN_UNITS, return_sequences=True)
|
||||
)(emb)
|
||||
|
||||
out_ner = TimeDistributed(Dense(len(ner2idx), activation="softmax"), name="ner_output")(
|
||||
rnn
|
||||
)
|
||||
out_srl = TimeDistributed(Dense(len(srl2idx), activation="softmax"), name="srl_output")(
|
||||
rnn
|
||||
)
|
||||
|
||||
model = Model(inp, [out_ner, out_srl])
|
||||
model.compile(
|
||||
optimizer="adam",
|
||||
loss={
|
||||
"ner_output": "categorical_crossentropy",
|
||||
"srl_output": "categorical_crossentropy",
|
||||
},
|
||||
metrics={"ner_output": "accuracy", "srl_output": "accuracy"},
|
||||
)
|
||||
|
||||
model.summary()
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# 6. Training
|
||||
# ------------------------------------------------------------------------------
|
||||
EPOCHS = 15
|
||||
BATCH_SIZE = 32
|
||||
|
||||
history = model.fit(
|
||||
X_tr,
|
||||
{"ner_output": ytr_ner, "srl_output": ytr_srl},
|
||||
validation_data=(X_te, {"ner_output": yte_ner, "srl_output": yte_srl}),
|
||||
epochs=EPOCHS,
|
||||
batch_size=BATCH_SIZE,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# 7. Simpan artefak
|
||||
# ------------------------------------------------------------------------------
|
||||
model.save("multi_task_lstm_ner_srl_model.keras")
|
||||
with open("word2idx.pkl", "wb") as f:
|
||||
pickle.dump(vocab, f)
|
||||
with open("tag2idx_ner.pkl", "wb") as f:
|
||||
pickle.dump(ner2idx, f)
|
||||
with open("tag2idx_srl.pkl", "wb") as f:
|
||||
pickle.dump(srl2idx, f)
|
||||
|
||||
print("✓ Model & mapping tersimpan — siap dipakai fungsi predict_sentence()!")
|
Loading…
Reference in New Issue