赞
踩
话接上篇 Golang实现一个批量自动化执行树莓派指令的软件(2)指令, 这次实现文件的下载。
运行环境: Windows, 基于Golang, 暂时没有使用什么不可跨平台接口, 理论上支持Linux/MacOS
目标终端:树莓派DebianOS(主要做用它测试)
type IDownloader interface { /* Download 下载的同步接口, 会堵塞执行 from : 下载的路径 to : 保存的路径 */ Download(from, to string) error /* Download 下载的同步/异步接口 from : 下载的路径 to : 保存的路径 processCallback : 进度回调函数,每次下载文件的时候被调用, 返回当前下载进度信息 from : 当前下载的文件路径 to : 当前文件保存路径 downloadNumber : 下载的文件总数 downloaded : 已下载的文件数 finishedCallback : 完成下载时调用 background : 表示是同步执行还是异步执行 */ DownloadWithCallback(from, to string, processCallback func(from, to string, downloadNumber, downloaded uint), finishedCallback func(err error), background bool) error }
package sshutil import ( "fmt" "github.com/pkg/sftp" "os" "path" "time" ) type downloader struct { client *sftp.Client downloadSize uint downNumber uint downloaded uint started bool canceled chan struct{} } func newDownloader(client *sftp.Client) (*downloader, error) { return &downloader{client: client, canceled: make(chan struct{})}, nil } func (d *downloader) Download(from, to string) error { return d.download(from, to, nil, nil) } func (d *downloader) DownloadWithCallback(from, to string, processCallback func(from, to string, downloadNumber, downloaded uint), finishedCallback func(err error), background bool) error { if !background { return d.download(from, to, processCallback, finishedCallback) } else { go d.download(from, to, processCallback, finishedCallback) } return nil } func (d *downloader) Cancel() error { if d.started { select { case d.canceled <- struct{}{}: case <-time.After(time.Second * 2): // 取消时间过长,取消失败 return fmt.Errorf("time out waiting for cancel") } } return nil } func (d *downloader) Destroy() error { err := d.Cancel() close(d.canceled) return err } func (d *downloader) downloadFolderCount(remotePath string) (needDownload, size uint, err error) { var c, s uint infos, _ := d.client.ReadDir(remotePath) for _, info := range infos { if info.IsDir() { c, s, err = d.downloadFolderCount(path.Join(remotePath, info.Name())) if nil != err { return } needDownload += c size += s continue } size += uint(info.Size()) needDownload += 1 } err = nil return } func (d *downloader) downloadFileCount(remotePath string) (needDownload, size uint, err error) { info, err := d.client.Stat(remotePath) if nil != err { return 0, 0, err } if info.IsDir() { return d.downloadFolderCount(remotePath) } return 1, uint(info.Size()), nil } func (d *downloader) download(remotePath, localPath string, processCallback func(from, to string, downloadNumber, downloaded uint), finishedCallback func(err error)) (err error) { whenErrorCall := func(e error) error { if nil != finishedCallback { go finishedCallback(e) } return e } d.started = true defer func() { d.started = false }() d.downNumber, d.downloadSize, err = d.downloadFileCount(remotePath) if nil != err { return whenErrorCall(err) } err = os.MkdirAll(localPath, 0777) if nil != err { if !os.IsExist(err) { return whenErrorCall(err) } } info, err := d.client.Stat(remotePath) if nil != err { return whenErrorCall(err) } if info.IsDir() { return d.downloadFolder(remotePath, localPath, processCallback, finishedCallback) } return d.downloadFile(remotePath, localPath, processCallback, finishedCallback) } func (d *downloader) downloadFile(remotePath, localPath string, processCallback func(from, to string, downloadNumber, downloaded uint), finishedCallback func(err error)) (err error) { var ( srcFile *sftp.File dstFile *os.File info os.FileInfo localFileName = path.Join(localPath, path.Base(remotePath)) ) whenErrorCall := func(e error) error { if nil != finishedCallback { go finishedCallback(e) } return e } info, err = d.client.Stat(remotePath) if nil != err { return whenErrorCall(err) } /* 这里是解决在下载0KB的文件时,sftp.Open接口会一直堵塞, 所以判定0KB文件直接创建就好, 有兴趣这里可以进行简化 */ if 0 >= info.Size() { dstFile, err = os.OpenFile(localFileName, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0777) if nil != err { return whenErrorCall(err) } dstFile.Close() return whenErrorCall(err) } srcFile, err = d.client.Open(remotePath) if err != nil { return whenErrorCall(err) } defer srcFile.Close() dstFile, err = os.OpenFile(localFileName, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0777) if err != nil { return whenErrorCall(err) } defer dstFile.Close() if _, err = srcFile.WriteTo(dstFile); err != nil { return whenErrorCall(err) } select { case <-d.canceled: return whenErrorCall(fmt.Errorf("user canceled")) default: } d.downloaded += 1 if nil != processCallback { go processCallback(remotePath, localFileName, d.downNumber, d.downloaded) } return whenErrorCall(err) } func (d *downloader) downloadFolder(remotePath, localPath string, processCallback func(from, to string, downloadNumber, downloaded uint), finishedCallback func(err error)) (err error) { whenErrorCall := func(e error) error { if nil != finishedCallback { go finishedCallback(e) } return e } err = os.MkdirAll(localPath, 0777) if nil != err { return whenErrorCall(err) } infos, err := d.client.ReadDir(remotePath) for _, info := range infos { remoteFilePath := path.Join(remotePath, info.Name()) if info.IsDir() { localFilePath := path.Join(localPath, info.Name()) err = d.downloadFolder(remoteFilePath, localFilePath, processCallback, nil) if nil != err { return whenErrorCall(err) } } else { err = d.downloadFile(remoteFilePath, localPath, processCallback, nil) if nil != err { return err } } } return whenErrorCall(err) }
package sshutil import ( "fmt" "github.com/pkg/sftp" "golang.org/x/crypto/ssh" "sync" "testing" "time" ) type downloaderTest struct { sshClient *ssh.Client sftpClient *sftp.Client downloader *downloader } func newDownloadTest() (*downloaderTest, error) { var ( err error dTest = &downloaderTest{} ) config := ssh.ClientConfig{ User: "pi", // 用户名 Auth: []ssh.AuthMethod{ssh.Password("a123456")}, // 密码 HostKeyCallback: ssh.InsecureIgnoreHostKey(), Timeout: 10 * time.Second, } dTest.sshClient, err = ssh.Dial("tcp", "192.168.3.2:22", &config) //IP + 端口 if err != nil { fmt.Print(err) return nil, err } if dTest.sftpClient, err = sftp.NewClient(dTest.sshClient); err != nil { dTest.destroy() return nil, err } dTest.downloader, err = newDownloader(dTest.sftpClient) return dTest, err } func (d *downloaderTest) destroy() { if nil != d.sftpClient { d.sftpClient.Close() d.sftpClient = nil } if nil != d.sshClient { d.sshClient.Close() d.sshClient = nil } } func TestDownloader_Download(t *testing.T) { var dTest, err = newDownloadTest() if nil != err { fmt.Println("fail to new download test!") return } defer dTest.destroy() err = dTest.downloader.Download("/home/pi/", "./download") if nil != err { fmt.Println(err) } } func TestDownloader_DownloadWithCallback(t *testing.T) { var dTest, err = newDownloadTest() if nil != err { fmt.Println("fail to new download test!") return } defer dTest.destroy() err = dTest.downloader.DownloadWithCallback("/home/pi/", "./download1", func(from, to string, downloadNumber, downloaded uint) { fmt.Println(from, to, downloadNumber, downloaded) }, func(err error) { fmt.Println("finished!!!") }, false) if nil != err { fmt.Println(err) } fmt.Println("sleping...") time.Sleep(time.Second * 1) // process 在download内部是异步调用, 所以这里延时使process内部的打印有时间执行完 } func TestDownloader_DownloadWithCallbackAsync(t *testing.T) { var waiter sync.WaitGroup var dTest, err = newDownloadTest() if nil != err { fmt.Println("fail to new download test!") return } defer dTest.destroy() waiter.Add(1) err = dTest.downloader.DownloadWithCallback("/home/pi/", "./download2/", func(from, to string, downloadNumber, downloaded uint) { fmt.Println(from, to, downloadNumber, downloaded) }, func(err error) { fmt.Println("finished!!!") waiter.Done() }, true) if nil != err { fmt.Println(err) } fmt.Println("waiting....") waiter.Wait() fmt.Println("done!!!") time.Sleep(time.Second * 1) // process 在download内部是异步调用, 所以这里延时使process内部的打印有时间执行完 }
https://gitee.com/grayhsu/ssh_remote_access
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。