ABOUT ME

-

Today
-
Yesterday
-
Total
-
  • flair 모델로 fine-tuning하기
    프로그래밍/자연어처리 2024. 9. 5. 18:21
    반응형

     

     

     

    flair ner 모델을 사용하다보니 분류 모델로도 사용이 가능하다고 해서 확인해보았다. 

     

    flair 모델 훈련 샘플 코드 

     

     

    아래가 훈련 샘플 코드인데 간단해 보인다. 

     

    from flair.data import Corpus
    from flair.datasets import TREC_6
    from flair.embeddings import TransformerDocumentEmbeddings
    from flair.models import TextClassifier
    from flair.trainers import ModelTrainer
    
    # 1. get the corpus
    corpus: Corpus = TREC_6()
    
    # 2. what label do we want to predict?
    label_type = 'question_class'
    
    # 3. create the label dictionary
    label_dict = corpus.make_label_dictionary(label_type=label_type)
    
    # 4. initialize transformer document embeddings (many models are available)
    document_embeddings = TransformerDocumentEmbeddings('distilbert-base-uncased', fine_tune=True)
    
    # 5. create the text classifier
    classifier = TextClassifier(document_embeddings, label_dictionary=label_dict, label_type=label_type)
    
    # 6. initialize trainer
    trainer = ModelTrainer(classifier, corpus)
    
    # 7. run training with fine-tuning
    trainer.fine_tune('resources/taggers/question-classification-with-transformer',
                      learning_rate=5.0e-5,
                      mini_batch_size=4,
                      max_epochs=10,
                      )

     

     

    샘플 코드에서는 TREC_6 코퍼스를 사용해서 모델 튜닝을 하였다. 

     

    TREC_6은 자연어 처리 분야에서 질문 분류(question classification)를 위해 널리 사용되는 데이터셋입니다. 이 데이터셋은 다양한 질문들을 6개의 상위 카테고리로 분류하여 제공합니다.

    해당 카테고리는 다음과 같습니다.
    ABBR (약어): 약어 또는 축약어에 대한 질문
    DESC (설명): 정의나 설명을 요구하는 질문
    ENTY (엔티티): 사물, 개념, 사건 등 특정 개체에 대한 질문
    HUM (인물): 사람이나 그룹에 대한 질문
    LOC (장소): 장소나 위치에 대한 질문
    NUM (숫자): 숫자, 수량, 날짜 등 수치 정보를 요구하는 질문

    TREC_6 데이터셋은 각 질문이 어떤 유형에 속하는지를 라벨링하여 제공하므로, 머신 러닝 모델이 질문의 유형을 학습하고 예측하는 데 활용됩니다. 이 데이터셋은 자연어 이해, 정보 검색, 챗봇 등 다양한 응용 분야에서 모델의 성능을 평가하는 표준 벤치마크로 사용됩니다. 예를 들어, 질문 "What is the capital of France?"는 LOC 카테고리에 속하며, "Who wrote 'Hamlet'?"은 HUM 카테고리에 속합니다.

     

    훈련/테스트 데이터 

     

    TREC4에서 사용하는 훈련/테스트 데이터 형식을 확인해보았다. 

     

    훈련 데이터(train.txt)

     

    테스트 데이터 (text.txt)

     

     

    __label__ 다음에 클래스 이름과 해당되는 문장을 넣으면 되는 것 같다. 

     

    이렇게 생성된 모델은 trainer에서 지정한 경로에 저장이된다. 

    resources/taggers/question-classification-with-transformer

     

     

    생성된 모델로 예측하기 

     

     

    이곳에 저장된 파인튜닝된 모델을 불러와서 테스트 하는 방법도 간단하다. 

     

    classifier = TextClassifier.load('resources/taggers/question-classification-with-transformer/final-model.pt')
    
    # create example sentence
    sentence = Sentence('Who built the Eiffel Tower ?')
    
    # predict class and print
    classifier.predict(sentence)
    
    print(sentence.labels)

     

     

     

    자체 데이터셋으로 분류 모델 튜닝하기 

     

     

     

    샘플 코드가 아니라 내가 생성한 데이터셋으로 분류를 하고싶다면 아래 훈련 모델과 레이블 타입 설정해주는 부분을 본인이 생성한 데이터 셋으로 변경하면 된다. 

     

    # 1. get the corpus
    corpus: Corpus = TREC_6()
    
    # 2. what label do we want to predict?
    label_type = 'question_class'

     

     

     

    훈련,dev, 테스트 파일 다 포맷은 동일하다. 

    # 1. 데이터 경로 설정
    data_folder = './data'  # 데이터가 위치한 폴더 경로
    
    # 2. 커스텀 코퍼스 생성
    corpus: Corpus = ClassificationCorpus(data_folder,
                                          train_file='train_article.txt',
                                          dev_file='dev_article.txt',
                                          test_file='test_article.txt')
    
    
    
    
    # 2. what label do we want to predict?
    label_type = 'class'

     

     

    레이블을 처음에 'question_class'로 했더니 아래와 같은 에러가 나왔었다. 

    2024-09-05 17:52:21,350 ERROR: You specified label_type='question_class' which is not in this dataset! 2024-09-05 17:52:21,351 ERROR: The corpus contains the following label types: 'class' (in 25 sentences)

     

    확인해보니 그냥 question_type만 class로 변경하면 된다고 한다. 


    TREC 데이터셋처럼 모든 라벨이 동일한 라벨 타입('class')으로 처리되도록 하기 위해, 위에서 label_type='class'로 설정했습니다. 이렇게 하면 Flair가 모든 라벨을 'class'로 인식하고, 동일한 방식으로 분류 작업을 수행합니다.

     

     

    만약 아래와 같은 데이터 에러가 나온다면 파일이 제대로 위치하는지 확인한다.

    파일이 있다면 텍스트 파일에서 __label__로 시작하는 라벨이 있어야 하며, 그 뒤에 해당 텍스트가 있는지 확인한다. 

     

     

    corpus: Corpus = ClassificationCorpus(data_folder, ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "D:\Lib\site-packages\flair\datasets\document_classification.py", line 117, in __init__ super().__init__(train, dev, test, name=str(data_folder), sample_missing_splits=sample_missing_splits)
    File "D:\Lib\site-packages\flair\data.py", line 1361, in __init__ raise RuntimeError("No data provided when initializing corpus object.")

    RuntimeError: No data provided when initializing corpus object.

     

     

    샘플 데이터를 돌려서 돌아가는 것은 확인해봤으니 이제 데이터를 많이 생성해서 결과를 봐야겠다. 

     

     

    사용이 간단해서 성능만 괜찮으면 앞으로도 flair 종종 사용할 것 같다. 

     

     

    오늘의 개발일기 끝 ㅎㅎ

     

     

     

     

     

    참고 

     

    https://flairnlp.github.io/docs/tutorial-training/how-to-train-text-classifier

    728x90
    반응형
Designed by Tistory.