当前位置:   article > 正文

Python-JupyterGPU机器学习代码_jupyter notebook gpu加速

jupyter notebook gpu加速

        使用 Jupyter Notebook 进行 GPU 加速的机器学习代码开发,通常涉及到利用 GPU 运行深度学习模型,特别是基于 TensorFlow 或 PyTorch 这样的深度学习框架。GPU 的并行计算能力可以显著加快模型训练的速度,尤其对于大规模数据集和复杂模型来说效果更为明显。

根据需要训练的数据进行机器学习建模

  1. import time
  2. start = time.time()
  3. import numpy as np # linear algebra
  4. import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
  5. from keras.models import Sequential
  6. from keras.layers import Conv2D, MaxPooling2D, Activation, Dropout, Flatten, Dense
  7. from keras.preprocessing.image import ImageDataGenerator, img_to_array, load_img
  8. import matplotlib.pyplot as plt
  9. from PIL import Image
  10. from glob import glob
  11. import os
  12. import os
  13. print(os.system('ls /dataset/stress_test_data_2'))
  14. train_path = '/dataset/stress_test_data_2/Training/'
  15. test_path = '/dataset/stress_test_data_2/Test/'
  16. img = load_img(train_path + "Apple Braeburn/0_100.jpg", target_size=(100,100))
  17. plt.imshow(img)
  18. plt.axis("off")
  19. plt.show()
  20. images = ['Orange', 'Banana', 'Cauliflower', 'Cactus fruit', 'Eggplant', 'Avocado', 'Blueberry','Lemon', 'Kiwi']
  21. import matplotlib.pyplot as plt
  22. import numpy as np
  23. fig = plt.figure(figsize =(15,5))
  24. for i in range(9):
  25. ax = fig.add_subplot(3,3,i+1,xticks=[],yticks=[])
  26. #fig.patch.set_facecolor('#E53090')
  27. #Above code adds a background color for subplots you can change the hex color code as you wish
  28. plt.title(images[i])
  29. plt.axis("off")
  30. ax.imshow(load_img(train_path + images[i] +"/0_100.jpg", target_size=(100,100)))
  31. x = img_to_array(img)
  32. print(x.shape)
  33. className = glob(train_path + '/*')
  34. number_of_class = len(className)
  35. print(number_of_class)
  36. model = Sequential()
  37. model.add(Conv2D(32, (3,3), input_shape= x.shape))
  38. model.add(Activation("relu"))
  39. model.add(MaxPooling2D())
  40. model.add(Conv2D(32, (3,3),))
  41. model.add(Activation("relu"))
  42. model.add(MaxPooling2D())
  43. model.add(Conv2D(64, (3,3),))
  44. model.add(Activation("relu"))
  45. model.add(MaxPooling2D())
  46. model.add(Flatten())
  47. model.add(Dense(1024))
  48. model.add(Activation("relu"))
  49. model.add(Dropout(0.5))
  50. model.add(Dense(number_of_class))#output
  51. model.add(Activation("softmax"))
  52. model.compile(loss = "categorical_crossentropy",
  53. optimizer = "rmsprop",
  54. metrics = ["accuracy"])
  55. model.summary()
  56. batch_size = 32
  57. train_datagen = ImageDataGenerator(rescale = 1./255,
  58. shear_range = 0.3,
  59. horizontal_flip=True,
  60. vertical_flip=False,
  61. zoom_range = 0.3
  62. )
  63. test_datagen = ImageDataGenerator(rescale = 1./255)
  64. train_generator = train_datagen.flow_from_directory(train_path,
  65. target_size=x.shape[:2],
  66. batch_size = batch_size,
  67. color_mode= "rgb",
  68. class_mode = "categorical")
  69. test_generator = test_datagen.flow_from_directory(test_path,
  70. target_size=x.shape[:2],
  71. batch_size = batch_size,
  72. color_mode= "rgb",
  73. class_mode = "categorical")
  74. hist = model.fit_generator(generator = train_generator,
  75. steps_per_epoch = 1600 // batch_size,
  76. epochs = 50,
  77. validation_data = test_generator,
  78. validation_steps = 800 // batch_size)
  79. print(hist.history.keys())
  80. plt.plot(hist.history["loss"], label = "Train Loss")
  81. plt.plot(hist.history["val_loss"], label = "Validation Loss")
  82. plt.legend()
  83. plt.show()
  84. plt.plot(hist.history["accuracy"], label = "Train Accuracy")
  85. plt.plot(hist.history["val_accuracy"], label = "Validation Accuracy")
  86. plt.legend()
  87. plt.show()
  88. end = time.time()
  89. print("time cost",end - start)

        在 Jupyter Notebook 中进行 GPU 加速的机器学习代码开发,可以带来训练速度的显著提升,尤其适用于大规模数据和复杂模型的场景。同时,结合 Jupyter Notebook 的交互式编程和展示优势,可以更方便地进行实验、调试和结果展示。如果你需要进一步的帮助或有其他问题,请随时告诉我。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/凡人多烦事01/article/detail/711301
推荐阅读
相关标签
  

闽ICP备14008679号