쉬엄쉬엄블로그

(NLP 기초대회) Prediction Service 본문

부스트캠프 AI Tech 4기

(NLP 기초대회) Prediction Service

쉬엄쉬엄블로그 2023. 7. 29. 13:14
728x90

이 색깔은 주석이라 무시하셔도 됩니다.

Prediction Service 개발

Fastapi

  • RestAPI 백엔드 서버 설정

Python 기반 웹 프레임워크

  • 웹 개발을 도와주는 도구

    • 사용자(웹, 모바일 등)의 요청을 수행함
    • DB와 연결하여 데이터 작업을 수행함
    • 속도와 안정성이 중요함
    • FastAPI가 다른 프레임워크에 비해서 간단한 머신러닝 데모 시스템을 만들 때 유리함
  • FastAPI : 빠르고 배우기 쉽다

코드 설명

  • 전체 코드

      from fastapi import FastAPI
      import uvicorn
      from starlette.responses import JSONResponse
      from temp.date import Model
    
      app = FastAPI()
    
      @app.get("/translation")
      async def root(date):
          inputs = model.tokenizer(date, padding='max_length', truncation=True, max_length=16, return_tensors='pt')
          pred_ids = model.encoder_decoder.generate(inputs['input_ids'], num_beams=3, min_length=0, max_length=16, num_return_sequences=3)
          pred = model.tokenizer.batch_decode(pred_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
          return JSONResponse(content={"preds": pred})
    
      if __name__ == '__main__':
          # 아래 경로에 저장된 모델 경로를 입력한다.
          model = Model.load_from_checkpoint('/workspace/temp/lightning_logs/version_16/checkpoints/epoch=0-step=188.ckpt')
    
          uvicorn.run(app, host="0.0.0.0", port=8000)
  • 날짜 정규화 수행 함수

      async def root(date):
          inputs = model.tokenizer(date, padding='max_length', truncation=True, max_length=16, return_tensors='pt')
          pred_ids = model.encoder_decoder.generate(inputs['input_ids'], num_beams=3, min_length=0, max_length=16, num_return_sequences=3)
          pred = model.tokenizer.batch_decode(pred_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
    
          return JSONResponse(content={"preds": pred})
    • 데코레이터와 입출력 부분으로 API 형식 정의
    • 함수 내부는 Python과 동일하게 동작함
  • 메인(실행) 파트

      if __name__ == '__main__':
          # 아래 경로에 저장된 모델 경로를 입력한다.
          model = Model.load_from_checkpoint('/workspace/temp/lightning_logs/version_16/checkpoints/epoch=0-step=188.ckpt')
    
          uvicorn.run(app, host="0.0.0.0", port=8000)
    • 학습해서 저장한 모델을 불러옴
    • 외부에서 접속을 허용할 ip와 port정보를 입력함
      • host=”0.0.0.0” : 외부 인터넷에서 접속 가능
      • host=”127.0.0.1” : 내부 인터넷에서만 접속 가능

Streamlit

  • 데이터 프로토타이핑 도구

Python 기반 웹 어플리케이션

  • Python 만으로 웹 뷰를 구현함
  • 배우기 쉽고, 기본 디자인이 제공됨
  • 데이터 및 결과에 대한 전달력이 우수함
    • pandas와 연결하면 데이터프레임 형태를 화면에 보일 수 있음
  • 딥러닝 모델 배포 및 테스트가 가능함

코드 설명

  • N2M 데모

      def streamlit():
              st.title('N2M 데모')
          st.markdown('**Transformers Encoder-Decoder 모델을 활용한 Sequence to Sequence 문제 실습**')
          st.markdown("#### 데이터셋([Link](https://github.com/htw5295/Neural_date_translation_dataset))", unsafe_allow_html=True)
          data_info = """
      - Faker 라이브러리로 생성한 날짜 표기 데이터
      - 입력 : 다양한 형태의 날짜 표기 데이터
      - 출력 : yyyy-mm-dd 형태의 날짜 표기 데이터
      - 학습 데이터 : 24,000개
      - 검증 데이터 : 3,000개
      - 평가 데이터 : 3,000개
          """
          st.markdown(data_info)
          st.markdown("#### 모델")
          model_info = """
      - [facebook/bart-base](https://huggingface.co/facebook/bart-base)의 Tokenizer와 Config 활용
      - Huggingface의 [AutoModelForSeq2SeqLM](https://huggingface.co/docs/transformers/model_doc/auto#transformers.AutoModelForSeq2SeqLM) 모델 활용
          """
          st.markdown(model_info)
          st.markdown("#### 평가")
          eval_info = """
      - 예측 데이터를 yyyy-mm-dd 형식으로 디코딩한 뒤 정답 데이터와 비교하여 일치, 불일치를 판단함
          """
          st.markdown(eval_info)

              today = date.today().strftime('%b %d %Y')
              st.subheader('Input human readable date text')
              input_date = st.text_input('Input date', today)
    
              if st.button('Translation'): # 위 그림에 Translation 버튼이 눌리면 실행되는 코드
                  result = requests.get('http://127.0.0.1:8000/translation', params={'date': input_date}).json()
    
                  data = []
                  index_data = []
                  for i, pred in enumerate(result['preds']):
                      data.append([input_date, pred])
                      index_data.append(f"Beam {i}")
    
                  df = pd.DataFrame(data, index=index_data, columns=['input date', 'generate date'])
    
                  st.subheader('Generated yyyy-mm-dd format date text')
                  st.table(df)

출처: 부스트캠프 AI Tech 4기(NAVER Connect Foundation)

Comments