feat: adding dataset and working on ner and srl model
This commit is contained in:
parent
7fa361e02d
commit
42816580aa
|
@ -1,203 +0,0 @@
|
||||||
{
|
|
||||||
"cells": [
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 8,
|
|
||||||
"id": "fcdce269",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"import json\n",
|
|
||||||
"import numpy as np\n",
|
|
||||||
"from keras.models import Model\n",
|
|
||||||
"from keras.layers import Input, Embedding, Bidirectional, LSTM, TimeDistributed, Dense\n",
|
|
||||||
"from keras.utils import to_categorical\n",
|
|
||||||
"from keras.preprocessing.sequence import pad_sequences\n",
|
|
||||||
"from sklearn.model_selection import train_test_split\n",
|
|
||||||
"from seqeval.metrics import classification_report\n",
|
|
||||||
"import pickle"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 9,
|
|
||||||
"id": "d568e8f2",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"# === LOAD DATA ===\n",
|
|
||||||
"with open(\"../dataset/dataset_ner_srl.json\", \"r\", encoding=\"utf-8\") as f:\n",
|
|
||||||
" data = json.load(f)\n",
|
|
||||||
"\n",
|
|
||||||
"sentences = [[token.lower() for token in item[\"tokens\"]] for item in data]\n",
|
|
||||||
"ner_labels = [item[\"labels_ner\"] for item in data]\n",
|
|
||||||
"srl_labels = [item[\"labels_srl\"] for item in data]"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 10,
|
|
||||||
"id": "e9653d99",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"# === VOCABULARY ===\n",
|
|
||||||
"words = list(set(word for sentence in sentences for word in sentence))\n",
|
|
||||||
"word2idx = {word: idx + 2 for idx, word in enumerate(words)}\n",
|
|
||||||
"word2idx[\"PAD\"] = 0\n",
|
|
||||||
"word2idx[\"UNK\"] = 1\n",
|
|
||||||
"\n",
|
|
||||||
"all_ner_tags = sorted(set(tag for seq in ner_labels for tag in seq))\n",
|
|
||||||
"all_srl_tags = sorted(set(tag for seq in srl_labels for tag in seq))\n",
|
|
||||||
"tag2idx_ner = {tag: idx for idx, tag in enumerate(all_ner_tags)}\n",
|
|
||||||
"tag2idx_srl = {tag: idx for idx, tag in enumerate(all_srl_tags)}\n",
|
|
||||||
"idx2tag_ner = {i: t for t, i in tag2idx_ner.items()}\n",
|
|
||||||
"idx2tag_srl = {i: t for t, i in tag2idx_srl.items()}"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"id": "9d3a37b3",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"ename": "KeyError",
|
|
||||||
"evalue": "'O'",
|
|
||||||
"output_type": "error",
|
|
||||||
"traceback": [
|
|
||||||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
|
||||||
"\u001b[0;31mKeyError\u001b[0m Traceback (most recent call last)",
|
|
||||||
"Cell \u001b[0;32mIn[11], line 9\u001b[0m\n\u001b[1;32m 7\u001b[0m X \u001b[38;5;241m=\u001b[39m pad_sequences(X, maxlen\u001b[38;5;241m=\u001b[39mmaxlen, padding\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mpost\u001b[39m\u001b[38;5;124m\"\u001b[39m, value\u001b[38;5;241m=\u001b[39mword2idx[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mPAD\u001b[39m\u001b[38;5;124m\"\u001b[39m])\n\u001b[1;32m 8\u001b[0m y_ner \u001b[38;5;241m=\u001b[39m pad_sequences(y_ner, maxlen\u001b[38;5;241m=\u001b[39mmaxlen, padding\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mpost\u001b[39m\u001b[38;5;124m\"\u001b[39m, value\u001b[38;5;241m=\u001b[39mtag2idx_ner[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mO\u001b[39m\u001b[38;5;124m\"\u001b[39m])\n\u001b[0;32m----> 9\u001b[0m y_srl \u001b[38;5;241m=\u001b[39m pad_sequences(y_srl, maxlen\u001b[38;5;241m=\u001b[39mmaxlen, padding\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mpost\u001b[39m\u001b[38;5;124m\"\u001b[39m, value\u001b[38;5;241m=\u001b[39m\u001b[43mtag2idx_srl\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mO\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m)\n\u001b[1;32m 10\u001b[0m y_ner_cat \u001b[38;5;241m=\u001b[39m [to_categorical(seq, num_classes\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mlen\u001b[39m(tag2idx_ner)) \u001b[38;5;28;01mfor\u001b[39;00m seq \u001b[38;5;129;01min\u001b[39;00m y_ner]\n\u001b[1;32m 11\u001b[0m y_srl_cat \u001b[38;5;241m=\u001b[39m [to_categorical(seq, num_classes\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mlen\u001b[39m(tag2idx_srl)) \u001b[38;5;28;01mfor\u001b[39;00m seq \u001b[38;5;129;01min\u001b[39;00m y_srl]\n",
|
|
||||||
"\u001b[0;31mKeyError\u001b[0m: 'O'"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"\n",
|
|
||||||
"# === ENCODING ===\n",
|
|
||||||
"X = [[word2idx.get(w, word2idx[\"UNK\"]) for w in s] for s in sentences]\n",
|
|
||||||
"y_ner = [[tag2idx_ner[t] for t in ts] for ts in ner_labels]\n",
|
|
||||||
"y_srl = [[tag2idx_srl[t] for t in ts] for ts in srl_labels]\n",
|
|
||||||
"\n",
|
|
||||||
"maxlen = max(len(x) for x in X)\n",
|
|
||||||
"X = pad_sequences(X, maxlen=maxlen, padding=\"post\", value=word2idx[\"PAD\"])\n",
|
|
||||||
"y_ner = pad_sequences(y_ner, maxlen=maxlen, padding=\"post\", value=tag2idx_ner[\"O\"])\n",
|
|
||||||
"y_srl = pad_sequences(y_srl, maxlen=maxlen, padding=\"post\", value=tag2idx_srl[\"O\"])\n",
|
|
||||||
"y_ner_cat = [to_categorical(seq, num_classes=len(tag2idx_ner)) for seq in y_ner]\n",
|
|
||||||
"y_srl_cat = [to_categorical(seq, num_classes=len(tag2idx_srl)) for seq in y_srl]\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"id": "a5c264df",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"# split dataset \n",
|
|
||||||
"X_temp, X_test, y_ner_temp, y_ner_test, y_srl_temp, y_srl_test = train_test_split(\n",
|
|
||||||
" X, y_ner_cat, y_srl_cat, test_size=0.1, random_state=42\n",
|
|
||||||
")\n",
|
|
||||||
"X_train, X_val, y_ner_train, y_ner_val, y_srl_train, y_srl_val = train_test_split(\n",
|
|
||||||
" X_temp, y_ner_temp, y_srl_temp, test_size=0.1111, random_state=42 # ~10% of total\n",
|
|
||||||
")"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"id": "712c1789",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"#training model\n",
|
|
||||||
"input_layer = Input(shape=(maxlen,))\n",
|
|
||||||
"embedding = Embedding(input_dim=len(word2idx), output_dim=64)(input_layer)\n",
|
|
||||||
"bilstm = Bidirectional(LSTM(units=64, return_sequences=True))(embedding)\n",
|
|
||||||
"out_ner = TimeDistributed(Dense(len(tag2idx_ner), activation=\"softmax\"), name=\"ner_output\")(bilstm)\n",
|
|
||||||
"out_srl = TimeDistributed(Dense(len(tag2idx_srl), activation=\"softmax\"), name=\"srl_output\")(bilstm)\n",
|
|
||||||
"\n",
|
|
||||||
"model = Model(inputs=input_layer, outputs=[out_ner, out_srl])\n",
|
|
||||||
"model.compile(\n",
|
|
||||||
" optimizer=\"adam\",\n",
|
|
||||||
" loss={\"ner_output\": \"categorical_crossentropy\", \"srl_output\": \"categorical_crossentropy\"},\n",
|
|
||||||
" metrics={\"ner_output\": \"accuracy\", \"srl_output\": \"accuracy\"}\n",
|
|
||||||
")\n",
|
|
||||||
"\n",
|
|
||||||
"model.summary()"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"id": "98feee87",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"\n",
|
|
||||||
"# === TRAINING ===\n",
|
|
||||||
"history = model.fit(\n",
|
|
||||||
" X_train,\n",
|
|
||||||
" {\"ner_output\": np.array(y_ner_train), \"srl_output\": np.array(y_srl_train)},\n",
|
|
||||||
" validation_data=(X_val, {\"ner_output\": np.array(y_ner_val), \"srl_output\": np.array(y_srl_val)}),\n",
|
|
||||||
" batch_size=2,\n",
|
|
||||||
" epochs=10\n",
|
|
||||||
")\n",
|
|
||||||
"\n",
|
|
||||||
"# === SAVE ===\n",
|
|
||||||
"model.save(\"NER_SRL/multi_task_bilstm_model.keras\")\n",
|
|
||||||
"with open(\"NER_SRL/word2idx.pkl\", \"wb\") as f:\n",
|
|
||||||
" pickle.dump(word2idx, f)\n",
|
|
||||||
"with open(\"NER_SRL/tag2idx_ner.pkl\", \"wb\") as f:\n",
|
|
||||||
" pickle.dump(tag2idx_ner, f)\n",
|
|
||||||
"with open(\"NER_SRL/tag2idx_srl.pkl\", \"wb\") as f:\n",
|
|
||||||
" pickle.dump(tag2idx_srl, f)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"id": "aeef32c1",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"# evaluation\n",
|
|
||||||
"y_pred_ner, y_pred_srl = model.predict(X_test)\n",
|
|
||||||
"\n",
|
|
||||||
"y_true_ner = [[idx2tag_ner[np.argmax(tok)] for tok in seq] for seq in y_ner_test]\n",
|
|
||||||
"y_pred_ner = [[idx2tag_ner[np.argmax(tok)] for tok in seq] for seq in y_pred_ner]\n",
|
|
||||||
"\n",
|
|
||||||
"y_true_srl = [[idx2tag_srl[np.argmax(tok)] for tok in seq] for seq in y_srl_test]\n",
|
|
||||||
"y_pred_srl = [[idx2tag_srl[np.argmax(tok)] for tok in seq] for seq in y_pred_srl]\n",
|
|
||||||
"\n",
|
|
||||||
"print(\"\\n📊 [NER] Test Set Classification Report:\")\n",
|
|
||||||
"print(classification_report(y_true_ner, y_pred_ner))\n",
|
|
||||||
"\n",
|
|
||||||
"print(\"\\n📊 [SRL] Test Set Classification Report:\")\n",
|
|
||||||
"print(classification_report(y_true_srl, y_pred_srl))"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"metadata": {
|
|
||||||
"kernelspec": {
|
|
||||||
"display_name": "myenv",
|
|
||||||
"language": "python",
|
|
||||||
"name": "python3"
|
|
||||||
},
|
|
||||||
"language_info": {
|
|
||||||
"codemirror_mode": {
|
|
||||||
"name": "ipython",
|
|
||||||
"version": 3
|
|
||||||
},
|
|
||||||
"file_extension": ".py",
|
|
||||||
"mimetype": "text/x-python",
|
|
||||||
"name": "python",
|
|
||||||
"nbconvert_exporter": "python",
|
|
||||||
"pygments_lexer": "ipython3",
|
|
||||||
"version": "3.10.16"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"nbformat": 4,
|
|
||||||
"nbformat_minor": 5
|
|
||||||
}
|
|
|
@ -15,6 +15,8 @@ with open("NER/tag2idx.pkl", "rb") as f:
|
||||||
|
|
||||||
idx2tag = {i: t for t, i in tag2idx.items()}
|
idx2tag = {i: t for t, i in tag2idx.items()}
|
||||||
|
|
||||||
|
print(idx2tag)
|
||||||
|
|
||||||
maxlen = 100
|
maxlen = 100
|
||||||
|
|
||||||
|
|
||||||
|
|
File diff suppressed because one or more lines are too long
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue