当前位置:   article > 正文

Copilot 体验(python)

Copilot 体验(python)

基本使用方法

  1. # 提示
  2. def xxx():"""提示"""
  3. 也可以写到一行中间的时候等待 # plot confusion matrix def xxx(): sns.

注意事项

实现复杂功能时,需要人工引导,写到一半,来回按 Tab Backspace 触发提示
直接一个注释搞定不太可能,除非是一些简单和常用的功能,如文件加载保存,对象转换,日期处理,常用函数等等

风险

如果开启 Allow public code ,可能有版权问题
模型可能生成有 bug 的代码


生成代码

import numpy as np
import pandas as pd

# get max 5 numbers from a list
def get_max_5(list):
    return sorted(list, reverse=True)[:5]

# train keras model with data

def train_model(data):
    # split data into train and test
    train_data = data[:int(len(data) * 0.8)]
    test_data = data[int(len(data) * 0.8):]

    # get features and labels
    train_features = train_data.iloc[:, :-1].values
    train_labels = train_data.iloc[:, -1].values
    test_features = test_data.iloc[:, :-1].values
    test_labels = test_data.iloc[:, -1].values

    # reshape features
    train_features = train_features.reshape(train_features.shape[0], 1, train_features.shape[1])
    test_features = test_features.reshape(test_features.shape[0], 1, test_features.shape[1])

    # import keras
    

# get VGG-19 model

def get_model():
    # import keras
    from keras.applications import VGG19
    

# confusion matrix
def get_confusion_matrix(pred, true):
    # import sklearn
    from sklearn.metrics import confusion_matrix
    return confusion_matrix(pred, true)

# get now time
def get_now_time():
    import datetime
    return datetime.datetime.now()

# get now time in string format
def get_now_time_str():
    import datetime
    return datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")

# get BERT model keras
def get_bert_model():
    # import keras
    from keras.applications import BERT

# plot heatmap with folium
def plot_heatmap(data, title, filename):
    # import folium
    from folium import plugins
    from folium.plugins import HeatMap
    import folium
    import pandas as pd
    import numpy as np
    import matplotlib.pyplot as plt
    import seaborn as sns

# convert list to dataframe
def list_to_df(list):
    return pd.DataFrame(list)

# # add 1 to 1000
# def add_1_to_1000():
#     for i in range(1000):
#         print(i + 1)
    
# hash function
def hash_function(str):
    import hashlib
    return hashlib.sha256(str.encode()).hexdigest()

# TSP
def tsp(data):
    import numpy as np
    import pandas as pd
    import matplotlib.pyplot as plt
    import seaborn as sns
    import random
    import math
    import time
    import itertools
    import copy
    import matplotlib.pyplot as plt

# crawl a web page
def crawl_web(url):
    import requests
    return requests.get(url)

# load keras model from pickle
def load_model(filename):
    import pickle
    with open(filename, 'rb') as f:
        return pickle.load(f)

# save keras model to pickle
def save_model(model, filename):
    import pickle
    with open(filename, 'wb') as f:
        pickle.dump(model, f)

# load keras model 
def load_keras_model(filename):
    import keras
    return keras.models.load_model(filename)

# display audio
def display_audio(filename):
    # load file
    file = open(filename, 'rb')
    # create an audio object
    audio = AudioSegment.from_file(file)
    # play the audio
    play(audio)

# plot heatmap
def plot_heatmap(data):
    import seaborn as sns
    import matplotlib.pyplot as plt
    plt.figure()
    sns.heatmap(data, cmap='Blues', annot=True, fmt='.2f')
    plt.show()

# create a random numpy array
def create_random_numpy_array(shape):
    return np.random.rand(*shape)

# baidu url
url = 'https://www.baidu.com'


# main function
if __name__ == "__main__":
    data = create_random_numpy_array(shape=(2, 3))
    plot_heatmap(data)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144

在 iris 上训练一个 XGBoost

全程只导入基本的包、打注释和改一些小瑕疵

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import sklearn.model_selection as skl_ms
import xgboost as xgb
  • 1
  • 2
  • 3
  • 4
  • 5


# load iris data from sklearn
from sklearn.datasets import load_iris
iris = load_iris()
X = iris.data
y = iris.target


# split train test data
X_train, X_test, y_train, y_test = skl_ms.train_test_split(X, y, test_size=0.2, random_state=0)


# train a xgb model CV
xgb_model = xgb.XGBClassifier()
xgb_model.fit(X_train, y_train)

# train a xgb model with a gridsearch
parameters = {'max_depth': [3, 4, 5],
                'learning_rate': [0.1, 0.05, 0.02],
                'n_estimators': [50, 100, 200]}
xgb_gridsearch = skl_ms.GridSearchCV(xgb_model, parameters, cv=5)
xgb_gridsearch.fit(X_train, y_train)
print(xgb_gridsearch.best_params_)
print(xgb_gridsearch.best_score_)

# plot a confusion matrix
from sklearn.metrics import confusion_matrix
y_pred = xgb_gridsearch.predict(X_test)
cm = confusion_matrix(y_test, y_pred)    
plt.figure(figsize=(10,7))
plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
plt.title("Confusion Matrix")
plt.colorbar()
plt.ylabel('True label')
plt.xlabel('Predicted label')
tick_marks = np.arange(3)
plt.xticks(tick_marks, ['setosa', 'versicolor', 'virginica'], rotation=45)
plt.yticks(tick_marks, ['setosa', 'versicolor', 'virginica'])
plt.show()


  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42

混淆矩阵

# plot xgb feature importance
importance = xgb_gridsearch.best_estimator_.feature_importances_
feature_names = iris.feature_names
feature_importance = pd.DataFrame({'feature': feature_names, 'importance': importance})
feature_importance = feature_importance.sort_values('importance', ascending=False)
plt.figure(figsize=(10,7))
plt.barh(range(len(feature_importance)), feature_importance.importance, align='center')
plt.yticks(range(len(feature_importance)), feature_importance.feature)
plt.xlabel('Importance')
plt.ylabel('Feature')
plt.title('XGBoost Feature Importance')
plt.show()

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13

特征重要性

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/凡人多烦事01/article/detail/622156
推荐阅读
相关标签
  

闽ICP备14008679号