当前位置:   article > 正文

调用ChatGPT的Embedding接口完成文本分类任务_chat gpt做文本分类

chat gpt做文本分类

文本分类任务在现实世界中的应用

文本分类是一种机器学习任务,旨在将一段给定的文本分配到预定义的类别或标签中。它通常涉及到对大量文本数据的处理和分类,以便进行有效的信息管理、资源分配或者决策制定。文本分类任务可以应用于多种场景,例如:

  1. 垃圾邮件分类:将电子邮件分类为垃圾邮件或非垃圾邮件。
  2. 情感分析:将文本分类为积极、消极或中性。
  3. 新闻分类:将新闻文章分类为政治、体育、娱乐等。
  4. 商品评论分类:将商品评论分类为好评、中评或差评。
  5. 聊天机器人:将用户输入分类为问题、建议或其他类型的文本。
  6. 智能客服:将用户提出的问题分类为不同的主题,以便向用户提供更好的解决方案。
  7. 社交媒体分析:对社交媒体上的内容进行分类,以了解用户的态度、情感和需求等。
  8. 舆情监测:将社交媒体上的文本分类为正面、负面或中性,以便了解公众对特定事件、话题或品牌的看法。

文本分类任务的途径

对于文本分类任务,可使用传统的算法,例如贝叶斯算法、SVM模型等来完成。使用这些算法的过程分为三部分:

1:获取带标签的训练数据

2:传入训练数据训练模型

3:传入测试数据检查模型正确率、召回率等

采用传统算法虽然可以完成这些任务,但一般情况下正确率不会太高,且为了准备充足的训练数据需要耗费大量的时间,在ChatGPT发布后,实际可以调用ChatGPT的接口获取文字的向量值,传入文本的向量值调用一些传统算法,如随机深林、逻辑回归等,即可快速完成文本分类任务。

调用ChatGPT Embedding接口完成文本分类任务

下面的代码来源于openai-cookbook,作用是从kaggle中先下载一份Amazon的fine food reviews的数据,生成每一条文本的向量值。执行代码的时候,为了节省ChatGPT消耗的tokens数量,只挑选了top10条数据来获取向量值。

  1. # imports
  2. import pandas as pd
  3. import tiktoken
  4. from openai.embeddings_utils import get_embedding
  5. import openai
  6. import os
  7. openai.api_key = os.environ.get("OPENAI_API_KEY")
  8. # embedding model parameters
  9. embedding_model = "text-embedding-ada-002"
  10. embedding_encoding = "cl100k_base" # this the encoding for text-embedding-ada-002
  11. max_tokens = 8000 # the maximum for text-embedding-ada-002 is 8191
  12. # load & inspect dataset
  13. # to save space, we provide a pre-filtered dataset
  14. input_datapath = "embeding/food_view.csv"
  15. df = pd.read_csv(input_datapath, index_col=0)
  16. df = df[["Time", "ProductId", "UserId", "Score", "Summary", "Text"]]
  17. df = df.dropna()
  18. df["combined"] = (
  19. "Title: " + df.Summary.str.strip() + "; Content: " + df.Text.str.strip()
  20. )
  21. df.head(2)
  22. # subsample to 1k most recent reviews and remove samples that are too long
  23. # top_n = 1000
  24. # df = df.sort_values("Time").tail(top_n * 2) # first cut to first 2k entries, assuming less than half will be filtered out
  25. # df.drop("Time", axis=1, inplace=True)
  26. encoding = tiktoken.get_encoding(embedding_encoding)
  27. # omit reviews that are too long to embed
  28. df["n_tokens"] = df.combined.apply(lambda x: len(encoding.encode(x)))
  29. # df = df[df.n_tokens <= max_tokens].tail(top_n)
  30. # len(df)
  31. # This may take a few minutes
  32. df["embedding"] = df.combined.apply(
  33. lambda x: get_embedding(x, engine=embedding_model))
  34. df.to_csv("embeding/food_view_embeddings.csv")

原始的Amazon find food views数据如下所示,每一条数据都包括Id,ProductId,UserId...Score,Time,Summary,Text等字段内容。 

执行上面的脚本处理后得到的数据如下所示,每一段文字后面都是一个很长的数组序列值,即文本的向量值来代替文本的含义。因为增加了文本的向量信息,新的文本会增加很多。以上面的10条数据的csv文件为例,原始文件大小是3KB,增加向量信息后的文件大小是350KB,文件大小增加了100多倍。

 有了带向量信息的文本数据后,即可挑选一些已有的算法,来完成文本分类任务,如下面的代码所示,可以使用随机深林算法等。下面的算法使用sklearn库中已有的函数,将数据拆分成训练数据和测试数据,调用RandomForestClassifier()函数,完成模型训练,训练完成后传入测试数据,对测试数据中的文本进行预测,预测完成后,调用classification_report()函数得到预测情况报告。

  1. import pandas as pd
  2. import numpy as np
  3. from sklearn.ensemble import RandomForestClassifier
  4. from sklearn.model_selection import train_test_split
  5. from sklearn.metrics import classification_report, accuracy_score
  6. # load data
  7. datafile_path = "embeding/food_view_embeddings.csv"
  8. df = pd.read_csv(datafile_path)
  9. df["embedding"] = df.embedding.apply(eval).apply(
  10. np.array) # convert string to array
  11. # split data into train and test
  12. X_train, X_test, y_train, y_test = train_test_split(
  13. list(df.embedding.values), df.Score, test_size=0.2, random_state=42
  14. )
  15. # train random forest classifier
  16. clf = RandomForestClassifier(n_estimators=100)
  17. clf.fit(X_train, y_train)
  18. preds = clf.predict(X_test)
  19. probas = clf.predict_proba(X_test)
  20. report = classification_report(y_test, preds)
  21. print(report)

下图是两份结果报告,左图是用上面的10条带向量的文本数据调用随机深林算法得到的结果。因为总共只有10条数据,可以看到准确率并不是很高,比如1分的准确率是0.00,5分的准确率是0.50.右图是从openai-cookbook中截取的结果报告,openai-cookbook中是挑选了1000条数据来获取向量值,再调用随机深林算法得到的结果,可以看到总体准确率达到0.92,即准确率是92%,召回率是0.43.也算比较好的结果了。

 为了更好的理解算法对文本分类任务的处理效果,需要理解下上图中各个数据的含义。

precision:准确率,代表模型判定属于这个分类的标题里面判断正确的有多少,有多少真的是属于这个分类的。

recall:召回率,代表模型判定属于这个分类的标题占实际这个分类下所有标题的比例,也就是没有漏掉的比例。

f1-score:所以模型效果的好坏,既要考虑准确率,又要考虑召回率,综合考虑这两项得出的结果,就是 F1 分数(F1 Score)。F1 分数,是准确率和召回率的调和平均数,也就是 F1 Score = 2/ (1/Precision + 1/Recall)。

macro avg:中文名叫做宏平均,宏平均的三个指标,就是把上面每一个分类算出来的指标加在一起平均一下。它主要是在数据分类不太平衡的时候,帮助我们衡量模型效果怎么样

weighted avg:加权平均,也就是我们把每一个指标,按照分类里面支持的样本量加权,算出来的一个值。无论是 Precision、Recall 还是 F1 Score 都要这么按照各个分类加权平均一下。

用已有的带文本向量的测试数据观察正确率 

如果完全从原始训练数据开始,调用ChatGPT获取文本的向量信息,会消耗很多tokens,下面的代码是下载直接带向量信息的训练数据。该训练数据来源于极客时间徐老师的“AI大模型之美”,用这份训练数据,再调用随机深林或者逻辑回归等算法可更快的检验文本分类的正确率等信息。

  1. from sklearn.ensemble import RandomForestClassifier
  2. from sklearn.model_selection import train_test_split
  3. from sklearn.metrics import classification_report, accuracy_score
  4. import pandas as pd
  5. from sklearn.linear_model import LogisticRegression
  6. training_data = pd.read_parquet(
  7. "embeding/toutiao_cat_data_all_with_embeddings.parquet")
  8. df = training_data.sample(50000, random_state=42)
  9. X_train, X_test, y_train, y_test = train_test_split(
  10. list(df.embedding.values), df.category, test_size=0.2, random_state=42
  11. )
  12. #使用随机森林算法完成文本分类任务
  13. clf = RandomForestClassifier(n_estimators=300)
  14. clf.fit(X_train, y_train)
  15. preds = clf.predict(X_test)
  16. probas = clf.predict_proba(X_test)
  17. report = classification_report(y_test, preds)
  18. print(report)
  19. # 使用逻辑回归算法完成文本分类
  20. # clf = LogisticRegression()
  21. # clf.fit(X_train, y_train)
  22. # preds = clf.predict(X_test)
  23. # probas = clf.predict_proba(X_test)
  24. # report = classification_report(y_test, preds)
  25. # print(report)

下面是生成的算法报告,因为训练数据是一份不同类型新闻的数据,不同的新闻被分配到不同的category中,例如农业新闻,汽车新闻等。可以看到因为挑选了5w条数据,其中4w条用来训练,1w条用于测试,所以生成的report中总体的正确率达到0.79,召回率也叨叨了0.76,算是比较好的结果了。

 总结

总结而言,利用ChatGPT的embedding接口,无需对机器学习算法有深入了解的情况下,也能快速完成文本分类任务,且正确率还能达到近90%的结果。而文本分类任务在现实世界中有很多应用场景,例如:社交媒体分析,舆情监测等,相较于传统的机器学习算法,调用ChatGPT的embedding接口可以更快的完成文本分类的任务。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/笔触狂放9/article/detail/350314
推荐阅读
相关标签
  

闽ICP备14008679号