当前位置:   article > 正文

深度学习框架【CNTK】的安装

cntk


前言

鉴于之前介绍的深度学习框架都是基于python写的,本次学习CNTK的时候决定使用C#来尝试,而且CNTK因各种原因,暂时无法在高于3.6版本进行安装,因此决定尝试使用C#来做CNTK的开发。


一、CNTK的前世今生

Computational Network Toolkit (CNTK) 是微软出品的开源深度学习工具包,改名后叫The Microsoft Cognitive Toolkit.。
CNTK只是一个框架或者说是一套简单的工具帮助我们实现我们所涉及的深度学习或者是神经网络。其中已经集成好很多经典的算法。当然大家也可以根据实际情况去自己定义具体的算法或者输入输出的方式。

特点:

  • 支持各种神经网络模型;
  • 一个简单的配置文件配置特定网络;
  • CNTK 可以用GPU,支持CUDA编程;
  • 自动计算所需的导数;
  • 可扩展;

二、CNTK的安装

本次使用的CNTK由于是使用C# 版的,因此我这边只需要对一些CNTK的标准库进行安装即可。
本次安装的一些工具如下:

1.CNTK预训练模型

2.CNTK各版本下载【GPU/CPU】

3.Ubuntu安装python3.6版的CNTK

安装openmpi

sudo apt install libevent-dev libhwloc-dev libibverbs-dev flex gfortran
sudo apt-get install openmpi-bin openmpi-common openmpi-doc libopenmpi-dev
  • 1
  • 2

下载文件:添加链接描述

tar -zxvf openmpi-4.0.5.tar.gz
cd openmpi-4.0.5
./configure --prefix="/usr/local/openmpi"
sudo make
sudo make install
sudo gedit ~/.bashrc
     添加:
         export PATH="$PATH:/usr/local/openmpi/bin"
         export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/usr/local/openmpi/lib/"
sudo ldconfig  或者 source ~/.bashrc
mpirun   # 测试
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11

pip安装

文件下载:whl

三、CNTK环境测试【逻辑回归为例】

本次使用VisualStudio 2022 CNTK2.7版本弄一个测试的工程

1.准备上述dll依赖库

直接放入到工程上

2.本次工程目录结构如下

按照如下引入依赖库
在这里插入图片描述

3.程序引入命名空间

using CNTK;
  • 1

4.程序主入口

static void Main(string[] args)
{
    //逻辑回归输入3个,输出2个
    int inputDim = 3;
    int numOutputClasses = 2;
    //使用GPU
    var device = DeviceDescriptor.GPUDevice(0);
    //设置输入变量及输出变量
    Variable featureVariable = Variable.InputVariable(new int[] { inputDim }, DataType.Float);
    Variable labelVariable = Variable.InputVariable(new int[] { numOutputClasses }, DataType.Float);
    //创建简单的全连接层
    var classifierOutput = CreateLinearModel(featureVariable, numOutputClasses, device);
    //使用API获得softmax的结果及误差
    var loss = CNTKLib.CrossEntropyWithSoftmax(classifierOutput, labelVariable);
    var evalError = CNTKLib.ClassificationError(classifierOutput, labelVariable);
    //学习率的设置
    TrainingParameterScheduleDouble learningRatePerSample = new TrainingParameterScheduleDouble(0.02, 1);
    IList<Learner> parameterLearners = new List<Learner>() { Learner.SGDLearner(classifierOutput.Parameters(), learningRatePerSample) };
    var trainer = Trainer.CreateTrainer(classifierOutput, loss, evalError, parameterLearners);
    //批量大小
    int minibatchSize = 64;
    //训练次数
    int numMinibatchesToTrain = 10000;
    //最小更新大小
    int updatePerMinibatches = 10;

    // 循环训练数据
    for (int minibatchCount = 0; minibatchCount < numMinibatchesToTrain; minibatchCount++)
    {
        Value features, labels;
        GenerateValueData(minibatchSize, inputDim, numOutputClasses, out features, out labels, device);
        
#pragma warning disable 618
        trainer.TrainMinibatch(new Dictionary<Variable, Value>() { { featureVariable, features }, { labelVariable, labels } }, device);
#pragma warning restore 618
        PrintTrainingProgress(trainer, minibatchCount, updatePerMinibatches);
    }

    // 测试数据
    int testSize = 100;
    Value testFeatureValue, expectedLabelValue;
    GenerateValueData(testSize, inputDim, numOutputClasses, out testFeatureValue, out expectedLabelValue, device);

    IList<IList<float>> expectedOneHot = expectedLabelValue.GetDenseData<float>(labelVariable);
    IList<int> expectedLabels = expectedOneHot.Select(l => l.IndexOf(1.0F)).ToList();

    var inputDataMap = new Dictionary<Variable, Value>() { { featureVariable, testFeatureValue } };
    var outputDataMap = new Dictionary<Variable, Value>() { { classifierOutput.Output, null } };
    classifierOutput.Evaluate(inputDataMap, outputDataMap, device);
    var outputValue = outputDataMap[classifierOutput.Output];
    IList<IList<float>> actualLabelSoftMax = outputValue.GetDenseData<float>(classifierOutput.Output);
    var actualLabels = actualLabelSoftMax.Select((IList<float> l) => l.IndexOf(l.Max())).ToList();
    int misMatches = actualLabels.Zip(expectedLabels, (a, b) => a.Equals(b) ? 0 : 1).Sum();

    Console.WriteLine($"Validating Model: Total Samples = {testSize}, Misclassify Count = {misMatches}");
    Console.ReadLine();
}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57

5.程序测试效果

在这里插入图片描述

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/繁依Fanyi0/article/detail/655567
推荐阅读
相关标签
  

闽ICP备14008679号