当前位置:   article > 正文

【阅读笔记】联邦学习实战——联邦学习视觉案例_联邦学习实例

联邦学习实例

前言

FATE是微众银行开发的联邦学习平台,是全球首个工业级的联邦学习开源框架,在github上拥有近4000stars,可谓是相当有名气的,该平台为联邦学习提供了完整的生态和社区支持,为联邦学习初学者提供了很好的环境,否则利用python从零开发,那将会是一件非常痛苦的事情。本篇博客内容涉及《联邦学习实战》第十章内容,使用的fate版本为1.6.0,fate的安装已经在这篇博客中介绍,有需要的朋友可以点击查阅。下面就让我们开始吧。


1. 概述

随着算法的提升,大数据和硬件算力的发展,人工智能在视觉领域出现爆发性的增长,以目标检测为例,主要步骤如下:

  1. 收集数据集存放到中心数据库中。
  2. 进行集中的数据预处理,包括图片清洗、标注等。
  3. 利用预处理的数据进行中心化的模型训练。
  4. 将训练的模型部署到客户端。

但是传统的深度学习容易受到以下因素影响:

  • 数据隐私:在特殊领域(银行、医疗),每个客户采集的数据都具有高度隐私性,无法有效共享。另外,机器学习模型效果非常依赖数据的数量和质量,单点建模会降低模型效果。
  • 模型更新:各个数据源之间由于网络和设备的差异,导致数据同步不一致,对于实时响应的场景,中心化的训练模式无法满足。
  • 数据不均匀:每个数据源的数据分布、质量、大小各不相同。

2. 案例描述

本案例对分散在各地的摄像头数据,通过联邦学习,构建一个联邦分布式训练网络,摄像头数据无需上传,便可以协同训练目标检测模型,这样一方面用户的隐私数据不会被泄露,另一方面,充分利用参与方的训练数据,提升机器学习视觉模型的识别效果。

3. 目标检测算法概述

当前常见的计算机视觉任务可以归纳为图像分类、语义分割、目标检测、实例分割,区别如下图所示。
在这里插入图片描述

本案例场景为典型的目标检测任务。本节简单回顾目标检测任务的算法步骤。

3.1 边框线与锚框

边界线: 描述目标位置,是一个矩形框,由左上角坐标 ( x 1 , y 1 ) (x_1,y_1) (x1,y1)和右下角坐标 ( x 2 , y 2 ) (x_2,y_2) (x2,y2)共同决定。
锚框: YOLO系列算法定义锚框来提取候选区域,锚框以每个像素为中心,生成多个大小宽高比不同的边界框集合。如下图所示

在这里插入图片描述

3.2 交并比

交并比: 当多个边界框覆盖了图像中物体,如果该物体的真实边界框已知,那么需要一个衡量预测边界框好坏的指标,在目标检测领域,使用交互比(IOU)衡量。
假设有两个边界框A和B,则A和B的IOU为二者的相交面积和相并面积的比值。

I O U ( A , B ) = A ∩ B A ∪ B IOU(A,B)=\frac{A\cap B}{A\cup B} IOU(A,B)=ABAB

3.3 基于候选区域的目标检测算法

基于候选区域的目标检测算法包括R-CNN、Fast R-CNN、Faster R-CNN等,这类算法在求解目标检测任务时,分为两个阶段:第一阶段先产生所有可能的目标候选框,第二阶段再对所有候选框做分类与回归。因此这类算法也被称为二阶段算法。

  • R-CNN:先对图像提取大约2000个候选区域,然后将候选区域输入到CNN网络中,提取每个候选框的特征数据,每个候选框的特征数据与其类别一起构成一个样本,训练多个支持向量机对目标分类,其中每个支持向量机用来判断样本是否属于同一个类别,利用每个候选框的特征数据与其边界框一起构成一个样本,用来训练线性回归模型,并预测真实的边界框。在这里插入图片描述

  • Fast R-CNN:R-CNN的瓶颈在于,候选区域大量重叠,导致单独提取特征出现大量重复计算,所以Fast R-CNN先将图片输入CNN中,得到特征图,在特征图上进行候选区选取工作,并用softmax代替支持向量机,加快训练速度。由于每个候选区域大小不同,得到的特征向量长度不一,所以使用ROI池化将不同大小的输入转变为固定的大小长度。在这里插入图片描述

  • Faster R-CNN:虽然Fast R-CNN相比R-CNN有了很大的提升,但是候选区域的提取与目标检测仍然是两个独立过程,因此,Faster R-CNN在此基础上,提出了候选区域网络(RPN),将候选区域的提取与目标检测作为同一个网络进行端到端的训练。在这里插入图片描述

3.4 单阶段目标检测

仅仅使用一个卷积神经网络直接预测不同目标的分类与位置,不需要预先选取候选区域,因此在效果上,基于区域的算法要比单阶段算法准确度高,但速度慢,相反,单阶段算法速度快,但准确性低,典型的单阶段算法包括SSD,YOLO系列。
以YOLO为例,不需要先找出所有的候选框,而是直接将图片输入到模型中,最后直接得到边界框的位置及物体的标签信息,并且它将边界框定位与目标分类都看成回归问题。这样做到端到端的处理,以Pascal VOC数据集为例,处理步骤如下:

  1. 将图片裁剪为448×448×3大小作为输入,并且将图片分割得到7×7的网格,模型的输出是一个7×7×30维的输出,即每个网格都对应一个30维向量。首先一个网格负责预测一个物体,当一个物体的中心点在网格内时,我们就说这个网格负责预测这个物体。每个网格会生成两个边界框来预测这个物体,每个边界框由一个5元组确定 ( x , y , w , h , c ) (x,y,w,h,c) (x,y,w,h,c),其中 ( x , y ) (x,y) (x,y)代表边界框的中心坐标, w w w代表边界框的宽, h h h代表边界框的高, c c c代表边界框的物体属于哪个类别。
  2. 对标签进行转化。Pascal VOC数据集共有20种不同类别输出的概率,为此每个网格需要一个20维大小的额外向量来存放网格预测不同类别输出的概率。所以7×7×(2×5+20)=7×7×30。
  3. 构建损失函数,利用梯度下降求解网络。包括类别预测损失、边界框坐标损失、置信度分数的预测损失。

4. 基于联邦学习的目标检测网络

4.1 动机

对模型提供方和数据提供方来说,安全威胁是当前最为头疼和亟待解决的问题。安全威胁主要来自数据层面:

  • 数据离开本地后,数据提供方无法追踪数据的用途。
  • 数据上传过程中面临重重泄露风险。

因此,急需一种新的模型训练方法:数据保证不离开本地,并且模型性能不能受到影响。这两点都非常适合联邦学习。

4.2 FedVision-联邦视觉产品

对于一个横向联邦学习实现的目标检测模型的工作流程,以本案为例,基本设置如下:

  • 参与方设置为三方:A,B,C。
  • 设置三个参与方数据分布均衡。
  • 每个参与方在本地,对数据进行预处理,发起联邦学习任务,参与任务,模型本地预测和推断。
  • 服务端实时监控连接情况,对上传数据聚合,挑选客户端参与本地训练,上传全局模型。
  • 训练好的模型,可以分发给参与方,也可以以商业形式售卖。

在这里插入图片描述

基于联邦学习的目标检测视觉模型对集中式模型的优势:

  • 隐私性:数据隐私安全大为提高。
  • 效率:多方训练,速度提高。
  • 费用:上传模型参数相对于传输图像视频来说有效节省网络带宽。

5. 方法实现

书中实现方法有基于Flask-SocketIO的python实现,也有基于FATE实现,这里主要介绍python实现过程。

5.1 Flask-SocketIO基础

Flask-SocketIO作为服务端和客户端之间的通信框架,可以轻松实现服务端和客户端的双向通信。
首先安装SocketIO库,只需在命令行中输入:

$ pip install flask-socketio
  • 1
  • 服务端:首先初始化服务端。
from flask import Flask, render_template
from flask_socketio import SocketIO

app = Flask(__name__)
app.config['SECRET_KEY'] = 'secret!'
socketio = SocketIO(app)

if __name__=='__main__':
    socketio.run(app)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9

socketio.run()是服务器启动的接口,通过封装app.run()实现。这段代码没有任何功能,为了能够相应用户请求,需要定义必要的函数。如下创建一个“my event”事件,代码如下:

from flask import Flask, render_template
from flask_socketio import SocketIO

app = Flask(__name__)
app.config['SECRET_KEY'] = 'secret!'
socketio = SocketIO(app)


@socketio.on('my event')

def test_message(message):
    emit('my response', {'data':message['data']})

if __name__=='__main__':
    socketio.run(app)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15

事件创建后,服务端等待客户发送“my event”请求,此外,socketIO是双向通信,所以服务端还能向客户端发送请求,用emit和send(命名事件用前者,未命名用后者)。

  • 客户端:更为灵活,使用多种语言的socketIO官方客户端库或者兼容的客户端,与上面的服务端建立连接。
from socketIO_client import SocketIO

def test_response(data):
    print(data)
    
sio = SocketIO('localhost', 5000, None)
sio.on("my_response", test_response)
sio.emit("my event")
sio.wait()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9

先用socketIO创建一个客户端,构造函数需要提供端口号和服务器IP,然后利用on连接事件“my_response”,以及处理函数“test_response”,发送“my event”事件,等待服务端事件响应。

联邦客户端与服务端之间的详细通信过程如下:
在这里插入图片描述

5.2 服务端设计

服务端主体如下:

  • 模型的聚合。
  • 客户端选取和模型分发。
  • 网络监听。

构建一个服务端类,在类结构的构造函数中,定义部分变量如下:

class FLServer(object):
    def __init__(self, task_config_filename, host, port):
        self.task_config = load_json(task_config_filename)
        self.ready_client_sids = set()

        self.app = Flask(__name__)
        self.socketio = SocketIO(self.app, ping_timeout=3600000,
                                 ping_interval=3600000,
                                 max_http_buffer_size=int(1e32))
        self.host = host
        self.port = port
        self.model_id = str(uuid.uuid4())
        self.aggregator = Aggregator(self.task_config, self.logger)
        ...
        self.register_handles()
         
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16

相对于第3章的服务端设计,本章的服务端更为复杂,主要增加了socket通信的信息,一些字段解析如下:

  • task_config:保存配置信息。
  • ready_client_sids:记录每轮客户端ID集合。
  • socket_io:利用Flask-SocketIO创建的服务端I/O。
  • host和port:服务端当前host信息和端口信息。
  • aggregator:模型聚合,当前联邦学习聚合策略。

构造函数之后是register_handles函数,用于事件注册,即响应客户端的请求。

def register_handles(self):
    # single-threaded async, no need to lock

    @self.socketio.on('connect')
    def handle_connect():
        print(request.sid, "connected")

    @self.socketio.on('reconnect')
    def handle_reconnect():
        print(request.sid, "reconnected")

    @self.socketio.on('disconnect')
    def handle_disconnect():
        print(request.sid, "disconnected")
        if request.sid in self.ready_client_sids:
            self.ready_client_sids.remove(request.sid)

    @self.socketio.on('client_wake_up')
    def handle_wake_up():
        print("client wake_up: ", request.sid)
        emit('init')

    @self.socketio.on('client_ready')
    def handle_client_ready():
        print("client ready for training", request.sid)
        self.ready_client_sids.add(request.sid)
        if len(self.ready_client_sids) >= self.MIN_NUM_WORKERS and self.current_round == -1:
            print("start to federated learning.....")
            self.check_client_resource()
        elif len(self.ready_client_sids) < self.MIN_NUM_WORKERS:
            print("not enough client worker running.....")
        else:
            print("current_round is not equal to -1, please restart server.")
    ...
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34

服务端创建完毕等待客户端发送信号,接收到客户端信号后,将它们全放置在候选列表ready_client_sids中,每一轮训练会随机挑选部分客户端参与下一轮的迭代。

client_sids_selected = random.sample(list(self.ready_client_sids), self.NUM_CLIENTS_CONTACTED_PER_ROUND)
  • 1

服务端另一个主要功能是进行模型聚合,如下是FedAvg的实现,我们将每轮上传的客户端模型参数放置到model_weights中,选择本地样本数量占全体样本数量的比例作为模型参数的权重,求取新的全局模型参数值。

def update_weights(self, client_weights, client_sizes):
    total_size = np.sum(client_sizes)
    new_weights = [np.zeros(param.shape) for param in client_weights[0]]
    for c in range(len(client_weights)):
        for i in range(len(new_weights)):
            new_weights[i] += (client_weights[c][i] * client_sizes[c]
                               / total_size)
    self.current_weights = new_weights
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

5.3 客户端设计

构造函数主体如下:

class FederatedClient(object):
    MAX_DATASET_SIZE_KEPT = 6000

    def __init__(self, server_host, server_port, task_config_filename,
                 gpu, ignore_load):
        os.environ['CUDA_VISIBLE_DEVICES'] = '%d' % gpu
        self.task_config = load_json(task_config_filename)
        # self.data_path = self.task_config['data_path']
        print(self.task_config)
        self.ignore_load = ignore_load

        self.local_model = None
        self.dataset = None
        ...
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14

在联邦学习中,客户端与服务端是双向通信的,因此需要客户端注册相应的事件函数,用于响应服务端发送事件请求处理。

def register_handles(self):
    ########## Socket IO messaging ##########
    def on_connect():
        print('connect')

    def on_disconnect():
        print('disconnect')

    def on_reconnect():
        print('reconnect')

    def on_request_update(*args):
    ...


    self.sio.on('connect', on_connect)
    self.sio.on('disconnect', on_disconnect)
    self.sio.on('reconnect', on_reconnect)
    self.sio.on('init', self.on_init)
    self.sio.on('request_update', on_request_update)
    self.sio.on('stop_and_eval', on_stop_and_eval)
    self.sio.on('check_client_resource', on_check_client_resource)

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23

on是一个接口函数,参数是事件名称和对应的响应函数。
客户端创建完毕后,等待服务端下发初始化命令,服务端会下发初始的全局模型和配置信息给客户端,客户端初始化主要是将本地模型替换全局模型,同时利用配置信息读取本地训练数据集。

def on_init(self, request):
    print('on init')
    self.local_model = LocalModel(self.task_config)
    print("local model initialized done.")
    # ready to be dispatched for training
    self.sio.emit('client_ready')
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

客户端另一个重要环节是本地训练,通常情况和本地训练没有太大区别,这里不再赘述,感兴趣的朋友参考官方代码。

6. 性能分析

本章最后部分对两个模型在联邦学习中的性能进行了测试,分别测试了它们在不同数量客户参与方(C)以及不同本地训练迭代次数(E)配置下的性能对比,可以看到,参与方越多,其迭代收敛也越快(这是书中原话,但笔者认为并不绝对)。
在这里插入图片描述下图是两个模型在损失值上的对比,可以得出:

  • 随着客户端增多,刚开始迭代的效果会低于集中式训练的效果。主要受数据不平衡的影响。
  • 迭代到一定轮次,全局模型效果逼近集中式训练效果。
    在这里插入图片描述

阅读总结

本章内容涉及CV领域的目标检测内容,还是比较好理解的,只不过在运行代码的过程中,由于官方代码不全,导致运行不起来,实属遗憾,有时间一定斟酌一下,找到遗漏的的文件。然后FATE的实现文中并没有介绍,但是给了github链接,感兴趣的朋友可以复现一下,我也尽量能够出期FATE进行联邦目标检测实例的博客。接下来的第11章,FL在物联网的应用,应该还是理论居多,就让我们继续吧!

参考链接

https://blog.csdn.net/tinyzhao/article/details/53729006
https://blog.csdn.net/tinyzhao/article/details/53742626
https://github.com/FederatedAI/Practicing-Federated-Learning/tree/main/chapter10_Computer_Vision

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小小林熬夜学编程/article/detail/221356
推荐阅读
相关标签
  

闽ICP备14008679号