当前位置:   article > 正文

天池大赛之工业蒸汽处理(改进版 ---- 0.1235)_天池大赛 工业蒸汽 深度学习

天池大赛 工业蒸汽 深度学习

导包

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline
import seaborn as sns
from sklearn.linear_model import LinearRegression,Lasso,Ridge,ElasticNet,RidgeCV
from sklearn.neighbors import KNeighborsRegressor
from sklearn.ensemble import GradientBoostingRegressor,RandomForestRegressor,AdaBoostRegressor,ExtraTreesRegressor
from xgboost import XGBRegressor
from lightgbm import LGBMRegressor
#支持向量机
from sklearn.svm import SVR
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler,StandardScaler,PolynomialFeatures
import warnings
warnings.filterwarnings('ignore')
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17

数据聚合

train_data = pd.read_csv('./zhengqi_train.txt',sep='\t')
test_data = pd.read_csv('./zhengqi_test.txt',sep='\t')
  • 1
  • 2
#合并训练数据和预测数据
train_data["origin"]="train"
test_data["origin"]="test"
data_all=pd.concat([train_data,test_data],axis=0,ignore_index=True)
#View data
data_all
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
V0V1V10V11V12V13V14V15V16V17V18V19V2V20V21V22V23V24V25V26V27V28V29V3V30V31V32V33V34V35V36V37V4V5V6V7V8V9origintarget
00.5660.016-0.940-0.307-0.0730.550-0.4840.000-1.707-1.162-0.573-0.991-0.1430.610-0.400-0.0630.3560.800-0.2230.7960.168-0.4500.1360.4070.109-0.6150.327-4.627-4.789-5.101-2.608-3.5080.452-0.901-1.812-2.360-0.436-2.114train0.175
10.9680.4370.188-0.455-0.1341.109-0.4880.000-0.977-1.162-0.571-0.8360.0660.588-0.802-0.0630.3570.801-0.1441.0570.3380.671-0.1280.5660.1240.0320.600-0.8430.1600.364-0.335-0.7300.194-0.893-1.566-2.3600.332-2.114train0.676
21.0130.5680.874-0.051-0.0720.767-0.493-0.212-0.618-0.897-0.564-0.5580.2350.576-0.477-0.0630.3550.961-0.0670.9150.3261.287-0.0090.3700.3610.277-0.116-0.8430.1600.3640.765-0.5890.112-0.797-1.367-2.3600.396-2.114train0.633
30.7330.3680.0110.102-0.0140.769-0.371-0.162-0.429-0.897-0.574-0.5640.2830.272-0.491-0.0630.3521.4350.1130.8980.2771.2980.0150.1650.4170.2790.603-0.843-0.0650.3640.333-0.1120.599-0.679-1.200-2.0860.403-2.114train0.206
40.6840.638-0.2510.5700.199-0.349-0.342-0.138-0.391-0.897-0.572-0.3940.2600.1060.309-0.2590.3520.8810.2210.3860.3321.2890.1830.2091.0780.3280.418-0.843-0.2150.364-0.280-0.0280.337-0.454-1.073-2.0860.314-2.114train0.384
...........................................................................................................................
4808-1.362-1.553-2.5510.5180.3960.9281.4520.867-5.1431.227-3.5730.107-3.096-0.0880.2272.953-1.538-0.630-3.072-1.120-1.6740.5250.171-0.444-4.488-5.793-4.050-1.187-0.852-2.131-2.5640.5970.3811.375-4.854-5.331-4.074-3.838testNaN
4809-2.698-3.452-2.5250.311-1.7861.8711.8851.135-5.7741.227-0.9650.193-3.620-0.506-0.5743.149-1.479-0.204-3.432-2.101-1.773-0.4461.297-1.066-0.613-7.698-0.674-1.187-0.852-2.131-2.5641.215-1.3851.378-4.927-5.103-4.393-1.683testNaN
4810-2.615-3.564-2.529-0.029-1.1511.9762.3370.504-4.7521.492-1.5680.301-3.4020.109-0.5413.511-1.0851.057-2.4090.477-1.585-0.4470.552-0.4220.125-6.1110.275-1.851-1.548-1.537-2.5441.612-1.2721.121-4.223-4.315-5.196-3.407testNaN
4811-2.661-3.646-2.560-0.028-1.5121.5202.2430.206-4.2001.492-1.282-0.036-3.271-1.015-0.2033.511-1.0840.800-2.3390.050-1.410-0.4470.318-0.6991.086-5.2680.683-1.645-1.471-1.537-2.5491.431-1.2701.116-3.716-3.809-4.735-2.976testNaN
4812-2.321-3.0370.0560.306-1.1540.8472.2210.206-3.9601.492-1.2130.592-3.214-1.5020.1533.609-1.0880.799-2.339-0.077-1.242-0.4420.323-1.594-0.774-5.2111.618-1.703-1.471-1.537-1.1231.988-0.9101.259-3.616-3.747-4.368-2.976testNaN

4813 rows × 40 columns

特征探索

#探索出去最后两列的数字属性
data_all.columns[:-2]
  • 1
  • 2
Index(['V0', 'V1', 'V10', 'V11', 'V12', 'V13', 'V14', 'V15', 'V16', 'V17',
       'V18', 'V19', 'V2', 'V20', 'V21', 'V22', 'V23', 'V24', 'V25', 'V26',
       'V27', 'V28', 'V29', 'V3', 'V30', 'V31', 'V32', 'V33', 'V34', 'V35',
       'V36', 'V37', 'V4', 'V5', 'V6', 'V7', 'V8', 'V9'],
      dtype='object')
  • 1
  • 2
  • 3
  • 4
  • 5
#38个特征将一些不重要的删除
#特征分布情况,训练和测试数据特征分布不均匀,删除
for column in data_all.columns[0:-2]:
    g = sns.kdeplot(data_all[column][(data_all["origin"] == "train")], color="Red", shade = True)
    g = sns.kdeplot(data_all[column][(data_all["origin"] == "test")], ax =g, color="Blue", shade= True)
    g.set_xlabel(column)
    g.set_ylabel("Frequency")
    g = g.legend(["train","test"])
    plt.show()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9

在这里插入图片描述

fig = plt.figure(figsize=(10, 10))
for i in range(len(data_all.columns)-2):
    g = sns.FacetGrid(data_all, col='origin')
    g = g.map(sns.distplot, data_all.columns[i])
  • 1
  • 2
  • 3
  • 4
<Figure size 720x720 with 0 Axes>
  • 1

在这里插入图片描述

#通过图示可以看出'V11','V17','V22','V5',波动太大,删除
drop_labels = ['V11','V17','V22','V5']
data_all.drop(drop_labels,axis=1,inplace=True)
  • 1
  • 2
  • 3

相关性系数corr

# 找出相关程度
plt.figure(figsize=(20, 16))  # 指定绘图对象宽度和高度
mcorr = train_data.corr()  # 相关系数矩阵,即给出了任意两个变量之间的相关系数
mask = np.zeros_like(mcorr, dtype=np.bool)  # 构造与mcorr同维数矩阵 为bool型
mask[np.triu_indices_from(mask)] = True  # 角分线右侧为True
cmap = sns.diverging_palette(220, 10, as_cmap=True)  # 返回matplotlib colormap对象
g = sns.heatmap(mcorr, mask=mask, cmap=cmap, square=True, annot=True, fmt='0.2f')  # 热力图(看两两相似度)
plt.show
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
<function matplotlib.pyplot.show(*args, **kw)>
  • 1

在这里插入图片描述

# 通过相关性系数找到7个相关性不大的属性
cond = mcorr.loc['target'].abs()<0.1
drop_labels = mcorr.loc['target'][cond].index
#['V14', 'V21', 'V25', 'V26', 'V32', 'V33', 'V34']

#查看属性分布后,将分布不好的删除  ('V14', 'V21', )
drop_labels = ['V14', 'V21']
data_all.drop(drop_labels,axis=1,inplace=True)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
#删除了6个属性
data_all.shape
  • 1
  • 2
(4813, 34)
  • 1

对数据进行归一化

data = data_all.iloc[:,:-2]
minmaxscale = MinMaxScaler()
data = minmaxscale.fit_transform(data)
data
  • 1
  • 2
  • 3
  • 4
array([[0.77577505, 0.723449  , 0.22174265, ..., 0.43285165, 0.66410771,
        0.73528007],
       [0.83374189, 0.77878549, 0.37388724, ..., 0.43285165, 0.7548128 ,
        0.73528007],
       [0.84023071, 0.79600421, 0.46641489, ..., 0.43285165, 0.76237156,
        0.73528007],
       ...,
       [0.31708724, 0.25289169, 0.0074184 , ..., 0.17367095, 0.10192512,
        0.64706284],
       [0.31045422, 0.24211356, 0.00323712, ..., 0.24075302, 0.1563718 ,
        0.67646858],
       [0.35948089, 0.32216088, 0.35608309, ..., 0.24897256, 0.19971655,
        0.67646858]])
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
#归一化数据
data_all_norm = pd.DataFrame(data,columns=data_all.columns[:-2])
data_all_norm
  • 1
  • 2
  • 3
V0V1V10V12V13V15V16V18V19V2V20V23V24V25V26V27V28V29V3V30V31V32V33V34V35V36V37V4V6V7V8V9
00.7757750.7234490.2217430.5708280.6947860.4022450.4879500.3751250.3802380.5821970.5379460.7921690.5691530.3752500.7307360.9029360.2793410.4068340.6651930.6037140.7293790.6794790.0000000.0000000.2424240.0000000.0183430.5718390.5086160.4328520.6641080.735280
10.8337420.7787850.3738870.5644180.7785440.4022450.5697790.3753740.4019620.6115880.5349960.7923040.5694190.3818240.7629150.9247340.4370950.3715960.6894340.6056760.7960050.7217920.3749500.4999490.8000200.2897020.4360250.5443810.5412250.4328520.7548130.735280
20.8402310.7960040.4664150.5709330.7273000.3728700.6100210.3762460.4409250.6353540.5333870.7920350.6118930.3882320.7454070.9231950.5237830.3874800.6595520.6366730.8212340.6108180.3749500.4999490.8000200.4299010.4572240.5356530.5676030.4328520.7623720.735280
30.7998560.7697160.3500130.5770280.7276000.3797980.6312070.3750000.4400840.6421040.4926250.7916330.7377220.4032120.7433120.9169120.5253310.3906830.6282970.6439970.8214400.7222570.3749500.4772200.8000200.3748410.5289430.5874840.5897400.4691770.7631980.735280
40.7927900.8052050.3146750.5994120.5600840.3831230.6354670.3752490.4639100.6388690.4703670.7916330.5906560.4122000.6801870.9239650.5240640.4131070.6350050.7304470.8264850.6935830.3749500.4620670.8000200.2967120.5415730.5596000.6065750.4691770.7526870.735280
...................................................................................................
48080.4977650.5172190.0044510.6201130.7514230.5223780.1027910.0012460.5341280.1669240.4443550.5382140.1895410.1381490.4945140.6667520.4165490.4115060.5354470.0024850.1961690.0010850.3408640.3977170.5454550.0056080.6355440.5642830.1053820.0389770.2344400.617657
48090.3051190.2676130.0079580.3908150.8927180.5595120.0320590.3262710.5461810.0932360.3883080.5461250.3026280.1081890.3735670.6540580.2799040.5617990.4406160.5092860.0000000.5243340.3408640.3977170.5454550.0056080.7284620.3763300.0957050.0692030.1967640.764686
48100.3170870.2528920.0074180.4575450.9084510.4720800.1466200.2511220.5613170.1238930.4707700.5989540.6373770.1933260.6914070.6781640.2797640.4623600.5388020.6058070.1634230.6714200.2750690.3274070.6060610.0081570.7881520.3883570.1890240.1736710.1019250.647063
48110.3104540.2421140.0032370.4196090.8401260.4307880.2084970.2867650.5140850.1423150.3200590.5990880.5691530.1991510.6387620.7006030.2797640.4311270.4965700.7314940.2502320.7346560.2954820.3351850.6060610.0075200.7609380.3885700.2562300.2407530.1563720.676469
48120.3594810.3221610.3560830.4572300.7392870.4307880.2354000.2953640.6021020.1503300.2547600.5985520.5688880.1991510.6231040.7221440.2804670.4317940.3601160.4882290.2561010.8795720.2897340.3351850.6060610.1892680.8446850.4268840.2694860.2489730.1997170.676469

4813 rows × 32 columns

#将oringin和target属性merage上
data_all_norm = pd.merge(data_all_norm,data_all.iloc[:,-2:],left_index=True,right_index=True)
  • 1
  • 2
data_all_norm.describe()
  • 1
V0V1V10V12V13V15V16V18V19V2V20V23V24V25V26V27V28V29V3V30V31V32V33V34V35V36V37V4V6V7V8V9target
count4813.0000004813.0000004813.0000004813.0000004813.0000004813.0000004813.0000004813.0000004813.0000004813.0000004813.0000004813.0000004813.0000004813.0000004813.0000004813.0000004813.0000004813.0000004813.0000004813.0000004813.0000004813.0000004813.0000004813.0000004813.0000004813.0000004813.0000004813.0000004813.0000004813.0000004813.0000004813.0000002888.000000
mean0.6941720.7213570.3485180.5785070.6123720.4022510.6792940.4465420.5191580.6023000.4561470.7444380.3567120.3937960.6325820.8814010.3426530.3886830.6031390.5894590.7927090.6288240.4584930.4837900.7628730.3323850.5457950.5237430.7488230.7457400.7156070.8795360.126353
std0.1441980.1314430.1348820.1050880.1498350.1385610.1120950.1246270.1401660.1406280.1340830.1340850.2655120.0832260.1232940.1282210.1407310.1334750.1524620.1307860.1029760.1550030.0990950.1010200.1020370.1274560.1503560.1064300.1325600.1325770.1181050.0682440.983966
min0.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.000000-3.044000
25%0.6266760.6794160.2843270.5328920.5199280.2990160.6294140.3993020.4144360.5144140.3704750.7193620.0406160.3478700.5665150.8885750.2787780.2924450.5038880.5500920.7618160.5624610.4090370.4544900.7272730.2705840.4456470.4781820.6833240.6969380.6649340.852903-0.350250
50%0.7294880.7524970.3664690.5916350.6278090.3914370.7002580.4562560.5402940.6170720.4473050.7888170.3817360.3888150.6412280.9160150.2799040.3757340.6142700.5944280.8150550.6430560.4545180.4999490.8000200.3470560.5393170.5358660.7741250.7719740.7428840.8823770.313000
75%0.7901950.7995530.4329650.6419710.7199580.4899540.7532790.5017450.6231250.7004640.5226600.7927060.5747280.4275970.7135990.9325550.4130310.4718370.7104740.6507980.8522290.7197770.5000000.5113650.8000200.4148610.6430610.5850360.8422590.8364050.7908350.9411890.793250
max1.0000001.0000001.0000001.0000001.0000001.0000001.0000001.0000001.0000001.0000001.0000001.0000001.0000001.0000001.0000001.0000001.0000001.0000001.0000001.0000001.0000001.0000001.0000001.0000001.0000001.0000001.0000001.0000001.0000001.0000001.0000001.0000002.538000
def scale_minmax(data):
    return (data - data.min())/(data.max() - data.min())
  • 1
  • 2
'
运行
#使用Box-Cox将连续数据转换的更加平滑(主要处理类似正太分布)
from scipy import stats
fcols = 6
frows = len(data_all_norm.columns[:10])
plt.figure(figsize=(4*fcols,4*frows))
i = 0

for col in data_all_norm.columns[:10]:
    dat = data_all_norm[[col, 'target']].dropna()

#     这条线就是数据分布dist:distribution(分布)
    i+=1
    plt.subplot(frows,fcols,i)
    sns.distplot(dat[col],fit = stats.norm);
    plt.title(col+' Original')
    plt.xlabel('')

#     第二个图:skew统计分析中中一个属性
#     skewness 偏斜系数,对正太分布的度量
    i+=1
    plt.subplot(frows,fcols,i)
    _=stats.probplot(dat[col], plot=plt)#画图,偏析度
    plt.title('skew='+'{:.4f}'.format(stats.skew(dat[col])))
    plt.xlabel('')
    plt.ylabel('')

#     散点图
    i+=1
    plt.subplot(frows,fcols,i)
#     plt.plot(dat[var], dat['target'],'.',alpha=0.5)
    plt.scatter(dat[col],dat['target'],alpha=0.5)
    plt.title('corr='+'{:.2f}'.format(np.corrcoef(dat[col], dat['target'])[0][1]))

#     !!!对数据进行了处理!!!
#   数据分布图distribution
    i+=1
    plt.subplot(frows,fcols,i)
    trans_var, lambda_var = stats.boxcox(dat[col].dropna()+1)
    trans_var = scale_minmax(trans_var)      
    sns.distplot(trans_var , fit=stats.norm);
    plt.title(col+' Tramsformed')
    plt.xlabel('')

#     偏斜度
    i+=1
    plt.subplot(frows,fcols,i)
    _=stats.probplot(trans_var, plot=plt)
    plt.title('skew='+'{:.4f}'.format(stats.skew(trans_var)))
    plt.xlabel('')
    plt.ylabel('')

#     散点图
    i+=1
    plt.subplot(frows,fcols,i)
    plt.plot(trans_var, dat['target'],'.',alpha=0.5)
    plt.title('corr='+'{:.2f}'.format(np.corrcoef(trans_var,dat['target'])[0][1]))
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56

在这里插入图片描述

# 将数据进行Box-Cox转换
# 统计建模中常用的数据变化
# 数据更加正态化,标准化
for col in data_all_norm.columns[:-2]:
    boxcox,maxlog = stats.boxcox(data_all_norm[col] + 1)
    data_all_norm[col] = scale_minmax(boxcox)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
data_all_norm
  • 1
V0V1V10V12V13V15V16V18V19V2V20V23V24V25V26V27V28V29V3V30V31V32V33V34V35V36V37V4V6V7V8V9origintarget
00.5074830.3570700.1349590.3034710.5617510.5397350.1360130.2397980.2729140.4426580.6946290.4259290.5924700.6261760.5527210.3946510.3776570.5590020.5813570.3236670.2671570.4407150.0000000.0000000.0264760.0000000.0208960.3536800.1657590.0940560.3040610.253539train0.175
10.6104190.4450150.2535970.2970550.6687040.5397350.1974240.2400040.2922630.4746680.6921100.4261780.5927300.6328780.5973900.4882670.5471310.5223340.6087820.3257840.3768100.4969310.3556310.4322800.4674660.1754570.4660890.3257460.1909980.0940560.4303260.253539train0.676
20.6228950.4758120.3369000.3035770.6021250.5090620.2348230.2407290.3280310.5013060.6907320.4256800.6342070.6393200.5728150.4810230.6304460.5390670.5750400.3605540.4277280.3590390.3556310.4322800.4674660.2903670.4873650.3171540.2135200.0940560.4425280.253539train0.633
30.5484330.4294760.2334980.3097660.6025040.5163900.2566940.2396940.3272450.5089970.6548650.4249350.7553010.6540390.5699160.4524880.6318790.5424000.5404750.3691410.4281670.4975750.3556310.4096240.4674660.2427240.5585470.3702200.2339800.1132830.4438790.253539train0.206
40.5361580.4929870.2047750.3332330.4093790.5198870.2612840.2399010.3497780.5053050.6344680.4249350.6135080.6626480.4869020.4846320.6307070.5653600.5478310.4820050.4390620.4589360.3556310.3946750.4674660.1807090.5709590.3410580.2505480.1132830.4269440.253539train0.384
.........................................................................................................
48080.1844520.1446380.0022940.3560010.6330150.6551940.0098370.0005790.4192130.0926220.6098610.1276120.2059650.3023140.2869010.0331580.5264390.5637410.4421510.0004520.0076960.0002450.3222270.3325860.1525960.0026750.6622170.3458550.0123000.0026220.0352770.128271testNaN
48090.0743570.0361690.0041130.1554650.8304400.6879760.0025180.2005420.4315910.0486670.5538140.1329600.3243640.2461440.1877450.0287430.3783090.7027060.3484240.2327030.0000000.2714680.3222270.3325860.1525960.0026750.7506540.1828000.0108930.0050910.0267470.298339testNaN
48100.0793520.0327150.0038330.2030110.8542180.6086910.0158710.1451120.4473290.0663480.6348420.1737460.6589440.3950010.5010620.0376680.3781460.6135910.4455900.3259250.0055660.4305300.2583230.2673390.2030970.0038990.8065840.1915460.0274080.0174640.0107230.152833testNaN
48110.0765580.0303320.0016670.1750140.7536010.5685710.0268620.1706680.3989300.0773850.4793590.1738620.5924700.4040540.4367690.0482820.3781460.5833460.4029110.4835080.0124120.5149640.2780670.2744240.2030970.0035920.7811670.1917030.0442010.0296090.0190580.181492testNaN
48120.0987810.0511780.2385550.2027670.6173740.5685710.0327090.1770300.4908040.0822860.4008020.1734000.5922100.4040540.4186810.0610870.3789600.5840050.2741310.2151430.0130310.7513030.2725000.2744240.2030970.1056450.8589520.2210290.0481010.0313840.0273650.181492testNaN

4813 rows × 34 columns

过滤异常值

ridge = RidgeCV(alphas=[0.0001,0.001,0.01,0.1,0.2,0.5,1,2,3,4,5,10,20,30,50])

cond = data_all_norm['origin'] == 'train'

X_train = data_all_norm[cond].iloc[:,:-2]
# 真实值
y_train = data_all_norm[cond]['target']
# 算法拟合数据和目标值的时候,不可能100%拟合
ridge.fit(X_train,y_train)
# 预测,预测值肯定会和真实值有一定的偏差,偏差特别大,当成异常值
y_ = ridge.predict(X_train)

cond = abs(y_ - y_train) > y_train.std()
print(cond.sum())
# 画图
plt.figure(figsize=(12,6))
axes = plt.subplot(1,3,1)
axes.scatter(y_train,y_)
axes.scatter(y_train[cond],y_[cond],c = 'red',s = 20)

axes = plt.subplot(1,3,2)
axes.scatter(y_train,y_train - y_)
axes.scatter(y_train[cond],(y_train - y_)[cond],c = 'red')

axes = plt.subplot(1,3,3)
# _ = axes.hist(y_train,bins = 50)
(y_train - y_).plot.hist(bins = 50,ax = axes)
(y_train - y_).loc[cond].plot.hist(bins = 50,ax = axes,color = 'r')
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
40





<matplotlib.axes._subplots.AxesSubplot at 0x2403c0836a0>
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

在这里插入图片描述

index = cond[cond].index

data_all_norm.drop(index,axis = 0,inplace=True)
  • 1
  • 2
  • 3
cond = data_all_norm['origin'] == 'train'
X_train = data_all_norm[cond].iloc[:,:-2]
y_train = data_all_norm[cond]['target']

cond = data_all_norm['origin'] == 'test'
X_test = data_all_norm[cond].iloc[:,:-2]
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

使用不同算法进行计算,最后求取平均值!!

estimators = {}
estimators['forest'] = RandomForestRegressor(n_estimators=300)
estimators['gbdt'] = GradientBoostingRegressor(n_estimators=300)
estimators['ada'] = AdaBoostRegressor(n_estimators=300)
estimators['extreme'] = ExtraTreesRegressor(n_estimators=300)
estimators['svm_rbf'] = SVR(kernel='rbf')
estimators['light'] = LGBMRegressor(n_estimators=300)
estimators['xgb'] = XGBRegressor(n_estimators=300)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
#将结果存入列表中,求取平均值作为最后答案
result = []
for key,model in estimators.items():
    model.fit(X_train,y_train)
    y_ = model.predict(X_test)
    result.append(y_)
y_ = np.mean(result,axis = 0)

pd.Series(y_).to_csv('./norm.txt',index = False)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
[19:51:26] WARNING: C:/Jenkins/workspace/xgboost-win64_release_0.90/src/objective/regression_obj.cu:152: reg:linear is now deprecated in favor of reg:squarederror.
  • 1
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/酷酷是懒虫/article/detail/945627
推荐阅读
相关标签
  

闽ICP备14008679号