185 lines
4.1 KiB
Python
185 lines
4.1 KiB
Python
import tensorflow as tf
|
|
import numpy as np
|
|
import matplotlib.pyplot as plt
|
|
import seaborn as sns
|
|
import os
|
|
|
|
from tensorflow.keras import layers, models
|
|
from tensorflow.keras.applications import MobileNetV2
|
|
from tensorflow.keras.preprocessing.image import ImageDataGenerator
|
|
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
|
|
|
|
from sklearn.metrics import confusion_matrix, classification_report
|
|
|
|
IMG_SIZE = 224
|
|
BATCH_SIZE = 32
|
|
|
|
train_dir = "dataset_split_v2/train"
|
|
val_dir = "dataset_split_v2/val"
|
|
|
|
# AUGMENTASI
|
|
train_gen = ImageDataGenerator(
|
|
rescale=1./255,
|
|
rotation_range=20,
|
|
zoom_range=0.2,
|
|
shear_range=0.15,
|
|
width_shift_range=0.1,
|
|
height_shift_range=0.1,
|
|
horizontal_flip=True,
|
|
fill_mode='nearest'
|
|
)
|
|
|
|
val_gen = ImageDataGenerator(rescale=1./255)
|
|
|
|
train_data = train_gen.flow_from_directory(
|
|
train_dir,
|
|
target_size=(IMG_SIZE, IMG_SIZE),
|
|
batch_size=BATCH_SIZE,
|
|
class_mode='categorical'
|
|
)
|
|
|
|
val_data = val_gen.flow_from_directory(
|
|
val_dir,
|
|
target_size=(IMG_SIZE, IMG_SIZE),
|
|
batch_size=BATCH_SIZE,
|
|
class_mode='categorical',
|
|
shuffle=False
|
|
)
|
|
|
|
num_classes = train_data.num_classes
|
|
class_labels = list(train_data.class_indices.keys())
|
|
|
|
print("Class indices:", train_data.class_indices)
|
|
|
|
|
|
# MODEL
|
|
base_model = MobileNetV2(
|
|
input_shape=(IMG_SIZE, IMG_SIZE, 3),
|
|
include_top=False,
|
|
weights='imagenet'
|
|
)
|
|
|
|
base_model.trainable = False
|
|
|
|
x = base_model.output
|
|
x = layers.GlobalAveragePooling2D()(x)
|
|
x = layers.Dense(128, activation='relu')(x)
|
|
x = layers.Dropout(0.5)(x)
|
|
output = layers.Dense(num_classes, activation='softmax')(x)
|
|
|
|
model = models.Model(inputs=base_model.input, outputs=output)
|
|
|
|
model.compile(
|
|
optimizer='adam',
|
|
loss='categorical_crossentropy',
|
|
metrics=['accuracy']
|
|
)
|
|
|
|
early_stop = EarlyStopping(
|
|
monitor='val_loss',
|
|
patience=5,
|
|
restore_best_weights=True
|
|
)
|
|
|
|
checkpoint = ModelCheckpoint(
|
|
"best_model_cnn.h5",
|
|
monitor='val_accuracy',
|
|
save_best_only=True,
|
|
mode='max'
|
|
)
|
|
|
|
# TRAINING AWAL
|
|
history = model.fit(
|
|
train_data,
|
|
validation_data=val_data,
|
|
epochs=30,
|
|
callbacks=[early_stop, checkpoint]
|
|
)
|
|
|
|
|
|
# FINE TUNING
|
|
base_model.trainable = True
|
|
|
|
for layer in base_model.layers[:-30]:
|
|
layer.trainable = False
|
|
|
|
model.compile(
|
|
optimizer=tf.keras.optimizers.Adam(1e-5),
|
|
loss='categorical_crossentropy',
|
|
metrics=['accuracy']
|
|
)
|
|
|
|
history_fine = model.fit(
|
|
train_data,
|
|
validation_data=val_data,
|
|
epochs=10
|
|
)
|
|
|
|
|
|
# SAVE MODEL
|
|
model.save("model_telur_cnn_new.h5")
|
|
|
|
print("Training selesai!")
|
|
|
|
# EVALUASI
|
|
loss, acc = model.evaluate(val_data)
|
|
print(f"Akurasi akhir: {acc*100:.2f}%")
|
|
|
|
# CONFUSION MATRIX
|
|
pred = model.predict(val_data)
|
|
y_pred = np.argmax(pred, axis=1)
|
|
y_true = val_data.classes
|
|
|
|
cm = confusion_matrix(y_true, y_pred)
|
|
|
|
# simpan gambar CM
|
|
os.makedirs("output_cm", exist_ok=True)
|
|
|
|
plt.figure(figsize=(6,5))
|
|
sns.heatmap(cm, annot=True, fmt='d',
|
|
xticklabels=class_labels,
|
|
yticklabels=class_labels,
|
|
cmap='Blues')
|
|
|
|
plt.title('Confusion Matrix')
|
|
plt.savefig("output_cm/confusion_matrix.png")
|
|
plt.show()
|
|
|
|
|
|
# NORMALIZED CM
|
|
cm_norm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
|
|
|
|
plt.figure(figsize=(6,5))
|
|
sns.heatmap(cm_norm, annot=True, fmt='.2f',
|
|
xticklabels=class_labels,
|
|
yticklabels=class_labels,
|
|
cmap='Blues')
|
|
|
|
plt.title("Normalized Confusion Matrix")
|
|
plt.savefig("output_cm/confusion_matrix_normalized.png")
|
|
plt.show()
|
|
|
|
print("\nClassification Report:")
|
|
print(classification_report(y_true, y_pred, target_names=class_labels))
|
|
|
|
|
|
# GRAFIK GABUNGAN
|
|
acc_total = history.history['accuracy'] + history_fine.history['accuracy']
|
|
val_acc_total = history.history['val_accuracy'] + history_fine.history['val_accuracy']
|
|
|
|
loss_total = history.history['loss'] + history_fine.history['loss']
|
|
val_loss_total = history.history['val_loss'] + history_fine.history['val_loss']
|
|
|
|
plt.plot(acc_total, label='train_acc')
|
|
plt.plot(val_acc_total, label='val_acc')
|
|
plt.legend()
|
|
plt.title("Accuracy Gabungan")
|
|
plt.savefig("output_cm/accuracy.png")
|
|
plt.show()
|
|
|
|
plt.plot(loss_total, label='train_loss')
|
|
plt.plot(val_loss_total, label='val_loss')
|
|
plt.legend()
|
|
plt.title("Loss Gabungan")
|
|
plt.savefig("output_cm/loss.png")
|
|
plt.show() |