In [8]:
import json
import numpy as np
from keras.models import Model
from keras.layers import Input, Embedding, Bidirectional, LSTM, TimeDistributed, Dense
from keras.utils import to_categorical
from keras.preprocessing.sequence import pad_sequences
from sklearn.model_selection import train_test_split
from seqeval.metrics import classification_report
import pickle

In [9]:
# === LOAD DATA ===
with open("../dataset/dataset_ner_srl.json", "r", encoding="utf-8") as f:
    data = json.load(f)

sentences = [[token.lower() for token in item["tokens"]] for item in data]
ner_labels = [item["labels_ner"] for item in data]
srl_labels = [item["labels_srl"] for item in data]

In [10]:
# === VOCABULARY ===
words = list(set(word for sentence in sentences for word in sentence))
word2idx = {word: idx + 2 for idx, word in enumerate(words)}
word2idx["PAD"] = 0
word2idx["UNK"] = 1

all_ner_tags = sorted(set(tag for seq in ner_labels for tag in seq))
all_srl_tags = sorted(set(tag for seq in srl_labels for tag in seq))
tag2idx_ner = {tag: idx for idx, tag in enumerate(all_ner_tags)}
tag2idx_srl = {tag: idx for idx, tag in enumerate(all_srl_tags)}
idx2tag_ner = {i: t for t, i in tag2idx_ner.items()}
idx2tag_srl = {i: t for t, i in tag2idx_srl.items()}

In [None]:

# === ENCODING ===
X = [[word2idx.get(w, word2idx["UNK"]) for w in s] for s in sentences]
y_ner = [[tag2idx_ner[t] for t in ts] for ts in ner_labels]
y_srl = [[tag2idx_srl[t] for t in ts] for ts in srl_labels]

maxlen = max(len(x) for x in X)
X = pad_sequences(X, maxlen=maxlen, padding="post", value=word2idx["PAD"])
y_ner = pad_sequences(y_ner, maxlen=maxlen, padding="post", value=tag2idx_ner["O"])
y_srl = pad_sequences(y_srl, maxlen=maxlen, padding="post", value=tag2idx_srl["O"])
y_ner_cat = [to_categorical(seq, num_classes=len(tag2idx_ner)) for seq in y_ner]
y_srl_cat = [to_categorical(seq, num_classes=len(tag2idx_srl)) for seq in y_srl]


KeyError: 'O'

In [None]:
# split dataset 
X_temp, X_test, y_ner_temp, y_ner_test, y_srl_temp, y_srl_test = train_test_split(
    X, y_ner_cat, y_srl_cat, test_size=0.1, random_state=42
)
X_train, X_val, y_ner_train, y_ner_val, y_srl_train, y_srl_val = train_test_split(
    X_temp, y_ner_temp, y_srl_temp, test_size=0.1111, random_state=42  # ~10% of total
)

In [None]:
#training model
input_layer = Input(shape=(maxlen,))
embedding = Embedding(input_dim=len(word2idx), output_dim=64)(input_layer)
bilstm = Bidirectional(LSTM(units=64, return_sequences=True))(embedding)
out_ner = TimeDistributed(Dense(len(tag2idx_ner), activation="softmax"), name="ner_output")(bilstm)
out_srl = TimeDistributed(Dense(len(tag2idx_srl), activation="softmax"), name="srl_output")(bilstm)

model = Model(inputs=input_layer, outputs=[out_ner, out_srl])
model.compile(
    optimizer="adam",
    loss={"ner_output": "categorical_crossentropy", "srl_output": "categorical_crossentropy"},
    metrics={"ner_output": "accuracy", "srl_output": "accuracy"}
)

model.summary()

In [None]:

# === TRAINING ===
history = model.fit(
    X_train,
    {"ner_output": np.array(y_ner_train), "srl_output": np.array(y_srl_train)},
    validation_data=(X_val, {"ner_output": np.array(y_ner_val), "srl_output": np.array(y_srl_val)}),
    batch_size=2,
    epochs=10
)

# === SAVE ===
model.save("NER_SRL/multi_task_bilstm_model.keras")
with open("NER_SRL/word2idx.pkl", "wb") as f:
    pickle.dump(word2idx, f)
with open("NER_SRL/tag2idx_ner.pkl", "wb") as f:
    pickle.dump(tag2idx_ner, f)
with open("NER_SRL/tag2idx_srl.pkl", "wb") as f:
    pickle.dump(tag2idx_srl, f)

In [None]:
# evaluation
y_pred_ner, y_pred_srl = model.predict(X_test)

y_true_ner = [[idx2tag_ner[np.argmax(tok)] for tok in seq] for seq in y_ner_test]
y_pred_ner = [[idx2tag_ner[np.argmax(tok)] for tok in seq] for seq in y_pred_ner]

y_true_srl = [[idx2tag_srl[np.argmax(tok)] for tok in seq] for seq in y_srl_test]
y_pred_srl = [[idx2tag_srl[np.argmax(tok)] for tok in seq] for seq in y_pred_srl]

print("\nðŸ“Š [NER] Test Set Classification Report:")
print(classification_report(y_true_ner, y_pred_ner))

print("\nðŸ“Š [SRL] Test Set Classification Report:")
print(classification_report(y_true_srl, y_pred_srl))