본문 바로가기
python/tensorflow,keras

keras ImageDataGenerator

by wjwkddyd221001 2022. 10. 1.

1. ImageDataGenerator 선언

import tensorflow as tf
from tensorflow import keras
from keras.preprocessing.image import ImageDataGenerator

batch_size = 16
img_height = 96
img_width = 200

train_datagen = ImageDataGenerator(
    rescale = 1./ 255,  # 이미지 데이터 정규화
    validation_split = 0.2, # train, validation 데이터 분할 (8:2)
)
train_generator = train_datagen.flow_from_directory(
    TRAIN_PATH,
    batch_size=batch_size,
    shuffle=True,
    target_size=(img_height, img_width),    
    class_mode='categorical',
    subset='training',
)
validation_generator = train_datagen.flow_from_directory(
    TRAIN_PATH,
    target_size=(img_height, img_width),
    batch_size=batch_size,
    class_mode='categorical',
    subset='validation',
)

## 결과
# Found 971 images belonging to 4 classes.
# Found 240 images belonging to 4 classes.

200 x 96 크기의 이미지를 multi class classification 하기 위해 train set, validation set을 생성하는 코드

 

TRAIN_PATH는 train data가 있는 경로이고 다음과 같은 구조를 가져야 한다.

TRAIN_PATH
ㄴ class 1
     ㄴ 그림파일들...
ㄴ class 2
    ㄴ 그림파일들...
ㄴ ...

 

 

2. 분류된 클래스 이름은 숫자와 매칭된다.

print(train_generator.class_indices)

## 결과
# {'enter': 0, 'marrymeboss': 1, 'psychorus': 2, 'samang': 3}

'python > tensorflow,keras' 카테고리의 다른 글

keras one-hot encoding  (0) 2022.10.05
tensorflow.keras.callbacks  (0) 2022.10.01
학습한 모델 저장  (0) 2022.10.01