赞
踩
平台:阿里云(免费A10算力)PAI DSW
首先创建实例,选择环境为
然后打开环境
首先先下载源码
git clone https://github.com/THUDM/ChatGLM-6B
需要一段时间下载,然后安装相关的依赖
cd ChatGLM-6B
# 其中 transformers版本用的是4.3.0(因为没有4.2.7)
pip install -r requirements.txt
然后下载模型,因为hungingface无法打开,所以使用的是阿里云提供的模型
import os
dsw_region = os.environ.get("dsw_region")
url_link = {
"cn-shanghai": "https://atp-modelzoo-sh.oss-cn-shanghai-internal.aliyuncs.com/release/tutorials/chatGLM/ChatGLM-6B-main.tar.gz",
"cn-hangzhou": "https://atp-modelzoo.oss-cn-hangzhou-internal.aliyuncs.com/release/tutorials/chatGLM/ChatGLM-6B-main.tar.gz",
"cn-shenzhen": "https://atp-modelzoo-sz.oss-cn-shenzhen-internal.aliyuncs.com/release/tutorials/chatGLM/ChatGLM-6B-main.tar.gz",
"cn-beijing": "https://atp-modelzoo-bj.oss-cn-beijing-internal.aliyuncs.com/release/tutorials/chatGLM/ChatGLM-6B-main.tar.gz",
}
path = url_link[dsw_region]
os.environ['LINK_CHAT'] = path
!wget $LINK_CHAT
!tar -xvf ChatGLM-6B-main.tar.gz
下载完成之后,模型在ChatGLM-6B-main/punting/chatglm-6b中,而且文件中缺少一个test_modeling_chatglm.py
我们建立一个py文件,名为test_modeling_chatglm.py,代码如下
import datetime
import math
import unittest
import torch
import random
from transformers import AutoTokenizer, AutoModel
from transformers.testing_utils import require_torch, slow, torch_device
def set_random_seed(seed):
import random
random.seed(seed)
# pytorch RNGs
import torch
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
# numpy RNG
import numpy as np
np.random.seed(seed)
def ids_tensor(shape, vocab_size):
# Creates a random int32 tensor of the shape within the vocab size
total_dims = 1
for dim in shape:
total_dims *= dim
values = []
for _ in range(total_dims):
values.append(random.randint(0, vocab_size - 1))
return torch.tensor(data=values, dtype=torch.long, device=torch_device).view(shape).contiguous()
def get_model_and_tokenizer():
model = AutoModel.from_pretrained("/mnt/vepfs/workspace/zxdu/chatglm_6b", trust_remote_code=True).half()
model.to(torch_device)
model.eval()
tokenizer = AutoTokenizer.from_pretrained("/mnt/vepfs/workspace/zxdu/chatglm_6b", trust_remote_code=True)
return model, tokenizer
@require_torch
class ChatGLMGenerationTest(unittest.TestCase):
def get_generation_kwargs(self):
pass
def test_chat(self):
model, tokenizer = get_model_and_tokenizer()
prompts = ["你好", "介绍一下清华大学", "它创建于哪一年"]
history = []
set_random_seed(42)
expected_responses = [
'你好声明:本文内容由网友自发贡献,转载请注明出处:【wpsshop博客】
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。