当前位置:   article > 正文

Golang实现一个批量自动化执行树莓派指令的软件(3)下载

Golang实现一个批量自动化执行树莓派指令的软件(3)下载

简介

话接上篇 Golang实现一个批量自动化执行树莓派指令的软件(2)指令, 这次实现文件的下载。

环境描述

运行环境: Windows, 基于Golang, 暂时没有使用什么不可跨平台接口, 理论上支持Linux/MacOS
目标终端:树莓派DebianOS(主要做用它测试)

实现

接口定义

type IDownloader interface {
	/*
		Download 下载的同步接口, 会堵塞执行
			from : 下载的路径
			to   : 保存的路径
	*/
	Download(from, to string) error
	/*
		Download 下载的同步/异步接口
			from : 下载的路径
			to   : 保存的路径
			processCallback : 进度回调函数,每次下载文件的时候被调用, 返回当前下载进度信息
					from : 当前下载的文件路径
					to   : 当前文件保存路径
					downloadNumber : 下载的文件总数
					downloaded     : 已下载的文件数
			finishedCallback : 完成下载时调用
			background : 表示是同步执行还是异步执行
	*/
	DownloadWithCallback(from, to string,
		processCallback func(from, to string, downloadNumber, downloaded uint),
		finishedCallback func(err error), background bool) error
}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23

接口实现

package sshutil

import (
	"fmt"
	"github.com/pkg/sftp"
	"os"
	"path"
	"time"
)

type downloader struct {
	client       *sftp.Client
	downloadSize uint
	downNumber   uint
	downloaded   uint

	started  bool
	canceled chan struct{}
}

func newDownloader(client *sftp.Client) (*downloader, error) {
	return &downloader{client: client, canceled: make(chan struct{})}, nil
}

func (d *downloader) Download(from, to string) error {
	return d.download(from, to, nil, nil)
}

func (d *downloader) DownloadWithCallback(from, to string,
	processCallback func(from, to string, downloadNumber, downloaded uint),
	finishedCallback func(err error), background bool) error {
	if !background {
		return d.download(from, to, processCallback, finishedCallback)
	} else {
		go d.download(from, to, processCallback, finishedCallback)
	}
	return nil
}

func (d *downloader) Cancel() error {
	if d.started {
		select {
		case d.canceled <- struct{}{}:
		case <-time.After(time.Second * 2): // 取消时间过长,取消失败
			return fmt.Errorf("time out waiting for cancel")
		}
	}
	return nil
}

func (d *downloader) Destroy() error {
	err := d.Cancel()
	close(d.canceled)
	return err
}

func (d *downloader) downloadFolderCount(remotePath string) (needDownload, size uint, err error) {
	var c, s uint
	infos, _ := d.client.ReadDir(remotePath)
	for _, info := range infos {
		if info.IsDir() {
			c, s, err = d.downloadFolderCount(path.Join(remotePath, info.Name()))
			if nil != err {
				return
			}
			needDownload += c
			size += s
			continue
		}
		size += uint(info.Size())
		needDownload += 1
	}
	err = nil
	return
}

func (d *downloader) downloadFileCount(remotePath string) (needDownload, size uint, err error) {
	info, err := d.client.Stat(remotePath)
	if nil != err {
		return 0, 0, err
	}

	if info.IsDir() {
		return d.downloadFolderCount(remotePath)
	}

	return 1, uint(info.Size()), nil
}

func (d *downloader) download(remotePath, localPath string,
	processCallback func(from, to string, downloadNumber, downloaded uint),
	finishedCallback func(err error)) (err error) {

	whenErrorCall := func(e error) error {
		if nil != finishedCallback {
			go finishedCallback(e)
		}
		return e
	}

	d.started = true
	defer func() {
		d.started = false
	}()

	d.downNumber, d.downloadSize, err = d.downloadFileCount(remotePath)
	if nil != err {
		return whenErrorCall(err)
	}

	err = os.MkdirAll(localPath, 0777)
	if nil != err {
		if !os.IsExist(err) {
			return whenErrorCall(err)
		}
	}

	info, err := d.client.Stat(remotePath)
	if nil != err {
		return whenErrorCall(err)
	}
	if info.IsDir() {
		return d.downloadFolder(remotePath, localPath, processCallback, finishedCallback)
	}
	return d.downloadFile(remotePath, localPath, processCallback, finishedCallback)
}

func (d *downloader) downloadFile(remotePath, localPath string,
	processCallback func(from, to string, downloadNumber, downloaded uint),
	finishedCallback func(err error)) (err error) {
	var (
		srcFile       *sftp.File
		dstFile       *os.File
		info          os.FileInfo
		localFileName = path.Join(localPath, path.Base(remotePath))
	)

	whenErrorCall := func(e error) error {
		if nil != finishedCallback {
			go finishedCallback(e)
		}
		return e
	}

	info, err = d.client.Stat(remotePath)
	if nil != err {
		return whenErrorCall(err)
	}

	/*
		这里是解决在下载0KB的文件时,sftp.Open接口会一直堵塞, 所以判定0KB文件直接创建就好, 有兴趣这里可以进行简化
	*/
	if 0 >= info.Size() {
		dstFile, err = os.OpenFile(localFileName, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0777)
		if nil != err {
			return whenErrorCall(err)
		}
		dstFile.Close()
		return whenErrorCall(err)
	}

	srcFile, err = d.client.Open(remotePath)
	if err != nil {
		return whenErrorCall(err)
	}
	defer srcFile.Close()

	dstFile, err = os.OpenFile(localFileName, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0777)
	if err != nil {
		return whenErrorCall(err)
	}
	defer dstFile.Close()
	if _, err = srcFile.WriteTo(dstFile); err != nil {
		return whenErrorCall(err)
	}
	select {
	case <-d.canceled:
		return whenErrorCall(fmt.Errorf("user canceled"))
	default:
	}
	d.downloaded += 1
	if nil != processCallback {
		go processCallback(remotePath, localFileName, d.downNumber, d.downloaded)
	}

	return whenErrorCall(err)
}

func (d *downloader) downloadFolder(remotePath, localPath string,
	processCallback func(from, to string, downloadNumber, downloaded uint),
	finishedCallback func(err error)) (err error) {

	whenErrorCall := func(e error) error {
		if nil != finishedCallback {
			go finishedCallback(e)
		}
		return e
	}

	err = os.MkdirAll(localPath, 0777)
	if nil != err {
		return whenErrorCall(err)
	}

	infos, err := d.client.ReadDir(remotePath)
	for _, info := range infos {
		remoteFilePath := path.Join(remotePath, info.Name())
		if info.IsDir() {
			localFilePath := path.Join(localPath, info.Name())

			err = d.downloadFolder(remoteFilePath, localFilePath, processCallback, nil)
			if nil != err {
				return whenErrorCall(err)
			}
		} else {
			err = d.downloadFile(remoteFilePath, localPath, processCallback, nil)
			if nil != err {
				return err
			}
		}
	}

	return whenErrorCall(err)
}

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
  • 159
  • 160
  • 161
  • 162
  • 163
  • 164
  • 165
  • 166
  • 167
  • 168
  • 169
  • 170
  • 171
  • 172
  • 173
  • 174
  • 175
  • 176
  • 177
  • 178
  • 179
  • 180
  • 181
  • 182
  • 183
  • 184
  • 185
  • 186
  • 187
  • 188
  • 189
  • 190
  • 191
  • 192
  • 193
  • 194
  • 195
  • 196
  • 197
  • 198
  • 199
  • 200
  • 201
  • 202
  • 203
  • 204
  • 205
  • 206
  • 207
  • 208
  • 209
  • 210
  • 211
  • 212
  • 213
  • 214
  • 215
  • 216
  • 217
  • 218
  • 219
  • 220
  • 221
  • 222
  • 223
  • 224
  • 225

测试用例

package sshutil

import (
	"fmt"
	"github.com/pkg/sftp"
	"golang.org/x/crypto/ssh"
	"sync"
	"testing"
	"time"
)

type downloaderTest struct {
	sshClient  *ssh.Client
	sftpClient *sftp.Client

	downloader *downloader
}

func newDownloadTest() (*downloaderTest, error) {
	var (
		err   error
		dTest = &downloaderTest{}
	)
	config := ssh.ClientConfig{
		User:            "pi",                                      // 用户名
		Auth:            []ssh.AuthMethod{ssh.Password("a123456")}, // 密码
		HostKeyCallback: ssh.InsecureIgnoreHostKey(),
		Timeout:         10 * time.Second,
	}
	dTest.sshClient, err = ssh.Dial("tcp", "192.168.3.2:22", &config) //IP + 端口
	if err != nil {
		fmt.Print(err)
		return nil, err
	}
	if dTest.sftpClient, err = sftp.NewClient(dTest.sshClient); err != nil {
		dTest.destroy()
		return nil, err
	}

	dTest.downloader, err = newDownloader(dTest.sftpClient)

	return dTest, err
}

func (d *downloaderTest) destroy() {
	if nil != d.sftpClient {
		d.sftpClient.Close()
		d.sftpClient = nil
	}

	if nil != d.sshClient {
		d.sshClient.Close()
		d.sshClient = nil
	}
}

func TestDownloader_Download(t *testing.T) {
	var dTest, err = newDownloadTest()
	if nil != err {
		fmt.Println("fail to new download test!")
		return
	}
	defer dTest.destroy()

	err = dTest.downloader.Download("/home/pi/", "./download")
	if nil != err {
		fmt.Println(err)
	}
}

func TestDownloader_DownloadWithCallback(t *testing.T) {
	var dTest, err = newDownloadTest()
	if nil != err {
		fmt.Println("fail to new download test!")
		return
	}
	defer dTest.destroy()

	err = dTest.downloader.DownloadWithCallback("/home/pi/", "./download1", func(from, to string, downloadNumber, downloaded uint) {
		fmt.Println(from, to, downloadNumber, downloaded)
	}, func(err error) {
		fmt.Println("finished!!!")
	}, false)

	if nil != err {
		fmt.Println(err)
	}
	fmt.Println("sleping...")
	time.Sleep(time.Second * 1) // process 在download内部是异步调用, 所以这里延时使process内部的打印有时间执行完
}

func TestDownloader_DownloadWithCallbackAsync(t *testing.T) {
	var waiter sync.WaitGroup
	var dTest, err = newDownloadTest()
	if nil != err {
		fmt.Println("fail to new download test!")
		return
	}
	defer dTest.destroy()
	waiter.Add(1)
	err = dTest.downloader.DownloadWithCallback("/home/pi/", "./download2/", func(from, to string, downloadNumber, downloaded uint) {
		fmt.Println(from, to, downloadNumber, downloaded)
	}, func(err error) {
		fmt.Println("finished!!!")
		waiter.Done()
	}, true)

	if nil != err {
		fmt.Println(err)
	}
	fmt.Println("waiting....")
	waiter.Wait()
	fmt.Println("done!!!")
	time.Sleep(time.Second * 1) // process 在download内部是异步调用, 所以这里延时使process内部的打印有时间执行完
}

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116

代码源

https://gitee.com/grayhsu/ssh_remote_access

其他

参考

本文内容由网友自发贡献,转载请注明出处:【wpsshop博客】
推荐阅读
相关标签
  

闽ICP备14008679号