파이썬/텐서플로우

Tensorflow Callback class를 이용한 오버피팅 방지

공부짱짱열심히하기 2022. 12. 29. 11:51
class myCallback(tf.keras.callbacks.Callback) :
  def on_epoch_end(self, epoch, logs={}) :
    if logs['val_accuracy'] > 0.88:
      print('\n내가 정한 정확도에 도달했으니, 학습을 멈춘다.')
      self.model.stop_training = True
my_cb = myCallback()
model = build_model()
epoch_history = model.fit(X_train, y_train, epochs=30, validation_split=0.2,
                          callbacks = [my_cb])