feat: adding dataset and working on ner and srl model

This commit is contained in:
akhdanre 2025-04-20 10:26:19 +07:00
parent 7fa361e02d
commit 42816580aa
8 changed files with 3785 additions and 2354 deletions

View File

@ -1,203 +0,0 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 8,
"id": "fcdce269",
"metadata": {},
"outputs": [],
"source": [
"import json\n",
"import numpy as np\n",
"from keras.models import Model\n",
"from keras.layers import Input, Embedding, Bidirectional, LSTM, TimeDistributed, Dense\n",
"from keras.utils import to_categorical\n",
"from keras.preprocessing.sequence import pad_sequences\n",
"from sklearn.model_selection import train_test_split\n",
"from seqeval.metrics import classification_report\n",
"import pickle"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "d568e8f2",
"metadata": {},
"outputs": [],
"source": [
"# === LOAD DATA ===\n",
"with open(\"../dataset/dataset_ner_srl.json\", \"r\", encoding=\"utf-8\") as f:\n",
" data = json.load(f)\n",
"\n",
"sentences = [[token.lower() for token in item[\"tokens\"]] for item in data]\n",
"ner_labels = [item[\"labels_ner\"] for item in data]\n",
"srl_labels = [item[\"labels_srl\"] for item in data]"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "e9653d99",
"metadata": {},
"outputs": [],
"source": [
"# === VOCABULARY ===\n",
"words = list(set(word for sentence in sentences for word in sentence))\n",
"word2idx = {word: idx + 2 for idx, word in enumerate(words)}\n",
"word2idx[\"PAD\"] = 0\n",
"word2idx[\"UNK\"] = 1\n",
"\n",
"all_ner_tags = sorted(set(tag for seq in ner_labels for tag in seq))\n",
"all_srl_tags = sorted(set(tag for seq in srl_labels for tag in seq))\n",
"tag2idx_ner = {tag: idx for idx, tag in enumerate(all_ner_tags)}\n",
"tag2idx_srl = {tag: idx for idx, tag in enumerate(all_srl_tags)}\n",
"idx2tag_ner = {i: t for t, i in tag2idx_ner.items()}\n",
"idx2tag_srl = {i: t for t, i in tag2idx_srl.items()}"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9d3a37b3",
"metadata": {},
"outputs": [
{
"ename": "KeyError",
"evalue": "'O'",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mKeyError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[11], line 9\u001b[0m\n\u001b[1;32m 7\u001b[0m X \u001b[38;5;241m=\u001b[39m pad_sequences(X, maxlen\u001b[38;5;241m=\u001b[39mmaxlen, padding\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mpost\u001b[39m\u001b[38;5;124m\"\u001b[39m, value\u001b[38;5;241m=\u001b[39mword2idx[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mPAD\u001b[39m\u001b[38;5;124m\"\u001b[39m])\n\u001b[1;32m 8\u001b[0m y_ner \u001b[38;5;241m=\u001b[39m pad_sequences(y_ner, maxlen\u001b[38;5;241m=\u001b[39mmaxlen, padding\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mpost\u001b[39m\u001b[38;5;124m\"\u001b[39m, value\u001b[38;5;241m=\u001b[39mtag2idx_ner[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mO\u001b[39m\u001b[38;5;124m\"\u001b[39m])\n\u001b[0;32m----> 9\u001b[0m y_srl \u001b[38;5;241m=\u001b[39m pad_sequences(y_srl, maxlen\u001b[38;5;241m=\u001b[39mmaxlen, padding\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mpost\u001b[39m\u001b[38;5;124m\"\u001b[39m, value\u001b[38;5;241m=\u001b[39m\u001b[43mtag2idx_srl\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mO\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m)\n\u001b[1;32m 10\u001b[0m y_ner_cat \u001b[38;5;241m=\u001b[39m [to_categorical(seq, num_classes\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mlen\u001b[39m(tag2idx_ner)) \u001b[38;5;28;01mfor\u001b[39;00m seq \u001b[38;5;129;01min\u001b[39;00m y_ner]\n\u001b[1;32m 11\u001b[0m y_srl_cat \u001b[38;5;241m=\u001b[39m [to_categorical(seq, num_classes\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mlen\u001b[39m(tag2idx_srl)) \u001b[38;5;28;01mfor\u001b[39;00m seq \u001b[38;5;129;01min\u001b[39;00m y_srl]\n",
"\u001b[0;31mKeyError\u001b[0m: 'O'"
]
}
],
"source": [
"\n",
"# === ENCODING ===\n",
"X = [[word2idx.get(w, word2idx[\"UNK\"]) for w in s] for s in sentences]\n",
"y_ner = [[tag2idx_ner[t] for t in ts] for ts in ner_labels]\n",
"y_srl = [[tag2idx_srl[t] for t in ts] for ts in srl_labels]\n",
"\n",
"maxlen = max(len(x) for x in X)\n",
"X = pad_sequences(X, maxlen=maxlen, padding=\"post\", value=word2idx[\"PAD\"])\n",
"y_ner = pad_sequences(y_ner, maxlen=maxlen, padding=\"post\", value=tag2idx_ner[\"O\"])\n",
"y_srl = pad_sequences(y_srl, maxlen=maxlen, padding=\"post\", value=tag2idx_srl[\"O\"])\n",
"y_ner_cat = [to_categorical(seq, num_classes=len(tag2idx_ner)) for seq in y_ner]\n",
"y_srl_cat = [to_categorical(seq, num_classes=len(tag2idx_srl)) for seq in y_srl]\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a5c264df",
"metadata": {},
"outputs": [],
"source": [
"# split dataset \n",
"X_temp, X_test, y_ner_temp, y_ner_test, y_srl_temp, y_srl_test = train_test_split(\n",
" X, y_ner_cat, y_srl_cat, test_size=0.1, random_state=42\n",
")\n",
"X_train, X_val, y_ner_train, y_ner_val, y_srl_train, y_srl_val = train_test_split(\n",
" X_temp, y_ner_temp, y_srl_temp, test_size=0.1111, random_state=42 # ~10% of total\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "712c1789",
"metadata": {},
"outputs": [],
"source": [
"#training model\n",
"input_layer = Input(shape=(maxlen,))\n",
"embedding = Embedding(input_dim=len(word2idx), output_dim=64)(input_layer)\n",
"bilstm = Bidirectional(LSTM(units=64, return_sequences=True))(embedding)\n",
"out_ner = TimeDistributed(Dense(len(tag2idx_ner), activation=\"softmax\"), name=\"ner_output\")(bilstm)\n",
"out_srl = TimeDistributed(Dense(len(tag2idx_srl), activation=\"softmax\"), name=\"srl_output\")(bilstm)\n",
"\n",
"model = Model(inputs=input_layer, outputs=[out_ner, out_srl])\n",
"model.compile(\n",
" optimizer=\"adam\",\n",
" loss={\"ner_output\": \"categorical_crossentropy\", \"srl_output\": \"categorical_crossentropy\"},\n",
" metrics={\"ner_output\": \"accuracy\", \"srl_output\": \"accuracy\"}\n",
")\n",
"\n",
"model.summary()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "98feee87",
"metadata": {},
"outputs": [],
"source": [
"\n",
"# === TRAINING ===\n",
"history = model.fit(\n",
" X_train,\n",
" {\"ner_output\": np.array(y_ner_train), \"srl_output\": np.array(y_srl_train)},\n",
" validation_data=(X_val, {\"ner_output\": np.array(y_ner_val), \"srl_output\": np.array(y_srl_val)}),\n",
" batch_size=2,\n",
" epochs=10\n",
")\n",
"\n",
"# === SAVE ===\n",
"model.save(\"NER_SRL/multi_task_bilstm_model.keras\")\n",
"with open(\"NER_SRL/word2idx.pkl\", \"wb\") as f:\n",
" pickle.dump(word2idx, f)\n",
"with open(\"NER_SRL/tag2idx_ner.pkl\", \"wb\") as f:\n",
" pickle.dump(tag2idx_ner, f)\n",
"with open(\"NER_SRL/tag2idx_srl.pkl\", \"wb\") as f:\n",
" pickle.dump(tag2idx_srl, f)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "aeef32c1",
"metadata": {},
"outputs": [],
"source": [
"# evaluation\n",
"y_pred_ner, y_pred_srl = model.predict(X_test)\n",
"\n",
"y_true_ner = [[idx2tag_ner[np.argmax(tok)] for tok in seq] for seq in y_ner_test]\n",
"y_pred_ner = [[idx2tag_ner[np.argmax(tok)] for tok in seq] for seq in y_pred_ner]\n",
"\n",
"y_true_srl = [[idx2tag_srl[np.argmax(tok)] for tok in seq] for seq in y_srl_test]\n",
"y_pred_srl = [[idx2tag_srl[np.argmax(tok)] for tok in seq] for seq in y_pred_srl]\n",
"\n",
"print(\"\\n📊 [NER] Test Set Classification Report:\")\n",
"print(classification_report(y_true_ner, y_pred_ner))\n",
"\n",
"print(\"\\n📊 [SRL] Test Set Classification Report:\")\n",
"print(classification_report(y_true_srl, y_pred_srl))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "myenv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.16"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@ -15,6 +15,8 @@ with open("NER/tag2idx.pkl", "rb") as f:
idx2tag = {i: t for t, i in tag2idx.items()} idx2tag = {i: t for t, i in tag2idx.items()}
print(idx2tag)
maxlen = 100 maxlen = 100

2696
NER_SRL/lstm_ner_srl.ipynb Normal file

File diff suppressed because one or more lines are too long

Binary file not shown.

BIN
NER_SRL/tag2idx_ner.pkl Normal file

Binary file not shown.

BIN
NER_SRL/tag2idx_srl.pkl Normal file

Binary file not shown.

BIN
NER_SRL/word2idx.pkl Normal file

Binary file not shown.

File diff suppressed because it is too large Load Diff