feat: new dataset for srl
This commit is contained in:
parent
fa116924e4
commit
7fa361e02d
|
@ -0,0 +1,203 @@
|
||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 8,
|
||||||
|
"id": "fcdce269",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import json\n",
|
||||||
|
"import numpy as np\n",
|
||||||
|
"from keras.models import Model\n",
|
||||||
|
"from keras.layers import Input, Embedding, Bidirectional, LSTM, TimeDistributed, Dense\n",
|
||||||
|
"from keras.utils import to_categorical\n",
|
||||||
|
"from keras.preprocessing.sequence import pad_sequences\n",
|
||||||
|
"from sklearn.model_selection import train_test_split\n",
|
||||||
|
"from seqeval.metrics import classification_report\n",
|
||||||
|
"import pickle"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 9,
|
||||||
|
"id": "d568e8f2",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# === LOAD DATA ===\n",
|
||||||
|
"with open(\"../dataset/dataset_ner_srl.json\", \"r\", encoding=\"utf-8\") as f:\n",
|
||||||
|
" data = json.load(f)\n",
|
||||||
|
"\n",
|
||||||
|
"sentences = [[token.lower() for token in item[\"tokens\"]] for item in data]\n",
|
||||||
|
"ner_labels = [item[\"labels_ner\"] for item in data]\n",
|
||||||
|
"srl_labels = [item[\"labels_srl\"] for item in data]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 10,
|
||||||
|
"id": "e9653d99",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# === VOCABULARY ===\n",
|
||||||
|
"words = list(set(word for sentence in sentences for word in sentence))\n",
|
||||||
|
"word2idx = {word: idx + 2 for idx, word in enumerate(words)}\n",
|
||||||
|
"word2idx[\"PAD\"] = 0\n",
|
||||||
|
"word2idx[\"UNK\"] = 1\n",
|
||||||
|
"\n",
|
||||||
|
"all_ner_tags = sorted(set(tag for seq in ner_labels for tag in seq))\n",
|
||||||
|
"all_srl_tags = sorted(set(tag for seq in srl_labels for tag in seq))\n",
|
||||||
|
"tag2idx_ner = {tag: idx for idx, tag in enumerate(all_ner_tags)}\n",
|
||||||
|
"tag2idx_srl = {tag: idx for idx, tag in enumerate(all_srl_tags)}\n",
|
||||||
|
"idx2tag_ner = {i: t for t, i in tag2idx_ner.items()}\n",
|
||||||
|
"idx2tag_srl = {i: t for t, i in tag2idx_srl.items()}"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "9d3a37b3",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"ename": "KeyError",
|
||||||
|
"evalue": "'O'",
|
||||||
|
"output_type": "error",
|
||||||
|
"traceback": [
|
||||||
|
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
||||||
|
"\u001b[0;31mKeyError\u001b[0m Traceback (most recent call last)",
|
||||||
|
"Cell \u001b[0;32mIn[11], line 9\u001b[0m\n\u001b[1;32m 7\u001b[0m X \u001b[38;5;241m=\u001b[39m pad_sequences(X, maxlen\u001b[38;5;241m=\u001b[39mmaxlen, padding\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mpost\u001b[39m\u001b[38;5;124m\"\u001b[39m, value\u001b[38;5;241m=\u001b[39mword2idx[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mPAD\u001b[39m\u001b[38;5;124m\"\u001b[39m])\n\u001b[1;32m 8\u001b[0m y_ner \u001b[38;5;241m=\u001b[39m pad_sequences(y_ner, maxlen\u001b[38;5;241m=\u001b[39mmaxlen, padding\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mpost\u001b[39m\u001b[38;5;124m\"\u001b[39m, value\u001b[38;5;241m=\u001b[39mtag2idx_ner[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mO\u001b[39m\u001b[38;5;124m\"\u001b[39m])\n\u001b[0;32m----> 9\u001b[0m y_srl \u001b[38;5;241m=\u001b[39m pad_sequences(y_srl, maxlen\u001b[38;5;241m=\u001b[39mmaxlen, padding\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mpost\u001b[39m\u001b[38;5;124m\"\u001b[39m, value\u001b[38;5;241m=\u001b[39m\u001b[43mtag2idx_srl\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mO\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m)\n\u001b[1;32m 10\u001b[0m y_ner_cat \u001b[38;5;241m=\u001b[39m [to_categorical(seq, num_classes\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mlen\u001b[39m(tag2idx_ner)) \u001b[38;5;28;01mfor\u001b[39;00m seq \u001b[38;5;129;01min\u001b[39;00m y_ner]\n\u001b[1;32m 11\u001b[0m y_srl_cat \u001b[38;5;241m=\u001b[39m [to_categorical(seq, num_classes\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mlen\u001b[39m(tag2idx_srl)) \u001b[38;5;28;01mfor\u001b[39;00m seq \u001b[38;5;129;01min\u001b[39;00m y_srl]\n",
|
||||||
|
"\u001b[0;31mKeyError\u001b[0m: 'O'"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"\n",
|
||||||
|
"# === ENCODING ===\n",
|
||||||
|
"X = [[word2idx.get(w, word2idx[\"UNK\"]) for w in s] for s in sentences]\n",
|
||||||
|
"y_ner = [[tag2idx_ner[t] for t in ts] for ts in ner_labels]\n",
|
||||||
|
"y_srl = [[tag2idx_srl[t] for t in ts] for ts in srl_labels]\n",
|
||||||
|
"\n",
|
||||||
|
"maxlen = max(len(x) for x in X)\n",
|
||||||
|
"X = pad_sequences(X, maxlen=maxlen, padding=\"post\", value=word2idx[\"PAD\"])\n",
|
||||||
|
"y_ner = pad_sequences(y_ner, maxlen=maxlen, padding=\"post\", value=tag2idx_ner[\"O\"])\n",
|
||||||
|
"y_srl = pad_sequences(y_srl, maxlen=maxlen, padding=\"post\", value=tag2idx_srl[\"O\"])\n",
|
||||||
|
"y_ner_cat = [to_categorical(seq, num_classes=len(tag2idx_ner)) for seq in y_ner]\n",
|
||||||
|
"y_srl_cat = [to_categorical(seq, num_classes=len(tag2idx_srl)) for seq in y_srl]\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "a5c264df",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# split dataset \n",
|
||||||
|
"X_temp, X_test, y_ner_temp, y_ner_test, y_srl_temp, y_srl_test = train_test_split(\n",
|
||||||
|
" X, y_ner_cat, y_srl_cat, test_size=0.1, random_state=42\n",
|
||||||
|
")\n",
|
||||||
|
"X_train, X_val, y_ner_train, y_ner_val, y_srl_train, y_srl_val = train_test_split(\n",
|
||||||
|
" X_temp, y_ner_temp, y_srl_temp, test_size=0.1111, random_state=42 # ~10% of total\n",
|
||||||
|
")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "712c1789",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"#training model\n",
|
||||||
|
"input_layer = Input(shape=(maxlen,))\n",
|
||||||
|
"embedding = Embedding(input_dim=len(word2idx), output_dim=64)(input_layer)\n",
|
||||||
|
"bilstm = Bidirectional(LSTM(units=64, return_sequences=True))(embedding)\n",
|
||||||
|
"out_ner = TimeDistributed(Dense(len(tag2idx_ner), activation=\"softmax\"), name=\"ner_output\")(bilstm)\n",
|
||||||
|
"out_srl = TimeDistributed(Dense(len(tag2idx_srl), activation=\"softmax\"), name=\"srl_output\")(bilstm)\n",
|
||||||
|
"\n",
|
||||||
|
"model = Model(inputs=input_layer, outputs=[out_ner, out_srl])\n",
|
||||||
|
"model.compile(\n",
|
||||||
|
" optimizer=\"adam\",\n",
|
||||||
|
" loss={\"ner_output\": \"categorical_crossentropy\", \"srl_output\": \"categorical_crossentropy\"},\n",
|
||||||
|
" metrics={\"ner_output\": \"accuracy\", \"srl_output\": \"accuracy\"}\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"model.summary()"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "98feee87",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"\n",
|
||||||
|
"# === TRAINING ===\n",
|
||||||
|
"history = model.fit(\n",
|
||||||
|
" X_train,\n",
|
||||||
|
" {\"ner_output\": np.array(y_ner_train), \"srl_output\": np.array(y_srl_train)},\n",
|
||||||
|
" validation_data=(X_val, {\"ner_output\": np.array(y_ner_val), \"srl_output\": np.array(y_srl_val)}),\n",
|
||||||
|
" batch_size=2,\n",
|
||||||
|
" epochs=10\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"# === SAVE ===\n",
|
||||||
|
"model.save(\"NER_SRL/multi_task_bilstm_model.keras\")\n",
|
||||||
|
"with open(\"NER_SRL/word2idx.pkl\", \"wb\") as f:\n",
|
||||||
|
" pickle.dump(word2idx, f)\n",
|
||||||
|
"with open(\"NER_SRL/tag2idx_ner.pkl\", \"wb\") as f:\n",
|
||||||
|
" pickle.dump(tag2idx_ner, f)\n",
|
||||||
|
"with open(\"NER_SRL/tag2idx_srl.pkl\", \"wb\") as f:\n",
|
||||||
|
" pickle.dump(tag2idx_srl, f)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "aeef32c1",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# evaluation\n",
|
||||||
|
"y_pred_ner, y_pred_srl = model.predict(X_test)\n",
|
||||||
|
"\n",
|
||||||
|
"y_true_ner = [[idx2tag_ner[np.argmax(tok)] for tok in seq] for seq in y_ner_test]\n",
|
||||||
|
"y_pred_ner = [[idx2tag_ner[np.argmax(tok)] for tok in seq] for seq in y_pred_ner]\n",
|
||||||
|
"\n",
|
||||||
|
"y_true_srl = [[idx2tag_srl[np.argmax(tok)] for tok in seq] for seq in y_srl_test]\n",
|
||||||
|
"y_pred_srl = [[idx2tag_srl[np.argmax(tok)] for tok in seq] for seq in y_pred_srl]\n",
|
||||||
|
"\n",
|
||||||
|
"print(\"\\n📊 [NER] Test Set Classification Report:\")\n",
|
||||||
|
"print(classification_report(y_true_ner, y_pred_ner))\n",
|
||||||
|
"\n",
|
||||||
|
"print(\"\\n📊 [SRL] Test Set Classification Report:\")\n",
|
||||||
|
"print(classification_report(y_true_srl, y_pred_srl))"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "myenv",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.10.16"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
|
@ -33,7 +33,7 @@ def predict_sentence(sentence):
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
try:
|
try:
|
||||||
sentence = "dani datang ke indonesia"
|
sentence = "korea adalah tempat lahir jun"
|
||||||
predict_sentence(sentence)
|
predict_sentence(sentence)
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
print("\n\nSelesai.")
|
print("\n\nSelesai.")
|
||||||
|
|
|
@ -0,0 +1,8 @@
|
||||||
|
{
|
||||||
|
"tokens": ["Barack", "Obama", "lahir", "di", "Hawaii", "."],
|
||||||
|
"ner": ["B-PER", "I-PER", "O", "O", "B-LOC", "O"],
|
||||||
|
"srl": ["B-ARG0", "I-ARG0", "B-V", "B-ARGM-LOC", "I-ARGM-LOC", "O"],
|
||||||
|
"question": "___ lahir di Hawaii.",
|
||||||
|
"answer": "Barack Obama",
|
||||||
|
"type": "isian"
|
||||||
|
}
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue