当前位置:   article > 正文

Gradio的web界面演示与交互机器学习模型,全局状态与会话状态《4》_gradio session变量

gradio session变量

全局状态和会话状态,对于程序员来说都是很熟悉的了,开发中会经常遇到,这里看下在Gradio中是怎么使用的,以及对GPT2的一点介绍

一、Global State全局状态

如果定义的函数想要访问外部的数据,可以将变量写在外面成为一个全局变量,这样就可以在函数内部访问它。例如,您可以在函数外部加载一个大型模型,并在函数内部使用它,这样每次函数调用都不需要重新加载模型。
看一个示例,显示排名前三的分数:

  1. import gradio as gr
  2. scores = [] #全局变量
  3. def track_score(score):
  4. scores.append(score)
  5. top_scores = sorted(scores, reverse=True)[:3]
  6. return top_scores
  7. demo = gr.Interface(
  8. track_score,
  9. gr.Number(label="分数"),
  10. gr.JSON(label="排名前三的分数")
  11. )
  12. demo.launch()

 其中的scores = [] 分数数组会在所有用户之间共享。如果多个用户访问此演示,他们的分数将被添加到同一列表中,并且返回的前三名分数将在这个共享参考中收集。

二、Session State会话状态

数据持久化的另一种类型是会话状态,其中数据在页面会话中的多个提交中持久化。但是,数据不会在模型的不同用户之间共享。要在会话状态中存储数据,需要做三件事:

1、向函数传递一个额外的参数,该参数表示接口的状态。
2、在函数结束时,将状态的更新值作为额外的返回值返回。
3、在创建接口时添加'state'输入和'state'输出组件

聊天机器人就是一个需要会话状态的例子——您希望访问用户以前提交的内容,但是您不能将聊天历史存储在全局变量中,因为这样聊天历史会在不同的用户之间混淆。这个大家也是很熟悉的,做登录的时候就属于这种,不同用户需要区别开来。
来看个简单的机器人聊天的示例(来自微软的DialoGPT),对官方例子做了点修改:

  1. import gradio as gr
  2. from transformers import AutoModelForCausalLM, AutoTokenizer
  3. import torch
  4. tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
  5. model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium")
  6. def predict(input, history=[]):
  7. new_user_input_ids = tokenizer.encode(input + tokenizer.eos_token, return_tensors='pt')
  8. # 添加到聊天历史记录
  9. bot_input_ids = torch.cat([torch.LongTensor(history), new_user_input_ids], dim=-1)
  10. # 生成回答
  11. history = model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id).tolist()
  12. # 将标记转换为文本,然后将响应分成几行
  13. response = tokenizer.decode(history[0]).split("<|endoftext|>")
  14. response = [(response[i], response[i+1]) for i in range(0, len(response)-1, 2)]
  15. return response, history
  16. gr.Interface(fn=predict,
  17. inputs=[gr.Textbox(label="输入"),"state"],
  18. outputs=[gr.Textbox(label="聊天记录列表",lines=8),"state"]).launch()

如果没有安装transformers模块将会找不到模块的错误,本人依然推荐大家带镜像的安装

ModuleNotFoundError: No module named 'transformers'

pip install transformers -i http://pypi.douban.com/simple/  --trusted-host pypi.douban.com 

整体的调用还是比较简单的,就是先将输入做词嵌入(编码),通过GPT2模型的处理生成回答,然后将词嵌入做解码输出。

然后我们输入信息,就能和机器人进行聊天了,界面如下所示: 

这个DialoGPT是基于GPT2的一个初级聊天,首先会将相关的配置和模型给下载下来。然后就是标准的输入和输出,左边输入想要聊天的内容,右边就是跟机器人聊天的历史聊天记录。

三、GPT2的相关

这里我们来了解下这个DialoGPT-medium模型的分词的相关配置等信息,属于transformers.models.gpt2.tokenization_gpt2_fast.GPT2TokenizerFast

打印看下它的信息print(tokenizer): 

GPT2TokenizerFast(name_or_path='microsoft/DialoGPT-medium', vocab_size=50257, model_max_length=1024, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=True), 'eos_token': AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=True), 'unk_token': AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=True)}, clean_up_tokenization_spaces=True)

看下这个模型的结构print(model)

  1. GPT2LMHeadModel(
  2. (transformer): GPT2Model(
  3. (wte): Embedding(50257, 1024)
  4. (wpe): Embedding(1024, 1024)
  5. (drop): Dropout(p=0.1, inplace=False)
  6. (h): ModuleList(
  7. (0): GPT2Block(
  8. (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
  9. (attn): GPT2Attention(
  10. (c_attn): Conv1D()
  11. (c_proj): Conv1D()
  12. (attn_dropout): Dropout(p=0.1, inplace=False)
  13. (resid_dropout): Dropout(p=0.1, inplace=False)
  14. )
  15. (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
  16. (mlp): GPT2MLP(
  17. (c_fc): Conv1D()
  18. (c_proj): Conv1D()
  19. (act): NewGELUActivation()
  20. (dropout): Dropout(p=0.1, inplace=False)
  21. )
  22. )
  23. (1): GPT2Block(
  24. (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
  25. (attn): GPT2Attention(
  26. (c_attn): Conv1D()
  27. (c_proj): Conv1D()
  28. (attn_dropout): Dropout(p=0.1, inplace=False)
  29. (resid_dropout): Dropout(p=0.1, inplace=False)
  30. )
  31. (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
  32. (mlp): GPT2MLP(
  33. (c_fc): Conv1D()
  34. (c_proj): Conv1D()
  35. (act): NewGELUActivation()
  36. (dropout): Dropout(p=0.1, inplace=False)
  37. )
  38. )
  39. (2): GPT2Block(
  40. (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
  41. (attn): GPT2Attention(
  42. (c_attn): Conv1D()
  43. (c_proj): Conv1D()
  44. (attn_dropout): Dropout(p=0.1, inplace=False)
  45. (resid_dropout): Dropout(p=0.1, inplace=False)
  46. )
  47. (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
  48. (mlp): GPT2MLP(
  49. (c_fc): Conv1D()
  50. (c_proj): Conv1D()
  51. (act): NewGELUActivation()
  52. (dropout): Dropout(p=0.1, inplace=False)
  53. )
  54. )
  55. (3): GPT2Block(
  56. (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
  57. (attn): GPT2Attention(
  58. (c_attn): Conv1D()
  59. (c_proj): Conv1D()
  60. (attn_dropout): Dropout(p=0.1, inplace=False)
  61. (resid_dropout): Dropout(p=0.1, inplace=False)
  62. )
  63. (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
  64. (mlp): GPT2MLP(
  65. (c_fc): Conv1D()
  66. (c_proj): Conv1D()
  67. (act): NewGELUActivation()
  68. (dropout): Dropout(p=0.1, inplace=False)
  69. )
  70. )
  71. (4): GPT2Block(
  72. (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
  73. (attn): GPT2Attention(
  74. (c_attn): Conv1D()
  75. (c_proj): Conv1D()
  76. (attn_dropout): Dropout(p=0.1, inplace=False)
  77. (resid_dropout): Dropout(p=0.1, inplace=False)
  78. )
  79. (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
  80. (mlp): GPT2MLP(
  81. (c_fc): Conv1D()
  82. (c_proj): Conv1D()
  83. (act): NewGELUActivation()
  84. (dropout): Dropout(p=0.1, inplace=False)
  85. )
  86. )
  87. (5): GPT2Block(
  88. (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
  89. (attn): GPT2Attention(
  90. (c_attn): Conv1D()
  91. (c_proj): Conv1D()
  92. (attn_dropout): Dropout(p=0.1, inplace=False)
  93. (resid_dropout): Dropout(p=0.1, inplace=False)
  94. )
  95. (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
  96. (mlp): GPT2MLP(
  97. (c_fc): Conv1D()
  98. (c_proj): Conv1D()
  99. (act): NewGELUActivation()
  100. (dropout): Dropout(p=0.1, inplace=False)
  101. )
  102. )
  103. (6): GPT2Block(
  104. (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
  105. (attn): GPT2Attention(
  106. (c_attn): Conv1D()
  107. (c_proj): Conv1D()
  108. (attn_dropout): Dropout(p=0.1, inplace=False)
  109. (resid_dropout): Dropout(p=0.1, inplace=False)
  110. )
  111. (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
  112. (mlp): GPT2MLP(
  113. (c_fc): Conv1D()
  114. (c_proj): Conv1D()
  115. (act): NewGELUActivation()
  116. (dropout): Dropout(p=0.1, inplace=False)
  117. )
  118. )
  119. (7): GPT2Block(
  120. (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
  121. (attn): GPT2Attention(
  122. (c_attn): Conv1D()
  123. (c_proj): Conv1D()
  124. (attn_dropout): Dropout(p=0.1, inplace=False)
  125. (resid_dropout): Dropout(p=0.1, inplace=False)
  126. )
  127. (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
  128. (mlp): GPT2MLP(
  129. (c_fc): Conv1D()
  130. (c_proj): Conv1D()
  131. (act): NewGELUActivation()
  132. (dropout): Dropout(p=0.1, inplace=False)
  133. )
  134. )
  135. (8): GPT2Block(
  136. (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
  137. (attn): GPT2Attention(
  138. (c_attn): Conv1D()
  139. (c_proj): Conv1D()
  140. (attn_dropout): Dropout(p=0.1, inplace=False)
  141. (resid_dropout): Dropout(p=0.1, inplace=False)
  142. )
  143. (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
  144. (mlp): GPT2MLP(
  145. (c_fc): Conv1D()
  146. (c_proj): Conv1D()
  147. (act): NewGELUActivation()
  148. (dropout): Dropout(p=0.1, inplace=False)
  149. )
  150. )
  151. (9): GPT2Block(
  152. (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
  153. (attn): GPT2Attention(
  154. (c_attn): Conv1D()
  155. (c_proj): Conv1D()
  156. (attn_dropout): Dropout(p=0.1, inplace=False)
  157. (resid_dropout): Dropout(p=0.1, inplace=False)
  158. )
  159. (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
  160. (mlp): GPT2MLP(
  161. (c_fc): Conv1D()
  162. (c_proj): Conv1D()
  163. (act): NewGELUActivation()
  164. (dropout): Dropout(p=0.1, inplace=False)
  165. )
  166. )
  167. (10): GPT2Block(
  168. (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
  169. (attn): GPT2Attention(
  170. (c_attn): Conv1D()
  171. (c_proj): Conv1D()
  172. (attn_dropout): Dropout(p=0.1, inplace=False)
  173. (resid_dropout): Dropout(p=0.1, inplace=False)
  174. )
  175. (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
  176. (mlp): GPT2MLP(
  177. (c_fc): Conv1D()
  178. (c_proj): Conv1D()
  179. (act): NewGELUActivation()
  180. (dropout): Dropout(p=0.1, inplace=False)
  181. )
  182. )
  183. (11): GPT2Block(
  184. (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
  185. (attn): GPT2Attention(
  186. (c_attn): Conv1D()
  187. (c_proj): Conv1D()
  188. (attn_dropout): Dropout(p=0.1, inplace=False)
  189. (resid_dropout): Dropout(p=0.1, inplace=False)
  190. )
  191. (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
  192. (mlp): GPT2MLP(
  193. (c_fc): Conv1D()
  194. (c_proj): Conv1D()
  195. (act): NewGELUActivation()
  196. (dropout): Dropout(p=0.1, inplace=False)
  197. )
  198. )
  199. (12): GPT2Block(
  200. (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
  201. (attn): GPT2Attention(
  202. (c_attn): Conv1D()
  203. (c_proj): Conv1D()
  204. (attn_dropout): Dropout(p=0.1, inplace=False)
  205. (resid_dropout): Dropout(p=0.1, inplace=False)
  206. )
  207. (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
  208. (mlp): GPT2MLP(
  209. (c_fc): Conv1D()
  210. (c_proj): Conv1D()
  211. (act): NewGELUActivation()
  212. (dropout): Dropout(p=0.1, inplace=False)
  213. )
  214. )
  215. (13): GPT2Block(
  216. (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
  217. (attn): GPT2Attention(
  218. (c_attn): Conv1D()
  219. (c_proj): Conv1D()
  220. (attn_dropout): Dropout(p=0.1, inplace=False)
  221. (resid_dropout): Dropout(p=0.1, inplace=False)
  222. )
  223. (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
  224. (mlp): GPT2MLP(
  225. (c_fc): Conv1D()
  226. (c_proj): Conv1D()
  227. (act): NewGELUActivation()
  228. (dropout): Dropout(p=0.1, inplace=False)
  229. )
  230. )
  231. (14): GPT2Block(
  232. (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
  233. (attn): GPT2Attention(
  234. (c_attn): Conv1D()
  235. (c_proj): Conv1D()
  236. (attn_dropout): Dropout(p=0.1, inplace=False)
  237. (resid_dropout): Dropout(p=0.1, inplace=False)
  238. )
  239. (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
  240. (mlp): GPT2MLP(
  241. (c_fc): Conv1D()
  242. (c_proj): Conv1D()
  243. (act): NewGELUActivation()
  244. (dropout): Dropout(p=0.1, inplace=False)
  245. )
  246. )
  247. (15): GPT2Block(
  248. (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
  249. (attn): GPT2Attention(
  250. (c_attn): Conv1D()
  251. (c_proj): Conv1D()
  252. (attn_dropout): Dropout(p=0.1, inplace=False)
  253. (resid_dropout): Dropout(p=0.1, inplace=False)
  254. )
  255. (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
  256. (mlp): GPT2MLP(
  257. (c_fc): Conv1D()
  258. (c_proj): Conv1D()
  259. (act): NewGELUActivation()
  260. (dropout): Dropout(p=0.1, inplace=False)
  261. )
  262. )
  263. (16): GPT2Block(
  264. (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
  265. (attn): GPT2Attention(
  266. (c_attn): Conv1D()
  267. (c_proj): Conv1D()
  268. (attn_dropout): Dropout(p=0.1, inplace=False)
  269. (resid_dropout): Dropout(p=0.1, inplace=False)
  270. )
  271. (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
  272. (mlp): GPT2MLP(
  273. (c_fc): Conv1D()
  274. (c_proj): Conv1D()
  275. (act): NewGELUActivation()
  276. (dropout): Dropout(p=0.1, inplace=False)
  277. )
  278. )
  279. (17): GPT2Block(
  280. (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
  281. (attn): GPT2Attention(
  282. (c_attn): Conv1D()
  283. (c_proj): Conv1D()
  284. (attn_dropout): Dropout(p=0.1, inplace=False)
  285. (resid_dropout): Dropout(p=0.1, inplace=False)
  286. )
  287. (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
  288. (mlp): GPT2MLP(
  289. (c_fc): Conv1D()
  290. (c_proj): Conv1D()
  291. (act): NewGELUActivation()
  292. (dropout): Dropout(p=0.1, inplace=False)
  293. )
  294. )
  295. (18): GPT2Block(
  296. (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
  297. (attn): GPT2Attention(
  298. (c_attn): Conv1D()
  299. (c_proj): Conv1D()
  300. (attn_dropout): Dropout(p=0.1, inplace=False)
  301. (resid_dropout): Dropout(p=0.1, inplace=False)
  302. )
  303. (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
  304. (mlp): GPT2MLP(
  305. (c_fc): Conv1D()
  306. (c_proj): Conv1D()
  307. (act): NewGELUActivation()
  308. (dropout): Dropout(p=0.1, inplace=False)
  309. )
  310. )
  311. (19): GPT2Block(
  312. (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
  313. (attn): GPT2Attention(
  314. (c_attn): Conv1D()
  315. (c_proj): Conv1D()
  316. (attn_dropout): Dropout(p=0.1, inplace=False)
  317. (resid_dropout): Dropout(p=0.1, inplace=False)
  318. )
  319. (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
  320. (mlp): GPT2MLP(
  321. (c_fc): Conv1D()
  322. (c_proj): Conv1D()
  323. (act): NewGELUActivation()
  324. (dropout): Dropout(p=0.1, inplace=False)
  325. )
  326. )
  327. (20): GPT2Block(
  328. (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
  329. (attn): GPT2Attention(
  330. (c_attn): Conv1D()
  331. (c_proj): Conv1D()
  332. (attn_dropout): Dropout(p=0.1, inplace=False)
  333. (resid_dropout): Dropout(p=0.1, inplace=False)
  334. )
  335. (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
  336. (mlp): GPT2MLP(
  337. (c_fc): Conv1D()
  338. (c_proj): Conv1D()
  339. (act): NewGELUActivation()
  340. (dropout): Dropout(p=0.1, inplace=False)
  341. )
  342. )
  343. (21): GPT2Block(
  344. (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
  345. (attn): GPT2Attention(
  346. (c_attn): Conv1D()
  347. (c_proj): Conv1D()
  348. (attn_dropout): Dropout(p=0.1, inplace=False)
  349. (resid_dropout): Dropout(p=0.1, inplace=False)
  350. )
  351. (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
  352. (mlp): GPT2MLP(
  353. (c_fc): Conv1D()
  354. (c_proj): Conv1D()
  355. (act): NewGELUActivation()
  356. (dropout): Dropout(p=0.1, inplace=False)
  357. )
  358. )
  359. (22): GPT2Block(
  360. (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
  361. (attn): GPT2Attention(
  362. (c_attn): Conv1D()
  363. (c_proj): Conv1D()
  364. (attn_dropout): Dropout(p=0.1, inplace=False)
  365. (resid_dropout): Dropout(p=0.1, inplace=False)
  366. )
  367. (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
  368. (mlp): GPT2MLP(
  369. (c_fc): Conv1D()
  370. (c_proj): Conv1D()
  371. (act): NewGELUActivation()
  372. (dropout): Dropout(p=0.1, inplace=False)
  373. )
  374. )
  375. (23): GPT2Block(
  376. (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
  377. (attn): GPT2Attention(
  378. (c_attn): Conv1D()
  379. (c_proj): Conv1D()
  380. (attn_dropout): Dropout(p=0.1, inplace=False)
  381. (resid_dropout): Dropout(p=0.1, inplace=False)
  382. )
  383. (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
  384. (mlp): GPT2MLP(
  385. (c_fc): Conv1D()
  386. (c_proj): Conv1D()
  387. (act): NewGELUActivation()
  388. (dropout): Dropout(p=0.1, inplace=False)
  389. )
  390. )
  391. )
  392. (ln_f): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
  393. )
  394. (lm_head): Linear(in_features=1024, out_features=50257, bias=False)
  395. )

有兴趣的可以查阅其余章节:

Gradio的web界面演示与交互机器学习模型,安装和使用《1》

Gradio的web界面演示与交互机器学习模型,主要特征《2》

Gradio的web界面演示与交互机器学习模型,分享应用《3》

声明:本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:【wpsshop博客】
推荐阅读
相关标签
  

闽ICP备14008679号