当前位置:   article > 正文

PT之BERT:基于torch框架(特征编码+BERT作为文本编码器+分类器)针对UCI新闻数据集利用Transformer-BERT算法(模型实时保存)实现新闻文本多分类案例_torch bert

torch bert

PT之BERT:基于torch框架(特征编码+BERT作为文本编码器+分类器)针对UCI新闻数据集利用Transformer-BERT算法(模型实时保存)实现新闻文本多分类案例

目录

基于torch框架(特征编码+BERT作为文本编码器+分类器)针对UCI新闻数据集利用Transformer-BERT算法(模型实时保存)实现新闻文本多分类

# 1、定义数据集

# 2、数据预处理

2.1、筛选特征:数据集包含标题(title)和类别(category)两列

# 2.2、去掉空值

# 2.3、【类别型】特征编码化:将类别信息转换为数字标签

# 2.4、数据集规范化:模型可接受的torch向量形式,以便用于训练或推理

# 3、模型训练与推理

# 3.1、切分数据集

# 3.2、数据集Torch化并进入数据加载器:需要设置批量大小和最大序列长度

# 3.3、定义模型、损失、优化器

# 使用BERT模型作为编码器,并添加一个全连接层进行分类

# 3.4、模型训练和实时保存

# 4、模型推理


相关文章
PT之BERT:基于torch框架(特征编码+BERT作为文本编码器+分类器)针对UCI新闻数据集利用Transformer-BERT算法(模型实时保存)实现新闻文本多分类案例
PT之BERT:基于torch框架(特征编码+BERT作为文本编码器+分类器)针对UCI新闻数据集利用Transformer-BERT算法(模型实时保存)实现新闻文本多分类案例实现代码

基于torch框架(特征编码+BERT作为文本编码器+分类器)针对UCI新闻数据集利用Transformer-BERT算法(模型实时保存)实现新闻文本多分类

# 1、定义数据集

IDTITLEURLPUBLISHERCATEGORYSTORYHOSTNAMETIMESTAMP
1Fed official says weak data caused by weather, should not slow taperhttp://www.latimes.com/business/money/la-fi-mo-federal-reserve-plosser-stimulus-economy-20140310,0,1312750.story\?track=rssLos Angeles TimesbddUyU0VZz0BRneMioxUPQVP6sIxvMwww.latimes.com1.39447E+12
2Fed's Charles Plosser sees high bar for change in pace of taperinghttp://www.livemint.com/Politics/H2EvwJSK2VE6OF7iK1g3PP/Feds-Charles-Plosser-sees-high-bar-for-change-in-pace-of-ta.htmlLivemintbddUyU0VZz0BRneMioxUPQVP6sIxvMwww.livemint.com1.39447E+12
3US open: Stocks fall after Fed official hints at accelerated taperinghttp://www.ifamagazine.com/news/us-open-stocks-fall-after-fed-official-hints-at-accelerated-tapering-294436IFA MagazinebddUyU0VZz0BRneMioxUPQVP6sIxvMwww.ifamagazine.com1.39447E+12
4Fed risks falling 'behind the curve', Charles Plosser sayshttp://www.ifamagazine.com/news/fed-risks-falling-behind-the-curve-charles-plosser-says-294430IFA MagazinebddUyU0VZz0BRneMioxUPQVP6sIxvMwww.ifamagazine.com1.39447E+12
5Fed's Plosser: Nasty Weather Has Curbed Job Growthhttp://www.moneynews.com/Economy/federal-reserve-charles-plosser-weather-job-growth/2014/03/10/id/557011MoneynewsbddUyU0VZz0BRneMioxUPQVP6sIxvMwww.moneynews.com1.39447E+12
6Plosser: Fed May Have to Accelerate Tapering Pacehttp://www.nasdaq.com/article/plosser-fed-may-have-to-accelerate-tapering-pace-20140310-00371NASDAQbddUyU0VZz0BRneMioxUPQVP6sIxvMwww.nasdaq.com1.39447E+12
7Fed's Plosser: Taper pace may be too slowhttp://www.marketwatch.com/story/feds-plosser-taper-pace-may-be-too-slow-2014-03-10\?reflink=MW_news_stmpMarketWatchbddUyU0VZz0BRneMioxUPQVP6sIxvMwww.marketwatch.com1.39447E+12
8Fed's Plosser expects US unemployment to fall to 6.2% by the end of 2014http://www.fxstreet.com/news/forex-news/article.aspx\?storyid=23285020-b1b5-47ed-a8c4-96124bb91a39FXstreet.combddUyU0VZz0BRneMioxUPQVP6sIxvMwww.fxstreet.com1.39447E+12
9US jobs growth last month hit by weather:Fed President Charles Plosserhttp://economictimes.indiatimes.com/news/international/business/us-jobs-growth-last-month-hit-by-weatherfed-president-charles-plosser/articleshow/31788000.cmsEconomic TimesbddUyU0VZz0BRneMioxUPQVP6sIxvMeconomictimes.indiatimes.com1.39447E+12
10ECB unlikely to end sterilisation of SMP purchases - tradershttp://www.iii.co.uk/news-opinion/reuters/news/152615Interactive InvestorbdPhGU51DcrolUIMxbRm0InaHGA2XMwww.iii.co.uk1.39447E+12
  1. <class 'pandas.core.frame.DataFrame'>
  2. RangeIndex: 422419 entries, 0 to 422418
  3. Data columns (total 8 columns):
  4. # Column Non-Null Count Dtype
  5. --- ------ -------------- -----
  6. 0 ID 422419 non-null int64
  7. 1 TITLE 422419 non-null object
  8. 2 URL 422419 non-null object
  9. 3 PUBLISHER 422417 non-null object
  10. 4 CATEGORY 422419 non-null object
  11. 5 STORY 422419 non-null object
  12. 6 HOSTNAME 422419 non-null object
  13. 7 TIMESTAMP 422419 non-null int64
  14. dtypes: int64(2), object(6)
  15. memory usage: 25.8+ MB

# 2、数据预处理

2.1、筛选特征:数据集包含标题(title)和类别(category)两列

  1. TITLE CATEGORY
  2. 0 Fed official says weak data caused by weather,... b
  3. 1 Fed's Charles Plosser sees high bar for change... b
  4. 2 US open: Stocks fall after Fed official hints ... b
  5. 3 Fed risks falling 'behind the curve', Charles ... b
  6. 4 Fed's Plosser: Nasty Weather Has Curbed Job Gr... b
  7. ... ... ...
  8. 422414 Surgeons to remove 4-year-old's rib to rebuild... m
  9. 422415 Boy to have surgery on esophagus after battery... m
  10. 422416 Child who swallowed battery to have reconstruc... m
  11. 422417 Phoenix boy undergoes surgery to repair throat... m
  12. 422418 Phoenix boy undergoes surgery to repair throat... m
  13. [422419 rows x 2 columns]

# 2.2、去掉空值

  1. TITLE CATEGORY
  2. 0 Fed official says weak data caused by weather,... b
  3. 1 Fed's Charles Plosser sees high bar for change... b
  4. 2 US open: Stocks fall after Fed official hints ... b
  5. 3 Fed risks falling 'behind the curve', Charles ... b
  6. 4 Fed's Plosser: Nasty Weather Has Curbed Job Gr... b
  7. ... ... ...
  8. 422414 Surgeons to remove 4-year-old's rib to rebuild... m
  9. 422415 Boy to have surgery on esophagus after battery... m
  10. 422416 Child who swallowed battery to have reconstruc... m
  11. 422417 Phoenix boy undergoes surgery to repair throat... m
  12. 422418 Phoenix boy undergoes surgery to repair throat... m
  13. [422419 rows x 2 columns]

# 2.3、【类别型】特征编码化:将类别信息转换为数字标签

  1. TITLE CATEGORY
  2. 0 Fed official says weak data caused by weather,... 0
  3. 1 Fed's Charles Plosser sees high bar for change... 0
  4. 2 US open: Stocks fall after Fed official hints ... 0
  5. 3 Fed risks falling 'behind the curve', Charles ... 0
  6. 4 Fed's Plosser: Nasty Weather Has Curbed Job Gr... 0
  7. ... ... ...
  8. 422414 Surgeons to remove 4-year-old's rib to rebuild... 2
  9. 422415 Boy to have surgery on esophagus after battery... 2
  10. 422416 Child who swallowed battery to have reconstruc... 2
  11. 422417 Phoenix boy undergoes surgery to repair throat... 2
  12. 422418 Phoenix boy undergoes surgery to repair throat... 2
  13. [422419 rows x 2 columns]

# 2.4、数据集规范化:模型可接受的torch向量形式,以便用于训练或推理

  1. input_ids tensor([[ 101, 7349, 2880, ..., 0, 0, 0],
  2. [ 101, 7349, 1005, ..., 0, 0, 0],
  3. [ 101, 2149, 2330, ..., 0, 0, 0],
  4. ...,
  5. [ 101, 2878, 1011, ..., 0, 0, 0],
  6. [ 101, 2878, 1011, ..., 0, 0, 0],
  7. [ 101, 20077, 1996, ..., 0, 0, 0]])
  8. attention_masks tensor([[1, 1, 1, ..., 0, 0, 0],
  9. [1, 1, 1, ..., 0, 0, 0],
  10. [1, 1, 1, ..., 0, 0, 0],
  11. ...,
  12. [1, 1, 1, ..., 0, 0, 0],
  13. [1, 1, 1, ..., 0, 0, 0],
  14. [1, 1, 1, ..., 0, 0, 0]])
  15. labels tensor([0, 0, 0, ..., 2, 2, 2])

# 3、模型训练与推理

# 3.1、切分数据集

# 3.2、数据集Torch化并进入数据加载器:需要设置批量大小和最大序列长度

  1. train_dataset
  2. <__main__.NewsDataset object at 0x0000021A1EE9ECA0>
  3. (tensor([ 101, 20228, 15094, 2121, 1024, 7349, 2089, 2031, 2000, 23306,
  4. 6823, 4892, 6393, 102, 0, 0, 0, 0, 0]), tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0]), tensor(0))
  5. (tensor([ 101, 2149, 2330, 1024, 15768, 2991, 2044, 7349, 2880, 20385,
  6. 2012, 14613, 6823, 4892, 102, 0, 0, 0, 0]), tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]), tensor(0))
  7. (tensor([ 101, 7349, 1005, 1055, 20228, 15094, 2121, 1024, 11808, 4633,
  8. 2038, 13730, 2098, 3105, 3930, 102, 0, 0, 0]), tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0]), tensor(0))
  9. (tensor([ 101, 7349, 10831, 4634, 1005, 2369, 1996, 7774, 1005, 1010,
  10. 2798, 20228, 15094, 2121, 2758, 102, 0, 0, 0]), tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0]), tensor(0))
  11. test_dataloader
  12. <torch.utils.data.dataloader.DataLoader object at 0x0000021A1EF422E0>
  13. [tensor([[ 101, 7349, 2880, 2758, 5410, 2951, 3303, 2011, 4633, 1010,
  14. 2323, 2025, 4030, 6823, 2099, 102, 0, 0, 0],
  15. [ 101, 7349, 1005, 1055, 2798, 20228, 15094, 2121, 5927, 2152,
  16. 3347, 2005, 2689, 1999, 6393, 1997, 6823, 4892, 102]]), tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
  17. [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0]]), tensor([0, 0])]

# 3.3、定义模型、损失、优化器

# 使用BERT模型作为编码器,并添加一个全连接层进行分类

# 3.4、模型训练和实时保存

  1. Epoch: 01
  2. Train Loss: 0.7342, Train Acc: 0.7198
  3. Eval Loss: 0.2669, Eval Acc: 46.0000
  4. Epoch: 02
  5. Train Loss: 0.1879, Train Acc: 0.9464
  6. Eval Loss: 0.1194, Eval Acc: 48.2812
  7. Epoch: 03
  8. Train Loss: 0.0991, Train Acc: 0.9731
  9. Eval Loss: 0.1043, Eval Acc: 48.2500
  10. Epoch: 04
  11. Train Loss: 0.0630, Train Acc: 0.9811
  12. Eval Loss: 0.1025, Eval Acc: 48.5312
  13. Epoch: 05
  14. Train Loss: 0.0439, Train Acc: 0.9866
  15. Eval Loss: 0.1078, Eval Acc: 48.5938

# 4、模型推理

  1. This is a breaking news about politics
  2. Predicted class: 0

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小丑西瓜9/article/detail/371842?site
推荐阅读
相关标签
  

闽ICP备14008679号