TIF_E41211115_lstm-quiz-gen.../NER_SRL/convert_dts.py

219 lines
7.5 KiB
Python

import json
import csv
from pathlib import Path
import json
import csv
from pathlib import Path
# Daftar label NER yang valid (bisa disesuaikan)
VALID_NER_LABELS = {
"O",
"LOC",
"LOC",
"PER",
"PER",
"ORG",
"ORG",
"DATE",
"DATE",
"TIME",
"TIME",
"EVENT",
"EVENT",
"MISC",
}
# Daftar label NER yang valid (bisa disesuaikan)
VALID_NER_LABELS = {"O", "LOC", "PER", "ORG", "DATE", "TIME", "EVENT", "MISC"}
# Daftar label SRL yang valid
VALID_SRL_LABELS = {
"ARG0",
"ARG1",
"ARG2",
"ARG3",
"ARGM-TMP",
"ARGM-LOC",
"ARGM-CAU",
"ARGM-MNR",
"ARGM-MOD",
"ARGM-NEG",
"V",
"O",
}
# def json_to_tsv(json_path: str | Path, tsv_path: str | Path) -> None:
# with open(json_path, encoding="utf-8") as f:
# records = json.load(f)
# seen_sentences: set[tuple[str, ...]] = set()
# with open(tsv_path, "w", encoding="utf-8", newline="") as f_out:
# writer = csv.writer(f_out, delimiter="\t", lineterminator="\n")
# for idx, rec in enumerate(records):
# contexxt = rec.get("context")
# tokens = rec.get("tokens")
# ner_tags = rec.get("ner")
# srl_tags = rec.get("srl")
# if not (len(tokens) == len(ner_tags) == len(srl_tags)):
# raise ValueError(
# f"❌ Panjang tidak sama di record index {idx}:\n"
# f" context ({len(contexxt)}): {contexxt}\n"
# f" tokens ({len(tokens)}): {tokens}\n"
# f" ner ({len(ner_tags)}): {ner_tags}\n"
# f" srl ({len(srl_tags)}): {srl_tags}\n"
# )
# # Validasi label NER
# for i, ner_label in enumerate(ner_tags):
# if ner_label not in VALID_NER_LABELS:
# raise ValueError(
# f"❌ Label NER tidak valid di record index {idx}, token ke-{i} ('{tokens[i]}'):\n"
# f" ner_label: {ner_label}\n"
# f" value: {tokens}"
# )
# # Validasi label SRL
# for i, srl_label in enumerate(srl_tags):
# if srl_label not in VALID_SRL_LABELS:
# raise ValueError(
# f"❌ Label SRL tidak valid di record index {idx}, token ke-{i} ('{tokens[i]}'):\n"
# f" srl_label: {srl_label}\n"
# f" value: {tokens}"
# )
# key = tuple(tokens)
# if key in seen_sentences:
# continue
# seen_sentences.add(key)
# for tok, ner, srl in zip(tokens, ner_tags, srl_tags):
# writer.writerow([tok, ner, srl])
# writer.writerow([])
# print(f"✔️ TSV selesai, simpan di: {tsv_path}")
def json_to_tsv(json_path: str | Path, tsv_path: str | Path) -> None:
with open(json_path, encoding="utf-8") as f:
records = json.load(f)
seen_sentences: set[tuple[str, ...]] = set()
with open(tsv_path, "w", encoding="utf-8", newline="") as f_out:
writer = csv.writer(f_out, delimiter="\t", lineterminator="\n")
for idx, rec in enumerate(records):
context = rec.get("context")
tokens = rec.get("tokens")
ner_tags = rec.get("ner")
srl_tags = rec.get("srl")
if not (len(tokens) == len(ner_tags) == len(srl_tags)):
print(
f"❌ Panjang tidak sama di record index {idx}:\n"
f" context: {context}\n"
f" tokens ({len(tokens)}): {tokens}\n"
f" ner ({len(ner_tags)}): {ner_tags}\n"
f" srl ({len(srl_tags)}): {srl_tags}\n"
)
continue
invalid_ner = False
for i, ner_label in enumerate(ner_tags):
if ner_label not in VALID_NER_LABELS:
print(
f"❌ Label NER tidak valid di record index {idx}, token ke-{i} ('{tokens[i]}'):\n"
f" ner_label: {ner_label}\n"
f" value: {tokens}"
)
invalid_ner = True
break
if invalid_ner:
continue
invalid_srl = False
for i, srl_label in enumerate(srl_tags):
if srl_label not in VALID_SRL_LABELS:
print(
f"❌ Label SRL tidak valid di record index {idx}, token ke-{i} ('{tokens[i]}'):\n"
f" srl_label: {srl_label}\n"
f" value: {tokens}"
)
invalid_srl = True
break
if invalid_srl:
continue
key = tuple(tokens)
if key in seen_sentences:
continue
seen_sentences.add(key)
for tok, ner, srl in zip(tokens, ner_tags, srl_tags):
writer.writerow([tok, ner, srl])
writer.writerow([])
print(f"✔️ TSV selesai, simpan di: {tsv_path}")
# def json_to_tsv(json_path: str | Path, tsv_path: str | Path) -> None:
# """
# Konversi data JSON (field: tokens, ner, srl, …) → TSV token\tNER\tSRL.
# Kalimat duplikat (urutan tokens persis sama) otomatis dilewati.
# Jika ada record yang tokens, ner, dan srl tidak sama panjang, akan diberi info error lengkap.
# """
# # ---------------------------------------------------------------------
# # 1. Baca semua record dari JSON
# # ---------------------------------------------------------------------
# with open(json_path, encoding="utf-8") as f:
# records = json.load(f)
# # ---------------------------------------------------------------------
# # 2. Tulis ke TSV, sambil mendeteksi duplikat
# # ---------------------------------------------------------------------
# seen_sentences: set[tuple[str, ...]] = set() # simpan tuple tokens unik
# with open(tsv_path, "w", encoding="utf-8", newline="") as f_out:
# writer = csv.writer(f_out, delimiter="\t", lineterminator="\n")
# for idx, rec in enumerate(records):
# tokens = rec.get("tokens")
# ner_tags = rec.get("ner")
# srl_tags = rec.get("srl")
# # -- cek panjang sama
# if not (len(tokens) == len(ner_tags) == len(srl_tags)):
# raise ValueError(
# f"❌ Panjang tidak sama di record index {idx}:\n"
# f" tokens ({len(tokens)}): {tokens}\n"
# f" ner ({len(ner_tags)}): {ner_tags}\n"
# f" srl ({len(srl_tags)}): {srl_tags}\n"
# )
# # -- cek duplikat kalimat
# key = tuple(tokens) # tuple hash-able
# if key in seen_sentences: # sudah pernah ditulis → skip
# continue
# seen_sentences.add(key)
# # -- tulis baris token, NER, SRL
# for tok, ner, srl in zip(tokens, ner_tags, srl_tags):
# writer.writerow([tok, ner, srl])
# # -- baris kosong pemisah antar-kalimat
# writer.writerow([])
# print(f"✔️ TSV selesai, simpan di: {tsv_path}")
# ---------------------------------------------------------------------------
# CONTOH PEMAKAIAN
# ---------------------------------------------------------------------------
if __name__ == "__main__":
json_to_tsv("../dataset/stable_qg_qa_train_dataset.json", "new_LNS_2.tsv")