当前位置:   article > 正文

Baichuan2 源码解析之 web_demo.py_if prompt := st.chat_input("shift + enter 换行, ente

if prompt := st.chat_input("shift + enter 换行, enter 发送"):

 github源码解析:https://github.com/ArtificialZeng/Baichuan2-Explained

 

  1. import json
  2. import torch
  3. import streamlit as st
  4. from transformers import AutoModelForCausalLM, AutoTokenizer
  5. from transformers.generation.utils import GenerationConfig
  6. st.set_page_config(page_title="Baichuan 2")
  7. st.title("Baichuan 2")
  8. @st.cache_resource
  9. def init_model():
  10. model = AutoModelForCausalLM.from_pretrained(
  11. "baichuan-inc/Baichuan2-13B-Chat",
  12. torch_dtype=torch.float16,
  13. device_map="auto",
  14. trust_remote_code=True
  15. )
  16. model.generation_config = GenerationConfig.from_pretrained(
  17. "baichuan-inc/Baichuan2-13B-Chat"
  18. )
  19. tokenizer = AutoTokenizer.from_pretrained(
  20. "baichuan-inc/Baichuan2-13B-Chat",
  21. use_fast=False,
  22. trust_remote_code=True
  23. )
  24. return model, tokenizer
  25. def clear_chat_history():
  26. del st.session_state.messages
  27. def init_chat_history():