当前位置:   article > 正文




  1. # coding=utf-8
  2. # Implements API for ChatGLM2-6B in OpenAI's format. (https://platform.openai.com/docs/api-reference/chat)
  3. # Usage: python openai_api.py
  4. # Visit http://localhost:8000/docs for documents.
  5. import time
  6. import torch
  7. import uvicorn
  8. from pydantic import BaseModel, Field
  9. from fastapi import FastAPI, HTTPException
  10. from fastapi.middleware.cors import CORSMiddleware
  11. from contextlib import asynccontextmanager
  12. from typing import Any, Dict, List, Literal, Optional, Union
  13. from transformers import AutoTokenizer, AutoModel
  14. from sse_starlette.sse import ServerSentEvent, EventSourceResponse
  15. @app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
  16. async def create_chat_completion(request: ChatCompletionRequest):
  17. global model, tokenizer
  18. if request.messages[-1].role != "user":
  19. raise HTTPException(status_code=400, detail="Invalid request")
  20. query = request.messages[-1].content
  21. prev_messages = request.messages[:-1]
  22. if len(prev_messages) > 0 and prev_messages[0].role == "system":
  23. query = prev_messages.pop(0).content + query
  24. history = []
  25. if len(prev_messages) % 2 == 0:
  26. for i in range(0, len(prev_messages), 2):
  27. if prev_messages[i].role == "user" and prev_messages[i+1].role == "assistant":
  28. history.append([prev_messages[i].content, prev_messages[i+1].content])
  29. if request.stream:
  30. generate = predict(query, history, request.model)
  31. return EventSourceResponse(generate, media_type="text/event-stream")
  32. response, _ = model.chat(tokenizer, query, history=history)
  33. choice_data = ChatCompletionResponseChoice(
  34. index=0,
  35. message=ChatMessage(role="assistant", content=response),
  36. finish_reason="stop"
  37. )
  38. return ChatCompletionResponse(model=request.model, choices=[choice_data], object="chat.completion")
  39. async def predict(query: str, history: List[List[str]], model_id: str):
  40. global model, tokenizer
  41. choice_data = ChatCompletionResponseStreamChoice(
  42. index=0,
  43. delta=DeltaMessage(role="assistant"),
  44. finish_reason=None
  45. )
  46. chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
  47. yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False))
  48. current_length = 0
  49. for new_response, _ in model.stream_chat(tokenizer, query, history):
  50. if len(new_response) == current_length:
  51. continue
  52. new_text = new_response[current_length:]
  53. current_length = len(new_response)
  54. choice_data = ChatCompletionResponseStreamChoice(
  55. index=0,
  56. delta=DeltaMessage(content=new_text),
  57. finish_reason=None
  58. )
  59. chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
  60. yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False))
  61. choice_data = ChatCompletionResponseStreamChoice(
  62. index=0,
  63. delta=DeltaMessage(),
  64. finish_reason="stop"
  65. )
  66. chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
  67. yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False))
  68. yield '[DONE]'
  69. if __name__ == "__main__":
  70. tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True)
  71. model = AutoModel.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True).cuda()
  72. # 多显卡支持,使用下面两行代替上面一行,将num_gpus改为你实际的显卡数量
  73. # from utils import load_model_on_gpus
  74. # model = load_model_on_gpus("THUDM/chatglm2-6b", num_gpus=2)
  75. model.eval()
  76. uvicorn.run(app, host='', port=8000, workers=1)



  1. import os
  2. # import socket, socks
  3. #
  4. # socks.set_default_proxy(socks.SOCKS5, "", 1080)
  5. # socket.socket = socks.socksocket
  6. import openai
  7. openai.api_base = "http://localhost:8000/v1"
  8. openai.api_key = "none"
  9. response = openai.ChatCompletion.create(
  10. model="chatglm2-6b",
  11. messages=[
  12. {"role": "user", "content": "你好"}
  13. ],
  14. stream=True
  15. )
  16. for chunk in response:
  17. if hasattr(chunk.choices[0].delta, "content"):
  18. print(chunk.choices[0].delta.content, end="", flush=True)




chunk.json(exclude_unset=True, ensure_ascii=False)




