515 lines
125 KiB
Plaintext
515 lines
125 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "263af9e9",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"\n",
|
|
"import pickle\n",
|
|
"import tensorflow as tf\n",
|
|
"from tensorflow.keras.models import Model\n",
|
|
"from tensorflow.keras.layers import Input, Embedding, Bidirectional, LSTM, TimeDistributed, Dense\n",
|
|
"from tensorflow.keras.preprocessing.sequence import pad_sequences\n",
|
|
"from tensorflow.keras.utils import to_categorical\n",
|
|
"from sklearn.model_selection import train_test_split\n",
|
|
"from seqeval.metrics import classification_report\n",
|
|
"import matplotlib.pyplot as plt"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 16,
|
|
"id": "4fc87f1b",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"data = []\n",
|
|
"with open(\"../dataset/dataset_ner_srl.tsv\", encoding=\"utf-8\") as f:\n",
|
|
" tokens, ner_labels, srl_labels = [], [], []\n",
|
|
" for line in f:\n",
|
|
" line = line.strip()\n",
|
|
" if not line:\n",
|
|
" if tokens:\n",
|
|
" data.append({\n",
|
|
" \"tokens\": tokens,\n",
|
|
" \"labels_ner\": ner_labels,\n",
|
|
" \"labels_srl\": srl_labels\n",
|
|
" })\n",
|
|
" tokens, ner_labels, srl_labels = [], [], []\n",
|
|
" else:\n",
|
|
" token, ner, srl = line.split(\"\\t\")\n",
|
|
" tokens.append(token)\n",
|
|
" ner_labels.append(ner)\n",
|
|
" srl_labels.append(srl)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 17,
|
|
"id": "48553e6b",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"\n",
|
|
"# 2. Preprocessing\n",
|
|
"sentences = [[tok.lower() for tok in item[\"tokens\"]] for item in data]\n",
|
|
"labels_ner = [item[\"labels_ner\"] for item in data]\n",
|
|
"labels_srl = [item[\"labels_srl\"] for item in data]\n",
|
|
"\n",
|
|
"words = sorted({w for s in sentences for w in s})\n",
|
|
"ner_tags = sorted({t for seq in labels_ner for t in seq})\n",
|
|
"srl_tags = sorted({t for seq in labels_srl for t in seq})\n",
|
|
"\n",
|
|
"word2idx = {w: i + 2 for i, w in enumerate(words)}\n",
|
|
"word2idx[\"PAD\"], word2idx[\"UNK\"] = 0, 1\n",
|
|
"\n",
|
|
"tag2idx_ner = {t: i for i, t in enumerate(ner_tags)}\n",
|
|
"tag2idx_srl = {t: i for i, t in enumerate(srl_tags)}\n",
|
|
"idx2tag_ner = {i: t for t, i in tag2idx_ner.items()}\n",
|
|
"idx2tag_srl = {i: t for t, i in tag2idx_srl.items()}\n",
|
|
"\n",
|
|
"X = [[word2idx.get(w, word2idx[\"UNK\"]) for w in s] for s in sentences]\n",
|
|
"y_ner = [[tag2idx_ner[t] for t in seq] for seq in labels_ner]\n",
|
|
"y_srl = [[tag2idx_srl[t] for t in seq] for seq in labels_srl]\n",
|
|
"\n",
|
|
"maxlen = 50\n",
|
|
"X = pad_sequences(X, maxlen=maxlen, padding=\"post\", value=word2idx[\"PAD\"])\n",
|
|
"y_ner = pad_sequences(y_ner, maxlen=maxlen, padding=\"post\", value=tag2idx_ner[\"O\"])\n",
|
|
"y_srl = pad_sequences(y_srl, maxlen=maxlen, padding=\"post\", value=tag2idx_srl[\"O\"])\n",
|
|
"\n",
|
|
"y_ner = to_categorical(y_ner, num_classes=len(tag2idx_ner))\n",
|
|
"y_srl = to_categorical(y_srl, num_classes=len(tag2idx_srl))\n",
|
|
"\n",
|
|
"X_train, X_test, y_ner_train, y_ner_test, y_srl_train, y_srl_test = train_test_split(\n",
|
|
" X, y_ner, y_srl, test_size=0.2, random_state=42, shuffle=True\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 18,
|
|
"id": "1b4a1c61",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">Model: \"functional_2\"</span>\n",
|
|
"</pre>\n"
|
|
],
|
|
"text/plain": [
|
|
"\u001b[1mModel: \"functional_2\"\u001b[0m\n"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓\n",
|
|
"┃<span style=\"font-weight: bold\"> Layer (type) </span>┃<span style=\"font-weight: bold\"> Output Shape </span>┃<span style=\"font-weight: bold\"> Param # </span>┃<span style=\"font-weight: bold\"> Connected to </span>┃\n",
|
|
"┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩\n",
|
|
"│ input_layer_2 │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">50</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ - │\n",
|
|
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">InputLayer</span>) │ │ │ │\n",
|
|
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
|
|
"│ embedding_2 │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">50</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">64</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">46,016</span> │ input_layer_2[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>]… │\n",
|
|
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Embedding</span>) │ │ │ │\n",
|
|
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
|
|
"│ bidirectional_2 │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">50</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">128</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">66,048</span> │ embedding_2[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
|
|
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Bidirectional</span>) │ │ │ │\n",
|
|
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
|
|
"│ ner_output │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">50</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">25</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">3,225</span> │ bidirectional_2[<span style=\"color: #00af00; text-decoration-color: #00af00\">…</span> │\n",
|
|
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">TimeDistributed</span>) │ │ │ │\n",
|
|
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
|
|
"│ srl_output │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">50</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">18</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">2,322</span> │ bidirectional_2[<span style=\"color: #00af00; text-decoration-color: #00af00\">…</span> │\n",
|
|
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">TimeDistributed</span>) │ │ │ │\n",
|
|
"└─────────────────────┴───────────────────┴────────────┴───────────────────┘\n",
|
|
"</pre>\n"
|
|
],
|
|
"text/plain": [
|
|
"┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓\n",
|
|
"┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mConnected to \u001b[0m\u001b[1m \u001b[0m┃\n",
|
|
"┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩\n",
|
|
"│ input_layer_2 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m50\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n",
|
|
"│ (\u001b[38;5;33mInputLayer\u001b[0m) │ │ │ │\n",
|
|
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
|
|
"│ embedding_2 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m50\u001b[0m, \u001b[38;5;34m64\u001b[0m) │ \u001b[38;5;34m46,016\u001b[0m │ input_layer_2[\u001b[38;5;34m0\u001b[0m]… │\n",
|
|
"│ (\u001b[38;5;33mEmbedding\u001b[0m) │ │ │ │\n",
|
|
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
|
|
"│ bidirectional_2 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m50\u001b[0m, \u001b[38;5;34m128\u001b[0m) │ \u001b[38;5;34m66,048\u001b[0m │ embedding_2[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
|
|
"│ (\u001b[38;5;33mBidirectional\u001b[0m) │ │ │ │\n",
|
|
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
|
|
"│ ner_output │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m50\u001b[0m, \u001b[38;5;34m25\u001b[0m) │ \u001b[38;5;34m3,225\u001b[0m │ bidirectional_2[\u001b[38;5;34m…\u001b[0m │\n",
|
|
"│ (\u001b[38;5;33mTimeDistributed\u001b[0m) │ │ │ │\n",
|
|
"├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
|
|
"│ srl_output │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m50\u001b[0m, \u001b[38;5;34m18\u001b[0m) │ \u001b[38;5;34m2,322\u001b[0m │ bidirectional_2[\u001b[38;5;34m…\u001b[0m │\n",
|
|
"│ (\u001b[38;5;33mTimeDistributed\u001b[0m) │ │ │ │\n",
|
|
"└─────────────────────┴───────────────────┴────────────┴───────────────────┘\n"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Total params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">117,611</span> (459.42 KB)\n",
|
|
"</pre>\n"
|
|
],
|
|
"text/plain": [
|
|
"\u001b[1m Total params: \u001b[0m\u001b[38;5;34m117,611\u001b[0m (459.42 KB)\n"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Trainable params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">117,611</span> (459.42 KB)\n",
|
|
"</pre>\n"
|
|
],
|
|
"text/plain": [
|
|
"\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m117,611\u001b[0m (459.42 KB)\n"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Non-trainable params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> (0.00 B)\n",
|
|
"</pre>\n"
|
|
],
|
|
"text/plain": [
|
|
"\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"\n",
|
|
"# 3. Model\n",
|
|
"input_layer = Input(shape=(maxlen,))\n",
|
|
"embedding_layer = Embedding(input_dim=len(word2idx), output_dim=64)(input_layer)\n",
|
|
"bilstm_layer = Bidirectional(LSTM(units=64, return_sequences=True))(embedding_layer)\n",
|
|
"\n",
|
|
"ner_output = TimeDistributed(Dense(len(tag2idx_ner), activation=\"softmax\"), name=\"ner_output\")(bilstm_layer)\n",
|
|
"srl_output = TimeDistributed(Dense(len(tag2idx_srl), activation=\"softmax\"), name=\"srl_output\")(bilstm_layer)\n",
|
|
"\n",
|
|
"model = Model(inputs=input_layer, outputs=[ner_output, srl_output])\n",
|
|
"model.compile(\n",
|
|
" optimizer=\"adam\",\n",
|
|
" loss={\n",
|
|
" \"ner_output\": \"categorical_crossentropy\",\n",
|
|
" \"srl_output\": \"categorical_crossentropy\",\n",
|
|
" },\n",
|
|
" metrics={\n",
|
|
" \"ner_output\": [tf.keras.metrics.CategoricalAccuracy(name=\"accuracy\")],\n",
|
|
" \"srl_output\": [tf.keras.metrics.CategoricalAccuracy(name=\"accuracy\")],\n",
|
|
" }\n",
|
|
")\n",
|
|
"\n",
|
|
"model.summary()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 19,
|
|
"id": "f41d6012",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Epoch 1/10\n",
|
|
"\u001b[1m64/64\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m6s\u001b[0m 28ms/step - loss: 3.1783 - ner_output_accuracy: 0.8803 - ner_output_loss: 1.5415 - srl_output_accuracy: 0.7421 - srl_output_loss: 1.6365 - val_loss: 0.7449 - val_ner_output_accuracy: 0.9488 - val_ner_output_loss: 0.2727 - val_srl_output_accuracy: 0.8513 - val_srl_output_loss: 0.4722\n",
|
|
"Epoch 2/10\n",
|
|
"\u001b[1m64/64\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 16ms/step - loss: 0.7761 - ner_output_accuracy: 0.9485 - ner_output_loss: 0.2537 - srl_output_accuracy: 0.8194 - srl_output_loss: 0.5225 - val_loss: 0.6895 - val_ner_output_accuracy: 0.9488 - val_ner_output_loss: 0.2616 - val_srl_output_accuracy: 0.8525 - val_srl_output_loss: 0.4279\n",
|
|
"Epoch 3/10\n",
|
|
"\u001b[1m64/64\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 16ms/step - loss: 0.6680 - ner_output_accuracy: 0.9534 - ner_output_loss: 0.2254 - srl_output_accuracy: 0.8478 - srl_output_loss: 0.4425 - val_loss: 0.6540 - val_ner_output_accuracy: 0.9488 - val_ner_output_loss: 0.2505 - val_srl_output_accuracy: 0.8775 - val_srl_output_loss: 0.4036\n",
|
|
"Epoch 4/10\n",
|
|
"\u001b[1m64/64\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 16ms/step - loss: 0.7246 - ner_output_accuracy: 0.9527 - ner_output_loss: 0.2277 - srl_output_accuracy: 0.8378 - srl_output_loss: 0.4968 - val_loss: 0.6080 - val_ner_output_accuracy: 0.9488 - val_ner_output_loss: 0.2360 - val_srl_output_accuracy: 0.8862 - val_srl_output_loss: 0.3720\n",
|
|
"Epoch 5/10\n",
|
|
"\u001b[1m64/64\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 16ms/step - loss: 0.6014 - ner_output_accuracy: 0.9482 - ner_output_loss: 0.2188 - srl_output_accuracy: 0.8758 - srl_output_loss: 0.3826 - val_loss: 0.5843 - val_ner_output_accuracy: 0.9488 - val_ner_output_loss: 0.2298 - val_srl_output_accuracy: 0.8900 - val_srl_output_loss: 0.3546\n",
|
|
"Epoch 6/10\n",
|
|
"\u001b[1m64/64\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 15ms/step - loss: 0.5891 - ner_output_accuracy: 0.9554 - ner_output_loss: 0.1810 - srl_output_accuracy: 0.8696 - srl_output_loss: 0.4080 - val_loss: 0.5570 - val_ner_output_accuracy: 0.9488 - val_ner_output_loss: 0.2211 - val_srl_output_accuracy: 0.8981 - val_srl_output_loss: 0.3359\n",
|
|
"Epoch 7/10\n",
|
|
"\u001b[1m64/64\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 14ms/step - loss: 0.5565 - ner_output_accuracy: 0.9522 - ner_output_loss: 0.1804 - srl_output_accuracy: 0.8809 - srl_output_loss: 0.3762 - val_loss: 0.5325 - val_ner_output_accuracy: 0.9494 - val_ner_output_loss: 0.2112 - val_srl_output_accuracy: 0.9031 - val_srl_output_loss: 0.3214\n",
|
|
"Epoch 8/10\n",
|
|
"\u001b[1m64/64\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 14ms/step - loss: 0.4922 - ner_output_accuracy: 0.9584 - ner_output_loss: 0.1499 - srl_output_accuracy: 0.8953 - srl_output_loss: 0.3423 - val_loss: 0.5081 - val_ner_output_accuracy: 0.9506 - val_ner_output_loss: 0.2005 - val_srl_output_accuracy: 0.9087 - val_srl_output_loss: 0.3076\n",
|
|
"Epoch 9/10\n",
|
|
"\u001b[1m64/64\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 15ms/step - loss: 0.4629 - ner_output_accuracy: 0.9591 - ner_output_loss: 0.1441 - srl_output_accuracy: 0.9043 - srl_output_loss: 0.3188 - val_loss: 0.4793 - val_ner_output_accuracy: 0.9544 - val_ner_output_loss: 0.1867 - val_srl_output_accuracy: 0.9144 - val_srl_output_loss: 0.2925\n",
|
|
"Epoch 10/10\n",
|
|
"\u001b[1m64/64\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 15ms/step - loss: 0.4199 - ner_output_accuracy: 0.9601 - ner_output_loss: 0.1453 - srl_output_accuracy: 0.9213 - srl_output_loss: 0.2746 - val_loss: 0.4803 - val_ner_output_accuracy: 0.9581 - val_ner_output_loss: 0.1858 - val_srl_output_accuracy: 0.9112 - val_srl_output_loss: 0.2945\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"history = model.fit(\n",
|
|
" X_train,\n",
|
|
" {\"ner_output\": y_ner_train, \"srl_output\": y_srl_train},\n",
|
|
" validation_data=(X_test, {\"ner_output\": y_ner_test, \"srl_output\": y_srl_test}),\n",
|
|
" batch_size=2,\n",
|
|
" epochs=10,\n",
|
|
" verbose=1\n",
|
|
")\n",
|
|
"\n",
|
|
"# 5. Save artifacts\n",
|
|
"model.save(\"multi_task_lstm_ner_srl_model_tf.keras\")\n",
|
|
"with open(\"word2idx.pkl\", \"wb\") as f:\n",
|
|
" pickle.dump(word2idx, f)\n",
|
|
"with open(\"tag2idx_ner.pkl\", \"wb\") as f:\n",
|
|
" pickle.dump(tag2idx_ner, f)\n",
|
|
"with open(\"tag2idx_srl.pkl\", \"wb\") as f:\n",
|
|
" pickle.dump(tag2idx_srl, f)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 20,
|
|
"id": "333745fd",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"image/png": "",
|
|
"text/plain": [
|
|
"<Figure size 1400x600 with 2 Axes>"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"\n",
|
|
"def plot_training_history(history):\n",
|
|
" epochs = range(1, len(history['loss']) + 1)\n",
|
|
"\n",
|
|
" plt.figure(figsize=(14, 6))\n",
|
|
"\n",
|
|
" # Plot Loss\n",
|
|
" plt.subplot(1, 2, 1)\n",
|
|
" plt.plot(epochs, history['loss'], label='Training Loss')\n",
|
|
" plt.plot(epochs, history['val_loss'], label='Validation Loss')\n",
|
|
" plt.title('Loss During Training')\n",
|
|
" plt.xlabel('Epochs')\n",
|
|
" plt.ylabel('Loss')\n",
|
|
" plt.legend()\n",
|
|
"\n",
|
|
" # Plot Accuracy\n",
|
|
" plt.subplot(1, 2, 2)\n",
|
|
" plt.plot(epochs, history['ner_output_accuracy'], label='NER Train Acc')\n",
|
|
" plt.plot(epochs, history['val_ner_output_accuracy'], label='NER Val Acc')\n",
|
|
" plt.plot(epochs, history['srl_output_accuracy'], label='SRL Train Acc')\n",
|
|
" plt.plot(epochs, history['val_srl_output_accuracy'], label='SRL Val Acc')\n",
|
|
" plt.title('Accuracy During Training')\n",
|
|
" plt.xlabel('Epochs')\n",
|
|
" plt.ylabel('Accuracy')\n",
|
|
" plt.legend()\n",
|
|
"\n",
|
|
" plt.tight_layout()\n",
|
|
" plt.show()\n",
|
|
" \n",
|
|
"plot_training_history(history.history)\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 21,
|
|
"id": "df36e200",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"loss: 0.48032140731811523\n",
|
|
"compile_metrics: 0.18580493330955505\n",
|
|
"ner_output_loss: 0.2945164740085602\n",
|
|
"srl_output_loss: 0.9581250548362732\n",
|
|
"NER Token Accuracy 95.81%\n",
|
|
"SRL Token Accuracy 91.12%\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"def token_level_accuracy(y_true, y_pred):\n",
|
|
" total, correct = 0, 0\n",
|
|
" for true_seq, pred_seq in zip(y_true, y_pred):\n",
|
|
" for t, p in zip(true_seq, pred_seq):\n",
|
|
" if t.sum() == 0:\n",
|
|
" continue\n",
|
|
" total += 1\n",
|
|
" if t.argmax() == p.argmax():\n",
|
|
" correct += 1\n",
|
|
" return correct / total\n",
|
|
"\n",
|
|
"def decode_predictions(pred, true, idx2tag):\n",
|
|
" true_out, pred_out = [], []\n",
|
|
" for pred_seq, true_seq in zip(pred, true):\n",
|
|
" t_labels, p_labels = [], []\n",
|
|
" for p_tok, t_tok in zip(pred_seq, true_seq):\n",
|
|
" if t_tok.sum() == 0:\n",
|
|
" continue\n",
|
|
" t_labels.append(idx2tag[t_tok.argmax()])\n",
|
|
" p_labels.append(idx2tag[p_tok.argmax()])\n",
|
|
" true_out.append(t_labels)\n",
|
|
" pred_out.append(p_labels)\n",
|
|
" return true_out, pred_out\n",
|
|
"\n",
|
|
"results = model.evaluate(X_test, {\"ner_output\": y_ner_test, \"srl_output\": y_srl_test}, verbose=0)\n",
|
|
"for name, value in zip(model.metrics_names, results):\n",
|
|
" print(f\"{name}: {value}\")\n",
|
|
"\n",
|
|
"y_pred_ner, y_pred_srl = model.predict(X_test, verbose=0)\n",
|
|
"\n",
|
|
"true_ner, pred_ner = decode_predictions(y_pred_ner, y_ner_test, idx2tag_ner)\n",
|
|
"true_srl, pred_srl = decode_predictions(y_pred_srl, y_srl_test, idx2tag_srl)\n",
|
|
"\n",
|
|
"acc_ner = token_level_accuracy(y_ner_test, y_pred_ner)\n",
|
|
"acc_srl = token_level_accuracy(y_srl_test, y_pred_srl)\n",
|
|
"\n",
|
|
"print(f\"NER Token Accuracy {acc_ner:.2%}\")\n",
|
|
"print(f\"SRL Token Accuracy {acc_srl:.2%}\")\n",
|
|
"\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 22,
|
|
"id": "9127cce0",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"[NER] Classification Report:\n",
|
|
" precision recall f1-score support\n",
|
|
"\n",
|
|
" DATE 0.80 0.36 0.50 11\n",
|
|
" EVENT 0.00 0.00 0.00 1\n",
|
|
" LOC 1.00 0.38 0.55 21\n",
|
|
" MIN 0.00 0.00 0.00 3\n",
|
|
" MISC 0.00 0.00 0.00 1\n",
|
|
" ORG 0.00 0.00 0.00 3\n",
|
|
" PER 0.00 0.00 0.00 2\n",
|
|
" RES 0.00 0.00 0.00 2\n",
|
|
" TIME 0.33 0.38 0.35 8\n",
|
|
"\n",
|
|
" micro avg 0.68 0.29 0.41 52\n",
|
|
" macro avg 0.24 0.12 0.16 52\n",
|
|
"weighted avg 0.62 0.29 0.38 52\n",
|
|
"\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"/mnt/disc1/code/thesis_quiz_project/lstm-quiz/myenv/lib64/python3.10/site-packages/seqeval/metrics/v1.py:57: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
|
|
" _warn_prf(average, modifier, msg_start, len(result))\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"print(\"[NER] Classification Report:\")\n",
|
|
"print(classification_report(true_ner, pred_ner, digits=2))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 23,
|
|
"id": "300897b8",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"SRL Classification Resport:\n",
|
|
" precision recall f1-score support\n",
|
|
"\n",
|
|
" FRQ 0.00 0.00 0.00 1\n",
|
|
" LOC 0.27 0.43 0.33 7\n",
|
|
" MNR 0.00 0.00 0.00 3\n",
|
|
" PRP 0.00 0.00 0.00 1\n",
|
|
" RG0 0.38 0.18 0.24 17\n",
|
|
" RG1 0.23 0.21 0.22 47\n",
|
|
" RG2 0.18 0.27 0.21 11\n",
|
|
" RG3 0.00 0.00 0.00 3\n",
|
|
" TMP 0.55 0.61 0.58 18\n",
|
|
" _ 0.40 0.12 0.19 33\n",
|
|
"\n",
|
|
" micro avg 0.31 0.24 0.27 141\n",
|
|
" macro avg 0.20 0.18 0.18 141\n",
|
|
"weighted avg 0.31 0.24 0.25 141\n",
|
|
"\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"/mnt/disc1/code/thesis_quiz_project/lstm-quiz/myenv/lib64/python3.10/site-packages/seqeval/metrics/sequence_labeling.py:171: UserWarning: ARG1 seems not to be NE tag.\n",
|
|
" warnings.warn('{} seems not to be NE tag.'.format(chunk))\n",
|
|
"/mnt/disc1/code/thesis_quiz_project/lstm-quiz/myenv/lib64/python3.10/site-packages/seqeval/metrics/sequence_labeling.py:171: UserWarning: ARG0 seems not to be NE tag.\n",
|
|
" warnings.warn('{} seems not to be NE tag.'.format(chunk))\n",
|
|
"/mnt/disc1/code/thesis_quiz_project/lstm-quiz/myenv/lib64/python3.10/site-packages/seqeval/metrics/sequence_labeling.py:171: UserWarning: V seems not to be NE tag.\n",
|
|
" warnings.warn('{} seems not to be NE tag.'.format(chunk))\n",
|
|
"/mnt/disc1/code/thesis_quiz_project/lstm-quiz/myenv/lib64/python3.10/site-packages/seqeval/metrics/sequence_labeling.py:171: UserWarning: ARGM-TMP seems not to be NE tag.\n",
|
|
" warnings.warn('{} seems not to be NE tag.'.format(chunk))\n",
|
|
"/mnt/disc1/code/thesis_quiz_project/lstm-quiz/myenv/lib64/python3.10/site-packages/seqeval/metrics/sequence_labeling.py:171: UserWarning: ARGM-PRP seems not to be NE tag.\n",
|
|
" warnings.warn('{} seems not to be NE tag.'.format(chunk))\n",
|
|
"/mnt/disc1/code/thesis_quiz_project/lstm-quiz/myenv/lib64/python3.10/site-packages/seqeval/metrics/sequence_labeling.py:171: UserWarning: ARGM-LOC seems not to be NE tag.\n",
|
|
" warnings.warn('{} seems not to be NE tag.'.format(chunk))\n",
|
|
"/mnt/disc1/code/thesis_quiz_project/lstm-quiz/myenv/lib64/python3.10/site-packages/seqeval/metrics/sequence_labeling.py:171: UserWarning: ARG2 seems not to be NE tag.\n",
|
|
" warnings.warn('{} seems not to be NE tag.'.format(chunk))\n",
|
|
"/mnt/disc1/code/thesis_quiz_project/lstm-quiz/myenv/lib64/python3.10/site-packages/seqeval/metrics/sequence_labeling.py:171: UserWarning: ARGM-MNR seems not to be NE tag.\n",
|
|
" warnings.warn('{} seems not to be NE tag.'.format(chunk))\n",
|
|
"/mnt/disc1/code/thesis_quiz_project/lstm-quiz/myenv/lib64/python3.10/site-packages/seqeval/metrics/sequence_labeling.py:171: UserWarning: ARGM-FRQ seems not to be NE tag.\n",
|
|
" warnings.warn('{} seems not to be NE tag.'.format(chunk))\n",
|
|
"/mnt/disc1/code/thesis_quiz_project/lstm-quiz/myenv/lib64/python3.10/site-packages/seqeval/metrics/sequence_labeling.py:171: UserWarning: ARG3 seems not to be NE tag.\n",
|
|
" warnings.warn('{} seems not to be NE tag.'.format(chunk))\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"print(\"SRL Classification Resport:\")\n",
|
|
"print(classification_report(true_srl, pred_srl, digits=2))"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "myenv",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.10.16"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|