[YOLO / Object Detection / Keras] Code Review - [1]
저번 포스팅에 이어서 계속해서 Train.py를 리뷰하도록 하겠습니다.
저번주에 해당 코드까지 리뷰하였습니다.
def _main_(args):
config_path = args.conf
with open(config_path) as config_buffer:
config = json.loads(config_buffer.read())
###############################
# Parse the annotations
###############################
# parse annotations of the training set
train_imgs, train_labels = parse_annotation(config['train']['train_annot_folder'],
config['train']['train_image_folder'],
config['model']['labels'])
.
.
.
(생략)
오늘은 그 이후에 대한 코드를 계속해서 리뷰하도록 하겠습니다.
오늘 리뷰할 코드는 아래와 같습니다.
# parse annotations of the validation set, if any, otherwise split the training set
if os.path.exists(config['valid']['valid_annot_folder']):
valid_imgs, valid_labels = parse_annotation(config['valid']['valid_annot_folder'],
config['valid']['valid_image_folder'],
config['model']['labels'])
else:
train_valid_split = int(0.8*len(train_imgs))
np.random.shuffle(train_imgs)
valid_imgs = train_imgs[train_valid_split:]
train_imgs = train_imgs[:train_valid_split]
overlap_labels = set(config['model']['labels']).intersection(set(train_labels.keys()))
print('Seen labels:\t', train_labels)
print('Given labels:\t', config['model']['labels'])
print('Overlap labels:\t', overlap_labels)
if len(overlap_labels) < len(config['model']['labels']):
print('Some labels have no images! Please revise the list of labels in the config.json file!')
return
###############################
# Construct the model
###############################
yolo = YOLO(architecture = config['model']['architecture'],
input_size = config['model']['input_size'],
labels = config['model']['labels'],
max_box_per_image = config['model']['max_box_per_image'],
anchors = config['model']['anchors'])
###############################
# Load the pretrained weights (if any)
###############################
if os.path.exists(config['train']['pretrained_weights']):
print("Loading pre-trained weights in", config['train']['pretrained_weights'])
yolo.load_weights(config['train']['pretrained_weights'])
###############################
# Start the training process
###############################
yolo.train(train_imgs = train_imgs,
valid_imgs = valid_imgs,
train_times = config['train']['train_times'],
valid_times = config['valid']['valid_times'],
nb_epoch = config['train']['nb_epoch'],
learning_rate = config['train']['learning_rate'],
batch_size = config['train']['batch_size'],
warmup_epochs = config['train']['warmup_epochs'],
object_scale = config['train']['object_scale'],
no_object_scale = config['train']['no_object_scale'],
coord_scale = config['train']['coord_scale'],
class_scale = config['train']['class_scale'],
saved_weights_name = config['train']['saved_weights_name'],
debug = config['train']['debug'])
1. config.json파일에 valid_annot_folder의 경로가 있으면, 전 포스팅과 같은 방법으로 validation dataset을 parsing합니다.
2. 만약에 config.json파일에 valid_annot_folder의 경로가 없으면, trainSet에서 20%의 비율로 shuffle해서 추출합니다.
3. trainSet data를 parsing할 때, 모아뒀던 label의 key값들과, config.json파일에 있는 label 리스트의 교집합만 추출하여 overlap_labels에 대입합니다.
4. overlap_labels의 길이와 config.json에 있는 labels의 리스트 길이를 확인해서 label이 모두 있는지 validation check를 합니다.
5. YOLO 메소드에 파라미터들을 넣어주고 yolo에 리턴값을 대입합니다.
6. config.json에 pretrained_weights가 있다면 pretrained weights를 로드합니다.
7. yolo 오브젝트에서 train메소드를 config.json의 파라미터를 인수로 호출합니다.
다음 포스팅에서는 keras 코드의 yolo 모델을 천천히 리뷰하도록 하겠습니다.