파이썬/텐서플로우

Tensorflow CNN 이미지분류 모델링하기:

공부짱짱열심히하기 2022. 12. 30. 13:10

트레이닝과 밸리데이션에 사용할 사진 자료 압축 풀기

https://seonggongstory.tistory.com/156

 

파이썬 압축파일 푸는 방법

import zipfile zipfile.ZipFile('/경로/'풀어줘야할파일이름') file.extractall('/압축풀경로/'압축풀고저장할폴더이름만들기')

seonggongstory.tistory.com

 

train_horse_dir = '/tmp/horse-or-human/horses'
train_human_dir = '/tmp/horse-or-human/humans'
validation_horse_dir = '/tmp/validation-horse-or-human/horses'
validation_human_dir = '/tmp/validation-horse-or-human/humans'

경로 저장후

# 파일명이나 파일 갯수 확인
import os
os.listdir(train_horse_dir)
len(  os.listdir(train_horse_dir) )

 


모델링 하기

import tensorflow as tf
from keras.layers import Dense, Flatten, Conv2D, MaxPooling2D
from keras.models import Sequential
def build_model():
  model = Sequential()
  model.add( Conv2D(16, (3,3) , activation='relu' , input_shape=(300,300,3)  ) )
  model.add( MaxPooling2D( (2,2) , 2 ) )
  model.add( Conv2D(32, (3,3), activation='relu'))
  model.add( MaxPooling2D( (2,2) , 2 ) )
  model.add( Conv2D(64, (3,3) , activation='relu'))
  model.add( MaxPooling2D( (2,2) , 2 ) )

  model.add(Flatten())
  model.add(Dense(512, 'relu'))
  model.add(Dense(1, 'sigmoid'))
  
  model.compile('rmsprop', 'binary_crossentropy', metrics=['accuracy'])
  return model
model = build_model()

Data Preprocessing

 

 

아직 학습할 준비가 다 되지 않았다.왜냐하면, fit 함수에 들어가는 데이터는 넘파이 어레이가 들어가야 한다. 하지만 우리가 가지고 있는 데이터는, 이미지 파일(png) 이다.따라서, 현재 상태로는 fit 함수 이용한 학습 불가능 하다.

 

https://seonggongstory.tistory.com/158

 

이미지 파일을 학습데이터로 바꿔주는 ImageDataGenerator

CNN모델링이 끝난후 사진을 머신러닝에 넣어 분류를 하려고 할때 사진파일은(PNG,JPG등) 넘파이 어레이가 아니기 때문에 그대로 사용이 불가능 하다. 이때 이미지파일을 학습데이터로 변환시켜 주

seonggongstory.tistory.com

 

라이브러리를 변수로 만들었으면, 그다음 할일은, 이미지가 들어있는 디렉토리의 정보와 이미지 사이즈정보와 몇개로 분류할지 정보를 알려준다.

넘파이의 target_size 와  모델의 input_shape 은, 가로 세로가 같아야 한다.

클래스모드는, 2개로 분류할땐 binary, 3개 이상일땐 categorical 사용.

 

 

train_generator = train_datagen.flow_from_directory('/tmp/horse-or-human' , target_size=(300,300) , class_mode= 'binary')
validation_generator = validation_datagen.flow_from_directory('/tmp/validation-horse-or-human' , target_size=(300,300) , class_mode= 'binary')

Training

epoch_history = model.fit(train_generator, epochs=15 , validation_data=(validation_generator) )


평가하기

 

model.evaluate(validation_generator)

import matplotlib.pyplot as plt
plt.plot(epoch_history.history['accuracy'])
plt.plot(epoch_history.history['val_accuracy'])
plt.legend(['train' , 'validation'])
plt.show()