본문 바로가기
python/tensorflow,keras

학습한 모델 저장

by wjwkddyd221001 2022. 10. 1.

1. 모델 + 가중치 저장하기 >> h5

MODEL_PATH = "/model"

model.save(MODEL_PATH + "model.h5")

new_model = tf.keras.models.load_model(MODEL_PATH + "model.h5")

test_loss, test_acc = new_model.evaluate(x, y, verbose=2)

2. 가중치만 저장하기

2-1. save_weights() >> h5

model.save_weights("model_weight")

new_model = tf.keras.models.Sequential([
    tf.keras.layers.Input()
    ...
])

new_model.load_weights("model_weights")

test_loss, test_acc = new_model.evaluate(x,  y, verbose=2)

2-2. to_json() >> json

model_json = model.to_json()
with open('model_json.json', 'w') as f:
    f.write(model_json)

사진 출처: https://hidden-loca.tistory.com/20

2-3. to_yaml() >> yaml

model_yaml=model.to_yaml()
with open('model_yaml.yaml', 'w') as f:
    f.write(model_yaml)

사진 출처: https://hidden-loca.tistory.com/20

3. ModelCheckpoint() >> model_ckpt.h5

from keras.callbacks import ModelCheckpoint, EarlyStopping

filename = 'checkpoint-epoch-{}-batch-{}-trial-001.h5'.format(EPOCH, BATCH_SIZE)
checkpoint = ModelCheckpoint(filename,             
                             monitor='val_loss',   
                             verbose=1,            
                             save_best_only=True,  # 가장 best 값만 저장합니다
                             mode='auto')          # auto는 알아서 best를 찾습니다. min/max

early_stopping = EarlyStopping(monitor = 'val_loss',
                               min_delta = 0,
                               patience = 5,
                               verbose = 1,
                               restore_best_weights = True)

model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

## ImageDataGenerator 사용했다고 가정
history = model.fit(train_generator, epochs=EPOCH,
                      validation_data = validation_generator
                      callbacks=[checkpoint, earlystopping])    # checkpoint, earlystopping 콜백

'python > tensorflow,keras' 카테고리의 다른 글

keras one-hot encoding  (0) 2022.10.05
tensorflow.keras.callbacks  (0) 2022.10.01
keras ImageDataGenerator  (0) 2022.10.01