Files
qsdk/schema/scopes.go
2026-05-01 11:03:19 +08:00

108 lines
2.9 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package schema
import "gorm.io/gorm"
// ScopeTsCode 按 ts_code 精确过滤;空字符串则不加条件。
func ScopeTsCode(tsCode string) func(*gorm.DB) *gorm.DB {
return func(db *gorm.DB) *gorm.DB {
if tsCode == "" {
return db
}
return db.Where("ts_code = ?", tsCode)
}
}
// ScopeTsCodes 按 ts_code IN (...)nil 或空切片不加条件。
func ScopeTsCodes(tsCodes []string) func(*gorm.DB) *gorm.DB {
return func(db *gorm.DB) *gorm.DB {
if len(tsCodes) == 0 {
return db
}
return db.Where("ts_code IN ?", tsCodes)
}
}
// ScopeTradeDateEQ 按 trade_date 等于0 表示不加条件。
func ScopeTradeDateEQ(tradeDate int) func(*gorm.DB) *gorm.DB {
return func(db *gorm.DB) *gorm.DB {
if tradeDate == 0 {
return db
}
return db.Where("trade_date = ?", tradeDate)
}
}
// ScopeTradeDateBetween trade_date 区间 [start,end];仅传一侧时做单边约束;均为 0 不加条件。
func ScopeTradeDateBetween(start, end int) func(*gorm.DB) *gorm.DB {
return func(db *gorm.DB) *gorm.DB {
switch {
case start > 0 && end > 0:
return db.Where("trade_date BETWEEN ? AND ?", start, end)
case start > 0:
return db.Where("trade_date >= ?", start)
case end > 0:
return db.Where("trade_date <= ?", end)
default:
return db
}
}
}
// ScopeStockDailyTsDate 日线ts_code + 交易日(任一为空则该项不限制)。
func ScopeStockDailyTsDate(tsCode string, tradeDate int) func(*gorm.DB) *gorm.DB {
return func(db *gorm.DB) *gorm.DB {
return db.Scopes(ScopeTsCode(tsCode), ScopeTradeDateEQ(tradeDate))
}
}
// ScopeStockIndicatorTsDate 指标表ts_code + trade_date。
func ScopeStockIndicatorTsDate(tsCode string, tradeDate int) func(*gorm.DB) *gorm.DB {
return func(db *gorm.DB) *gorm.DB {
return db.Scopes(ScopeTsCode(tsCode), ScopeTradeDateEQ(tradeDate))
}
}
// ScopeFinaTsPeriod 财务指标ts_code + period与 uniqueIndex un_fi_code_date 一致period 为 0 时不限制 period。
func ScopeFinaTsPeriod(tsCode string, period int) func(*gorm.DB) *gorm.DB {
return func(db *gorm.DB) *gorm.DB {
db = db.Scopes(ScopeTsCode(tsCode))
if period != 0 {
db = db.Where("period = ?", period)
}
return db
}
}
// ScopeBlocksIndexCode 板块 code。
func ScopeBlocksIndexCode(code string) func(*gorm.DB) *gorm.DB {
return func(db *gorm.DB) *gorm.DB {
if code == "" {
return db
}
return db.Where("code = ?", code)
}
}
// ScopeBlocksMemberPair 板块 ti_code + 成分 stock_code。
func ScopeBlocksMemberPair(tiCode, stockCode string) func(*gorm.DB) *gorm.DB {
return func(db *gorm.DB) *gorm.DB {
if tiCode != "" {
db = db.Where("ti_code = ?", tiCode)
}
if stockCode != "" {
db = db.Where("stock_code = ?", stockCode)
}
return db
}
}
// ScopeMoneyTotalCode 资金流 code。
func ScopeMoneyTotalCode(code string) func(*gorm.DB) *gorm.DB {
return func(db *gorm.DB) *gorm.DB {
if code == "" {
return db
}
return db.Where("code = ?", code)
}
}