由 SuKai March 21, 2022
在前面文章中介绍了golang开发rpc命令行工具,今天继续后续功能的介绍。
场景描述:
1,客户端运行一个daemon程序,执行文件上传任务,使用boltdb数据库记录任务执行状态。支持继续上传未完成文件。
2,客户端命令行通过rpc调用客户端daemon程序API,进行文件上传任务管理,包括任务创建,查看,启停。
通过这个功能开发,我们将golang并发编程的goroutine协程,channel通道,context上下文, Mutex互斥锁, WaitGroup协程同步等技术得到应用。
基本知识:
WaiGroup
等待一组协程执行完成后继续向下执行,WaitGroup内部有一个计数器,从0开始计数,有3个方法:Add(),Done(), Wait()。Add()添加计数,Done()减掉一个计数,Wait()执行阻塞,直到WaitGroup数量变成0。
Select
select和channel配合使用,通过select可以监听多个channel的I/O读写事件。
select {
case <-ctx.Done():
p.log.Error("the context error \n")
return context.Canceled
default:
}
如果没有default分支,select会阻塞在多个channel上,对多个channel进行监控。如果有default分支,多个channel都没有满足,则执行default分支。
Context
Golang的Context称之为上下文,用来跟踪goroutine关系链,传递通知,达到控制他们的目的。主要用法是,传递取消信号,传递数据。
下面是一个传递取消信号的使用过程,首先context.Background()返回一个空的Context,一般用于整个context tree的根节点。context.WithCancel()返回一个ctx可取消的Sub Context,作为Run的参数传入goroutine,这样可以使用ctx跟踪这个Goroutine。cancel()调用Sub Context的取消函数,向关联的Goroutine发送一个"取消"通知。在Run函数中,接收ctx.Done的cancel通知,做相关清理后退出。
ctx, cancel := context.WithCancel(context.Background())
go pool.Run(ctx)
cancel()
for {
select {
case <-ctx.Done():
return
}
}
Mutext互斥锁
Mutext互斥锁在同一时间只被一个goroutine访问,不区分读写。有两个方法:Lock()和Unlock()。当一个goroutine申请了Lock(),那么另一个goroutine申请Lock()时会阻塞等待直到Unlock()释放锁。
multipart/form请求
multipart/form请求是http Post方法,可以发送文件和消息,在请求的Header中包含一个特殊头信息Content-Type: multipart/form-data; boundary=,boundary的值为随机计算生成的值,用于分隔上传多个form-data的间隔。
POST /raw/v1alpha2/rawdatas/6214cb78ff3f2903536f5751/images HTTP/1.1
Host: 127.0.0.1:9090
Content-Length: 546
Content-Type: multipart/form-data; boundary=----WebKitFormBoundary7MA4YWxkTrZu0gW
----WebKitFormBoundary7MA4YWxkTrZu0gW
Content-Disposition: form-data; name="files"; filename="/C:/Users/ycsk0/Pictures/ubuntu-install.png"
Content-Type: image/png
(data)
----WebKitFormBoundary7MA4YWxkTrZu0gW
Content-Disposition: form-data; name="files"; filename="/C:/Users/ycsk0/Pictures/windows store.png"
Content-Type: image/png
(data)
----WebKitFormBoundary7MA4YWxkTrZu0gW
Content-Disposition: form-data; name="files"; filename="/C:/Users/ycsk0/Pictures/wsl-ubuntu.png"
Content-Type: image/png
(data)
----WebKitFormBoundary7MA4YWxkTrZu0gW
工作池
worker pool就是线程池thread pool,在go中对应的就是goroutine协程。在线程池模型中,包括:任务队列,已完成任务队列和线程池。完成任务队列,根据实际情况判断是否需要。
代码
当命令行执行daemon命令时,执行server方法,启动daemon进程。
可以看到:
1,context.Background()作为协程树的根
2,context.WithCancel创建ctx一个可取消的sub context,
3,go pool.Run(ctx)开启一个协程,入参ctx。
4,监听进程信号sigs,关闭时执行cancel()取消函数,所有的goroutine都会同步收到这一取消信号。
func (sc *serverCmd) server(cmd *cobra.Command, args []string) error {
sigs := make(chan os.Signal, 1)
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
logger, _ := zap.NewProduction()
serializer := boltdb.JSONSerializer{}
boltdbClient, _ := boltdb.NewClient(sc.database, serializer)
taskStore := store.NewTaskStore(boltdbClient, serializer)
imageStore := store.NewImageStore(boltdbClient, serializer)
pool := grpcserver.NewPool(5, taskStore, imageStore, logger, sc.serverURL)
ctx, cancel := context.WithCancel(context.Background())
go pool.Run(ctx)
server := grpcserver.NewRPCServer(taskStore, imageStore, sc.baseBuilderCmd.port, logger, pool)
go server.Start()
logger.Info("the ailabel daemon is started")
<-sigs
cancel()
pool.Close()
server.Stop()
boltdbClient.Close()
logger.Info("the ailabel daemon is closed")
return nil
}
启动线程池
Pool中包含:mutex互斥锁,tasks任务队列,WaitGroup用于等待所有线程池结束。taskQueue记录线程池中的任务。stopTaskID保存停止的任务。
在Run方法中,开启定时器,每10秒发送一个信号。开启workersCount数量的线程池,每开启一个waitGroup计数器加1。监听ctx.Done()取消信号,监听定时器信号,定时查询boltdb中的任务执行分发任务。
type Pool struct {
sync.Mutex
workersCount int
wg sync.WaitGroup
tasks chan store.Task
stopCmd chan StopCmd
results chan Result
taskQueue []uint64
stopTaskID []uint64
log *zap.Logger
taskStore *store.TaskStore
imageStore *store.ImageStore
aiLabelClient *ailabel.Ailabel
}
func NewPool(workCount int, taskStore *store.TaskStore, imageStore *store.ImageStore, logger *zap.Logger, serverURL string) *Pool {
return &Pool{
workersCount: workCount,
tasks: make(chan store.Task, workCount),
stopCmd: make(chan StopCmd, workCount),
results: make(chan Result, workCount),
log: logger,
taskStore: taskStore,
imageStore: imageStore,
aiLabelClient: ailabel.CreateAilabelClient(nil, serverURL, 10),
}
}
func (p *Pool) Run(ctx context.Context) {
defer p.log.Info("Stopping worker pool")
ticker := time.NewTicker(10 * time.Second)
defer ticker.Stop()
for i := 0; i < p.workersCount; i++ {
p.wg.Add(1)
go p.worker(ctx)
}
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
tasks, err := p.taskStore.ListPendingInprogressTasks()
if err != nil {
p.log.Error("failed to list pending tasks")
continue
}
p.Dispatch(tasks)
}
}
}
分发任务
申请排斥锁,防止多次分发相同任务,task发送给pool的tasks缓冲channel,同时在taskQueue中记录任务。
func (p *Pool) Dispatch(tasks []store.Task) {
p.Lock()
defer p.Unlock()
for _, task := range tasks {
if sliceutil.HasUint64(p.taskQueue, task.ID) {
p.log.Info("task exists, ignore", zap.String("task", fmt.Sprintf("%d", task.ID)))
continue
} else if len(p.taskQueue) == p.workersCount {
p.log.Info("task limit size exceed")
return
}
p.taskQueue = append(p.taskQueue, task.ID)
p.tasks <- task
p.log.Info("task created successfully", zap.String("task", fmt.Sprintf("%d", task.ID)))
}
}
线程池
当线程退出时,执行p.wg.Done()减少waitGroup计数器数量。
这里监听四个channel:
1,p.stopCmd,当命令行停止任务时,在p.stopTaskID记录停止任务的任务ID
2,p.tasks,读取任务队列中一条任务进行处理。
3,p.results,读取任务返回的结果消息进行处理。这里申请了一个p.Lock()锁,保证前一个结果消息处理结束,再处理下一个结果消息,防止结果数据更新错误。
4,ctx.Done(),当根ctx执行取消,接收取消信号,当前线程能出。
func (p *Pool) worker(ctx context.Context) {
defer p.wg.Done()
for {
select {
case stop, ok := <-p.stopCmd:
if !ok {
return
}
p.log.Info("stop task", zap.String("task", fmt.Sprintf("%d", stop.TaskID)))
p.stopTaskID = append(p.stopTaskID, stop.TaskID)
case task, ok := <-p.tasks:
if !ok {
return
}
p.log.Info("recieve task", zap.String("task", fmt.Sprintf("%d", task.ID)))
if task.Status != store.TaskInProgress {
if task.Status == store.TaskLoading || task.Status == store.TaskStoppedLoading {
if err := p.ClearTaskData(ctx, &task); err != nil {
if err != nil {
p.log.Error("clear task data error", zap.String("task", fmt.Sprintf("%d", task.ID)), zap.Error(err))
continue
}
}
}
err := p.LoadTaskData(ctx, &task)
if err != nil {
p.log.Error("Task load error", zap.String("task", fmt.Sprintf("%d", task.ID)), zap.Error(err))
continue
}
}
err := p.ProcessTask(ctx, task.ID)
if err != nil {
p.log.Error("Task process error", zap.String("task", fmt.Sprintf("%d", task.ID)), zap.Error(err))
continue
}
case result, ok := <-p.results:
if !ok {
return
}
p.Lock()
task, err := p.taskStore.GetTask(result.TaskID)
if err != nil {
p.log.Error("read task failed", zap.String("task", fmt.Sprintf("%d", task.ID)))
}
if result.Err != nil {
task.Reason = result.Err.Error()
} else {
task.Reason = ""
}
task.IsPuase = task.Reason == store.TaskStoppedError
if result.Status > 0 {
task.Status = result.Status
} else if result.StartKey > 0 && result.EndKey > 0 {
task.StartKey, task.EndKey = result.StartKey, result.EndKey
}
err = p.taskStore.UpdateTask(task)
if err != nil {
p.log.Error("update task failed", zap.String("task", fmt.Sprintf("%d", task.ID)))
}
if task.Status != store.TaskLoading && task.Status != store.TaskInProgress {
p.taskQueue = sliceutil.RemoveUint64(p.taskQueue, func(item uint64) bool {
return item == task.ID
})
}
p.Unlock()
p.log.Info("recieved task result",
zap.String("task", fmt.Sprintf("%d", result.TaskID)),
zap.String("status", fmt.Sprintf("%d", result.Status)))
case <-ctx.Done():
return
}
}
}
任务处理
这里可以看到:
1,每一个业务处理过程中,都有一个判断ctx.Done(),是否协程被取消。取消则退出。
2,每一个业务处理过程中,都有一个判断sliceutil.HasUint64(p.stopTaskID, task.ID),任务是否在停止任务的slice中,有则退出。因为我们可以通过命令行暂停某一任务的执行,而在pool中无法知道某一任务是哪一个goroutine协程在执行的,所以这里通过查询并判断是否是自己在执行任务。
3,ProcessTask中会查询并等待task.StartKey有数据时再继续业务逻辑,因为LoadTaskData执行后发送到results channel消息处理会有时间延时,所以等待LoadTaskData results中的startkey更新后再继续。
4,在ProcessTask中,读取Boltdb中的图片数据时,因为通过startKey, endKey来进行seek偏移读取数据的,而在boltdb中,key保存的是字节码,所以我们根据ID读取会存在数据变多的情况。比如我们读取整型的749-751会返回749,75,750。
5,批量上传文件PostFiles
func (p *Pool) ProcessTask(ctx context.Context, taskID uint64) error {
var task *store.Task
var lastImage store.Image
var err error
for task, err = p.taskStore.GetTask(taskID); task.StartKey == 0; time.Sleep(time.Microsecond * 500) {
if err != nil {
return err
}
task, err = p.taskStore.GetTask(taskID)
}
p.results <- Result{
TaskID: task.ID,
Status: store.TaskInProgress,
Err: nil,
}
// the boltdb key is seeked by bytes order, range 749 - 751 will return 749 75 750
for startKey, endKey := task.StartKey, task.EndKey; startKey < endKey; startKey = lastImage.ID + 1 {
select {
case <-ctx.Done():
p.log.Error("the context error \n")
return context.Canceled
default:
}
if sliceutil.HasUint64(p.stopTaskID, task.ID) {
return errors.New(store.TaskStoppedError)
}
images, err := p.imageStore.ListBatchPendingImagesByTaskID(task.ID, startKey, endKey)
if err != nil || len(images) == 0 {
return err
}
lastImage = images[len(images)-1]
var params = make(map[string]string)
buildParams := map[string]string{}
buildParams["json"] = string(makeJson(params))
b, _ := json.Marshal(buildParams)
payload := bytes.NewBuffer(b)
_, err = p.aiLabelClient.Requester.PostFiles(fmt.Sprintf("/raw/v1alpha2/rawdatas/%s/images", task.RawDataID), payload, nil, nil, images)
if err != nil {
return err
}
p.SetImageStatus(images, store.ImageUploaded)
}
p.results <- Result{
TaskID: task.ID,
Status: store.TaskFinished,
Err: nil,
}
return nil
}
func (p *Pool) ClearTaskData(ctx context.Context, task *store.Task) error {
return p.imageStore.DeleteImagesByTaskID(task.ID)
}
func (p *Pool) LoadTaskData(ctx context.Context, task *store.Task) error {
var startKey, endKey uint64
var walkFun filepath.WalkFunc = func(f string, info os.FileInfo, err error) error {
select {
case <-ctx.Done():
p.log.Error("the context error \n")
return context.Canceled
default:
}
if sliceutil.HasUint64(p.stopTaskID, task.ID) {
return errors.New(store.TaskStoppedError)
}
if !info.IsDir() {
var format archiver.Extractor
var sourceAchive *os.File
switch filepath.Ext(f) {
case ".rar":
format = archiver.Rar{}
case ".zip":
format = archiver.Zip{}
case ".png", ".jpeg", ".jpg", ".pcd":
var key uint64
image := &store.Image{
TaskID: task.ID,
FileName: f,
IsAchive: false,
Status: store.ImagePending,
}
if key, err = p.imageStore.CreateImage(image); err != nil {
p.log.Error("create image failed", zap.String("image", fmt.Sprintf("%s", image.FileName)))
return err
}
if startKey == 0 {
startKey = key
}
endKey = key
default:
p.log.Info("unknown file extension", zap.String("image", fmt.Sprintf("%s", f)))
}
if format != nil {
sourceAchive, err = os.Open(f)
err = format.Extract(context.Background(), sourceAchive, nil, func(ctx context.Context, f archiver.File) error {
if !f.IsDir() {
if sliceutil.HasString([]string{".png", ".jpeg", ".jpg", ".pcd"}, filepath.Ext(f.Name())) {
var key uint64
image := &store.Image{
TaskID: task.ID,
FileName: f.NameInArchive,
IsAchive: true,
AchiveFileName: sourceAchive.Name(),
Status: store.ImagePending,
}
if key, err = p.imageStore.CreateImage(image); err != nil {
p.log.Error("create image failed", zap.String("image", fmt.Sprintf("%s", image.FileName)))
return err
}
if startKey == 0 {
startKey = key
}
endKey = key
}
}
return nil
})
if err != nil {
return err
}
sourceAchive.Close()
}
}
return nil
}
if _, err := os.Stat(task.Path); err != nil {
p.results <- Result{
TaskID: task.ID,
Status: store.TaskFailed,
Err: err,
}
return err
}
p.results <- Result{
TaskID: task.ID,
Status: store.TaskLoading,
Err: nil,
}
err := filepath.Walk(task.Path, walkFun)
if err != nil {
if err == context.Canceled {
p.results <- Result{
TaskID: task.ID,
Err: err,
}
p.taskQueue = sliceutil.RemoveUint64(p.taskQueue, func(item uint64) bool {
return item == task.ID
})
return err
}
if err.Error() == errors.New(store.TaskStoppedError).Error() {
p.results <- Result{
TaskID: task.ID,
Status: store.TaskStoppedLoading,
Err: err,
}
p.stopTaskID = sliceutil.RemoveUint64(p.stopTaskID, func(item uint64) bool {
return item == task.ID
})
return err
}
p.results <- Result{
TaskID: task.ID,
Status: store.TaskFailed,
Err: err,
}
return err
}
p.results <- Result{
TaskID: task.ID,
StartKey: startKey,
EndKey: endKey + 1,
Err: nil,
}
return nil
}
批量上传
使用mime/multipart库上传文件的基本过程
1,创建http client,ai.Requester.Client = http.DefaultClient
2,生成请求体body,通过multipart.NewWriter创建一个multipart写接口,将formdata数据和文件写入到body缓冲区中。CreateFormFile在字段名为"file"字段中添加一个文件,io.copy将源文件数据写入到CreateFormFile的文件中。writer.WriteField将k/v写入上传数据中。
2,创建http request,req, err = http.NewRequest(ar.Method, URL.String(), body)
3,添加request头信息,req.Header.Set(“Content-Type”, writer.FormDataContentType()),writer会自动生成文件之间的分隔字符。
4,发送请求,response, err := r.Client.Do(req)
func CreateAilabelClient(client *http.Client, base string, maxConnection int, auth ...interface{}) *Ailabel {
ai := &Ailabel{}
if strings.HasSuffix(base, "/") {
base = base[:len(base)-1]
}
ai.Server = base
ai.Requester = &Requester{Base: base, SslVerify: true, Client: client, connControl: make(chan struct{}, maxConnection)}
if ai.Requester.Client == nil {
ai.Requester.Client = http.DefaultClient
}
if len(auth) == 2 {
ai.Requester.BasicAuth = &BasicAuth{Username: auth[0].(string), Password: auth[1].(string)}
}
return ai
}
func (r *Requester) PostFiles(endpoint string, payload io.Reader, responseStruct interface{}, querystring map[string]string, files interface{}) (*http.Response, error) {
ar := NewAPIRequest("POST", endpoint, payload)
return r.Do(ar, responseStruct, querystring, files)
}
func (r *Requester) Do(ar *APIRequest, responseStruct interface{}, options ...interface{}) (*http.Response, error) {
if !strings.HasSuffix(ar.Endpoint, "/") && ar.Method != "POST" {
ar.Endpoint += "/"
}
fileUpload := false
archiveUpload := false
var files []string
var archiveFiles = make(map[string][]string)
URL, err := url.Parse(r.Base + ar.Endpoint + ar.Suffix)
if err != nil {
return nil, err
}
for _, o := range options {
switch v := o.(type) {
case map[string]string:
querystring := make(url.Values)
for key, val := range v {
querystring.Set(key, val)
}
URL.RawQuery = querystring.Encode()
case []string:
fileUpload = true
files = v
case []store.Image:
for _, image := range v {
if image.IsAchive {
archiveUpload = true
archiveFiles[image.AchiveFileName] = append(archiveFiles[image.AchiveFileName], image.FileName)
} else {
fileUpload = true
files = append(files, image.FileName)
}
}
}
}
var req *http.Request
if fileUpload || archiveUpload {
body := &bytes.Buffer{}
writer := multipart.NewWriter(body)
for _, file := range files {
fileData, err := os.Open(file)
if err != nil {
return nil, err
}
part, err := writer.CreateFormFile("file", filepath.Base(file))
if err != nil {
return nil, err
}
if _, err = io.Copy(part, fileData); err != nil {
return nil, err
}
defer fileData.Close()
}
for archiveFile, images := range archiveFiles {
var format archiver.Extractor
var sourceAchive *os.File
switch filepath.Ext(archiveFile) {
case ".rar":
format = archiver.Rar{}
case ".zip":
format = archiver.Zip{}
}
sourceAchive, err = os.Open(archiveFile)
if err != nil {
return nil, err
}
err = format.Extract(context.Background(), sourceAchive, images, func(ctx context.Context, f archiver.File) error {
part, err := writer.CreateFormFile("file", filepath.Base(f.Name()))
if err != nil {
return err
}
fileData, err := f.Open()
if err != nil {
return err
}
if _, err = io.Copy(part, fileData); err != nil {
return err
}
defer fileData.Close()
return nil
})
if err != nil {
return nil, err
}
}
var params map[string]string
json.NewDecoder(ar.Payload).Decode(¶ms)
for key, val := range params {
if err = writer.WriteField(key, val); err != nil {
return nil, err
}
}
if err = writer.Close(); err != nil {
return nil, err
}
req, err = http.NewRequest(ar.Method, URL.String(), body)
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", writer.FormDataContentType())
} else {
req, err = http.NewRequest(ar.Method, URL.String(), ar.Payload)
if err != nil {
return nil, err
}
}
if r.BasicAuth != nil {
req.SetBasicAuth(r.BasicAuth.Username, r.BasicAuth.Password)
}
req.Close = true
req.Header.Add("Accept", "*/*")
for k := range ar.Headers {
req.Header.Add(k, ar.Headers.Get(k))
}
r.connControl <- struct{}{}
if response, err := r.Client.Do(req); err != nil {
<-r.connControl
return nil, err
} else {
<-r.connControl
errorText := response.Header.Get("X-Error")
if errorText != "" {
return nil, errors.New(errorText)
}
err := CheckResponse(response)
if err != nil {
return nil, err
}
switch responseStruct.(type) {
case *string:
return r.ReadRawResponse(response, responseStruct)
default:
return r.ReadJSONResponse(response, responseStruct)
}
}
}
执行
PS G:\ai\ailabel> go run .\cmd\client\client.go task create --path G:\ailabel --rawDataID 6214dd1bb690a882006cb9d1 -p 8080
PS G:\ai\ailabel> go run .\cmd\client\client.go task list -p 8080
ID RAWDATAID PATH STATUS REASON ISPUASE STARTKEY ENDKEY
1 6214dd1bb690a882006cb9d1 G:\ailabel 3 false 1 13
PS G:\ai\ailabel> go run .\cmd\client\client.go task create --path G:\ailabel --rawDataID 6214dd1bb690a882006cb9d1 -p 8080
PS G:\ai\ailabel> go run .\cmd\client\client.go task list -p 8080
ID RAWDATAID PATH STATUS REASON ISPUASE STARTKEY ENDKEY
1 6214dd1bb690a882006cb9d1 G:\ailabel 3 false 1 13
2 6214dd1bb690a882006cb9d1 G:\ailabel 3 false 13 25
PS G:\ai\ailabel>
PS G:\ai\ailabel> go run .\cmd\client\client.go image list -p 8080
ID TASKID FILENAME ISACHIVE ACHIVEFILENAME STATUS
1 1 G:\ailabel\apt.png false 1
2 1 G:\ailabel\test\apt.png false 0
3 1 G:\ailabel\test\ubuntu-install.png false 0
4 1 G:\ailabel\test\windows store.png false 0
5 1 G:\ailabel\test\wsl-ubuntu.png false 0
6 1 test/apt.png true G:\ailabel\test.rar 0
7 1 test/ubuntu-install.png true G:\ailabel\test.rar 0
8 1 test/windows store.png true G:\ailabel\test.rar 0
9 1 test/wsl-ubuntu.png true G:\ailabel\test.rar 0
10 1 G:\ailabel\ubuntu-install.png false 1
11 1 G:\ailabel\windows store.png false 1
12 1 G:\ailabel\wsl-ubuntu.png false 1
13 2 G:\ailabel\apt.png false 1
14 2 G:\ailabel\test\apt.png false 1
15 2 G:\ailabel\test\ubuntu-install.png false 1
16 2 G:\ailabel\test\windows store.png false 1
17 2 G:\ailabel\test\wsl-ubuntu.png false 1
18 2 test/apt.png true G:\ailabel\test.rar 1
19 2 test/ubuntu-install.png true G:\ailabel\test.rar 1
20 2 test/windows store.png true G:\ailabel\test.rar 1
21 2 test/wsl-ubuntu.png true G:\ailabel\test.rar 1
22 2 G:\ailabel\ubuntu-install.png false 1
23 2 G:\ailabel\windows store.png false 1
24 2 G:\ailabel\wsl-ubuntu.png false 1