fix bug
This commit is contained in:
@@ -1,6 +1,7 @@
|
|||||||
package strategy
|
package strategy
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
@@ -12,38 +13,54 @@ func Boot() {
|
|||||||
InitCacheByAll()
|
InitCacheByAll()
|
||||||
}
|
}
|
||||||
|
|
||||||
func BootAiStart(key string, ymd int) {
|
// 启动 AI 分析任务
|
||||||
|
func BootAiStart(key string, ymd int) error {
|
||||||
var datas []models.StratModel
|
var datas []models.StratModel
|
||||||
err := impl.DBService.Where("strat_key=? and ymd=? and ai_score=0", key, ymd).Find(&datas).Error
|
err := impl.DBService.Where("strat_key=? and ymd=? and ai_score=0", key, ymd).Find(&datas).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
log.Printf("Failed to query data: %v", err)
|
||||||
|
return fmt.Errorf("query failed: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create a buffered channel with a capacity of 20 to act as a semaphore
|
// 构造任务列表
|
||||||
semaphore := make(chan struct{}, 20)
|
var tasks []func()
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
|
||||||
for _, row := range datas {
|
for _, row := range datas {
|
||||||
|
row := row // 避免闭包捕获循环变量
|
||||||
|
tasks = append(tasks, func() {
|
||||||
|
BootAiTask(row.ID, row.Code)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// 使用并发控制执行任务
|
||||||
|
runWithLimit(20, tasks)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 执行单个 AI 分析任务
|
||||||
|
func BootAiTask(id uint, code string) {
|
||||||
|
result, err := AiAnalysis(code)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("ERROR BootAiTask - ID: %d, Code: %s, Error: %v", id, code, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Printf("SUCCESS BootAiTask - ID: %d, Code: %s", id, code)
|
||||||
|
impl.DBService.Model(&models.StratModel{}).Where("id=?", id).Updates(result)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 通用并发控制函数
|
||||||
|
func runWithLimit(limit int, tasks []func()) {
|
||||||
|
semaphore := make(chan struct{}, limit)
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
|
||||||
|
for _, task := range tasks {
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
// Acquire a slot in the semaphore
|
semaphore <- struct{}{} // 获取信号量
|
||||||
semaphore <- struct{}{}
|
go func(t func()) {
|
||||||
go func(row models.StratModel) {
|
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
// Release the slot in the semaphore when done
|
defer func() { <-semaphore }() // 释放信号量
|
||||||
defer func() { <-semaphore }()
|
t()
|
||||||
BootAiTask(row.ID, row.Code, &wg)
|
}(task)
|
||||||
}(row)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
func BootAiTask(id uint, code string, wg *sync.WaitGroup) {
|
|
||||||
defer wg.Done()
|
|
||||||
result, err := AiAnalysis(code)
|
|
||||||
if err != nil {
|
|
||||||
log.Println("ERROR BootAiTask", err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
impl.DBService.Model(&models.StratModel{}).Where("id=?", id).Updates(result)
|
|
||||||
}
|
|
||||||
|
|||||||
Reference in New Issue
Block a user