[YOLO / Object Detection / Keras] Code Review - [1]
저번 포스팅에 이어서 계속해서 Train.py를 리뷰하도록 하겠습니다.
저번주에 해당 코드까지 리뷰하였습니다.
def _main_(args): config_path = args.conf with open(config_path) as config_buffer: config = json.loads(config_buffer.read()) ############################### # Parse the annotations ############################### # parse annotations of the training set train_imgs, train_labels = parse_annotation(config['train']['train_annot_folder'], config['train']['train_image_folder'], config['model']['labels']) . . . (생략)
오늘은 그 이후에 대한 코드를 계속해서 리뷰하도록 하겠습니다.
오늘 리뷰할 코드는 아래와 같습니다.
# parse annotations of the validation set, if any, otherwise split the training set if os.path.exists(config['valid']['valid_annot_folder']): valid_imgs, valid_labels = parse_annotation(config['valid']['valid_annot_folder'], config['valid']['valid_image_folder'], config['model']['labels']) else: train_valid_split = int(0.8*len(train_imgs)) np.random.shuffle(train_imgs) valid_imgs = train_imgs[train_valid_split:] train_imgs = train_imgs[:train_valid_split] overlap_labels = set(config['model']['labels']).intersection(set(train_labels.keys())) print('Seen labels:\t', train_labels) print('Given labels:\t', config['model']['labels']) print('Overlap labels:\t', overlap_labels) if len(overlap_labels) < len(config['model']['labels']): print('Some labels have no images! Please revise the list of labels in the config.json file!') return ############################### # Construct the model ############################### yolo = YOLO(architecture = config['model']['architecture'], input_size = config['model']['input_size'], labels = config['model']['labels'], max_box_per_image = config['model']['max_box_per_image'], anchors = config['model']['anchors']) ############################### # Load the pretrained weights (if any) ############################### if os.path.exists(config['train']['pretrained_weights']): print("Loading pre-trained weights in", config['train']['pretrained_weights']) yolo.load_weights(config['train']['pretrained_weights']) ############################### # Start the training process ############################### yolo.train(train_imgs = train_imgs, valid_imgs = valid_imgs, train_times = config['train']['train_times'], valid_times = config['valid']['valid_times'], nb_epoch = config['train']['nb_epoch'], learning_rate = config['train']['learning_rate'], batch_size = config['train']['batch_size'], warmup_epochs = config['train']['warmup_epochs'], object_scale = config['train']['object_scale'], no_object_scale = config['train']['no_object_scale'], coord_scale = config['train']['coord_scale'], class_scale = config['train']['class_scale'], saved_weights_name = config['train']['saved_weights_name'], debug = config['train']['debug'])
1. config.json파일에 valid_annot_folder의 경로가 있으면, 전 포스팅과 같은 방법으로 validation dataset을 parsing합니다.
2. 만약에 config.json파일에 valid_annot_folder의 경로가 없으면, trainSet에서 20%의 비율로 shuffle해서 추출합니다.
3. trainSet data를 parsing할 때, 모아뒀던 label의 key값들과, config.json파일에 있는 label 리스트의 교집합만 추출하여 overlap_labels에 대입합니다.
4. overlap_labels의 길이와 config.json에 있는 labels의 리스트 길이를 확인해서 label이 모두 있는지 validation check를 합니다.
5. YOLO 메소드에 파라미터들을 넣어주고 yolo에 리턴값을 대입합니다.
6. config.json에 pretrained_weights가 있다면 pretrained weights를 로드합니다.
7. yolo 오브젝트에서 train메소드를 config.json의 파라미터를 인수로 호출합니다.
다음 포스팅에서는 keras 코드의 yolo 모델을 천천히 리뷰하도록 하겠습니다.