train_datagen = ImageDataGenerator(
                                    rotation_range=10,# randomly rotate images in the range (degrees, 0 to 180)
                                    width_shift_range=0.2,  # randomly shift images horizontally (fraction of total width)
                                    height_shift_range=0.2,  # randomly shift images vertically (fraction of total height)
                                    horizontal_flip=True,  # randomly flip images
                                    vertical_flip=False,
                                    rescale=1./255,
                                    shear_range=0.2,
                                    zoom_range=0.5)

validation_datagen = ImageDataGenerator(rescale=1./255)

train_generator = train_datagen.flow_from_directory(
                                    path_train,
                                    target_size=(width, height),
                                    batch_size=batch_size,
                                    color_mode='grayscale',
                                    class_mode='categorical')

validation_generator = validation_datagen.flow_from_directory(
                                    path_validation,
                                    target_size=(width, height),
                                    batch_size=batch_size,
                                    color_mode='grayscale',
                                    class_mode='categorical')

 

save_dir = os.path.join(r'../root', 'saved_models')
model_name = 'keras_Where_am_I.h5'
# Use ModelCheckpoint to save model and weights
if not os.path.isdir(save_dir):
    os.makedirs(save_dir)
model_path = os.path.join(save_dir, model_name)
checkpoint = ModelCheckpoint(model_path, monitor='val_loss', save_best_only=True, verbose=1)

# earlystop
earlystop = EarlyStopping(monitor='val_loss', patience=5, verbose=1)


model_history=model.fit_generator(  train_generator,
                                    steps_per_epoch=2000,
                                    epochs=epochs,
                                    workers=16,
                                    validation_data=validation_generator,
                                    validation_steps=800,
                                    callbacks=[earlystop, checkpoint])
 

# loading our save model
print("Loading trained model")
model = load_model(model_path)

# Score trained model.
score = model.evaluate_generator(validation_generator, workers=12)
print('Test loss:', score[0])
print('Test accuracy:', score[1])

 

創作者介紹
創作者 CrownTail的部落格 的頭像
CrownTail

CrownTail的部落格

CrownTail 發表在 痞客邦 留言(0) 人氣( 18 )