본문 바로가기

IT/Deeplearning

[YOLO / Object Detection / Keras] Code Review - [2]

[YOLO / Object Detection / Keras] Code Review - [1]


저번 포스팅에 이어서 계속해서 Train.py를 리뷰하도록 하겠습니다.

저번주에 해당 코드까지 리뷰하였습니다.


def _main_(args):

    config_path = args.conf

    with open(config_path) as config_buffer:    
        config = json.loads(config_buffer.read())

    ###############################
    #   Parse the annotations 
    ###############################

    # parse annotations of the training set
    train_imgs, train_labels = parse_annotation(config['train']['train_annot_folder'], 
                                                config['train']['train_image_folder'], 
                                                config['model']['labels'])
.
.
.
(생략)


오늘은 그 이후에 대한 코드를 계속해서 리뷰하도록 하겠습니다.


오늘 리뷰할 코드는 아래와 같습니다.


# parse annotations of the validation set, if any, otherwise split the training set
    if os.path.exists(config['valid']['valid_annot_folder']):
        valid_imgs, valid_labels = parse_annotation(config['valid']['valid_annot_folder'], 
                                                    config['valid']['valid_image_folder'], 
                                                    config['model']['labels'])
    else:
        train_valid_split = int(0.8*len(train_imgs))
        np.random.shuffle(train_imgs)

        valid_imgs = train_imgs[train_valid_split:]
        train_imgs = train_imgs[:train_valid_split]

    
    overlap_labels = set(config['model']['labels']).intersection(set(train_labels.keys()))

    print('Seen labels:\t', train_labels)
    print('Given labels:\t', config['model']['labels'])
    print('Overlap labels:\t', overlap_labels)

    if len(overlap_labels) < len(config['model']['labels']):
        print('Some labels have no images! Please revise the list of labels in the config.json file!')
        return
        
    ###############################
    #   Construct the model 
    ###############################

    yolo = YOLO(architecture        = config['model']['architecture'],
                input_size          = config['model']['input_size'], 
                labels              = config['model']['labels'], 
                max_box_per_image   = config['model']['max_box_per_image'],
                anchors             = config['model']['anchors'])

    ###############################
    #   Load the pretrained weights (if any) 
    ###############################    

    if os.path.exists(config['train']['pretrained_weights']):
        print("Loading pre-trained weights in", config['train']['pretrained_weights'])
        yolo.load_weights(config['train']['pretrained_weights'])

    ###############################
    #   Start the training process 
    ###############################

    yolo.train(train_imgs         = train_imgs,
               valid_imgs         = valid_imgs,
               train_times        = config['train']['train_times'],
               valid_times        = config['valid']['valid_times'],
               nb_epoch           = config['train']['nb_epoch'], 
               learning_rate      = config['train']['learning_rate'], 
               batch_size         = config['train']['batch_size'],
               warmup_epochs      = config['train']['warmup_epochs'],
               object_scale       = config['train']['object_scale'],
               no_object_scale    = config['train']['no_object_scale'],
               coord_scale        = config['train']['coord_scale'],
               class_scale        = config['train']['class_scale'],
               saved_weights_name = config['train']['saved_weights_name'],
               debug              = config['train']['debug'])


1. config.json파일에 valid_annot_folder의 경로가 있으면, 전 포스팅과 같은 방법으로 validation dataset을 parsing합니다.

2. 만약에 config.json파일에 valid_annot_folder의 경로가 없으면, trainSet에서 20%의 비율로 shuffle해서 추출합니다.

3. trainSet data를 parsing할 때, 모아뒀던 label의 key값들과, config.json파일에 있는 label 리스트의 교집합만 추출하여 overlap_labels에 대입합니다.

4. overlap_labels의 길이와 config.json에 있는 labels의 리스트 길이를 확인해서 label이 모두 있는지 validation check를 합니다.

5. YOLO 메소드에 파라미터들을 넣어주고 yolo에 리턴값을 대입합니다.

6. config.json에 pretrained_weights가 있다면 pretrained weights를 로드합니다.

7. yolo 오브젝트에서 train메소드를 config.json의 파라미터를 인수로 호출합니다.




다음 포스팅에서는 keras 코드의 yolo 모델을 천천히 리뷰하도록 하겠습니다.