增加命令功能

This commit is contained in:
2025-04-07 10:42:08 +08:00
parent 0344c09ce7
commit 74b0033e45
18 changed files with 1844 additions and 494 deletions

View File

@@ -2,27 +2,38 @@ package main
import (
"bufio"
"bytes"
"flag"
"fmt"
"os"
"path"
"runtime"
"strings"
"sync"
"e.coding.net/rta/public/saasapi"
"google.golang.org/protobuf/encoding/protojson"
)
// TODO 转换加速
const (
convertBatchSize = 100000
convertedExt = ".converted"
)
type convertParams struct {
targetCfg *TargetConfig
mapCfg *MapConfig
sourcePath string
destPath string
}
type convertResult struct {
resultBuf bytes.Buffer
convertedLines int
}
func RunConvert(args ...string) error {
fs := flag.NewFlagSet("convert", flag.ExitOnError)
targetCfgFile := paramTargets(fs)
mapCfgFile := paramMap(fs)
sourcePath := paramSourcePath(fs)
destPath := paramDestPath(fs)
@@ -31,19 +42,19 @@ func RunConvert(args ...string) error {
return err
}
if fs.NArg() > 0 || *targetCfgFile == "" || len(*sourcePath) == 0 || len(*destPath) == 0 {
if fs.NArg() > 0 || *mapCfgFile == "" || len(*sourcePath) == 0 || len(*destPath) == 0 {
fs.PrintDefaults()
return nil
}
targetCfg, err := LoadTargetFile(*targetCfgFile)
mapCfg, err := LoadMapFile(*mapCfgFile)
if err != nil {
fmt.Println("LoadConfigFile error", "err", err)
return err
}
convertParams := convertParams{
targetCfg: targetCfg,
mapCfg: mapCfg,
sourcePath: *sourcePath,
destPath: *destPath,
}
@@ -97,7 +108,7 @@ func doFileConvert(convertParams convertParams) error {
os.MkdirAll(convertParams.destPath, os.ModePerm)
}
destName := path.Join(convertParams.destPath, path.Base(convertParams.sourcePath)+".converted")
destName := path.Join(convertParams.destPath, path.Base(convertParams.sourcePath)+convertedExt)
destFile, err := os.Create(destName)
if err != nil {
return err
@@ -108,15 +119,92 @@ func doFileConvert(convertParams convertParams) error {
destWriter := bufio.NewWriter(destFile)
defer destWriter.Flush()
jasonMarshal := protojson.MarshalOptions{Multiline: false, Indent: ""}
// 启动处理协程
workers := []chan []string{}
results := []chan convertResult{}
processedLine := 0
wg := sync.WaitGroup{}
convertMaxWorkers := runtime.GOMAXPROCS(0)
for range convertMaxWorkers {
workerChan := make(chan []string)
workers = append(workers, workerChan)
resultChan := make(chan convertResult)
results = append(results, resultChan)
go func(workerChan <-chan []string, resultChan chan<- convertResult) {
for lines := range workerChan {
convertBatch(lines, convertParams, resultChan)
}
}(workerChan, resultChan)
}
// 启动写入协程
go func() {
i := 0
// TIP: 不要改成for range
for {
select {
case result, ok := <-results[i%convertMaxWorkers]:
if !ok {
return
}
destWriter.Write(result.resultBuf.Bytes())
destWriter.Flush()
processedLine += result.convertedLines
fmt.Printf("\rconverted records: %v [%v]", processedLine, destName)
i++
wg.Done()
}
}
}()
// 读取文件并塞给协程处理
batch := []string{}
batchCount := 0
for scaner.Scan() {
line := scaner.Text()
if line == "" {
continue
}
batch = append(batch, line)
if len(batch) == convertBatchSize {
// 将batch写入协程
wg.Add(1)
workers[batchCount%convertMaxWorkers] <- batch
batch = nil
batchCount++
}
}
if len(batch) > 0 {
wg.Add(1)
workers[batchCount%convertMaxWorkers] <- batch
}
wg.Wait()
// 关闭所有工作协程的通道
for _, workerChan := range workers {
close(workerChan)
}
for _, resultChan := range results {
close(resultChan)
}
fmt.Println("")
return nil
}
func convertBatch(lines []string, convertParams convertParams, resultChan chan<- convertResult) {
byteBuf := bytes.Buffer{}
byteBuf.Grow(1024 * 1024 * 10)
jasonMarshal := protojson.MarshalOptions{Multiline: false, Indent: ""}
for _, line := range lines {
// 按\t分割为两列
parts := strings.Split(line, "\t")
if len(parts) != 2 {
@@ -125,65 +213,63 @@ func doFileConvert(convertParams convertParams) error {
// 读取userid
userid := parts[0]
if len(userid) == 0 {
continue
}
value := parts[1]
value = strings.ReplaceAll(value, "[", "")
value = strings.ReplaceAll(value, "]", "")
// 第二列解析为string数组
targets := strings.Split(value, " ")
saasWriteCmd := &saasapi.WriteCmd{
saasWriteItem := &saasapi.WriteItem{
Userid: userid,
}
if len(userid) == 0 || len(targets) == 0 {
continue
}
// 遍历targets转换成saasapi.WriteCmd
for _, target := range targets {
if targetinfo, ok := convertParams.targetCfg.Targets[target]; ok {
if targetinfo, ok := convertParams.mapCfg.Targets[target]; ok {
if targetinfo.WriteByte != nil {
if saasWriteCmd.WriteBytes == nil {
saasWriteCmd.WriteBytes = &saasapi.Bytes{}
// 转换byte区
if saasWriteItem.WriteBytes == nil {
saasWriteItem.WriteBytes = &saasapi.Bytes{}
}
saasWriteCmd.WriteBytes.Bytes = append(saasWriteCmd.WriteBytes.Bytes, *targetinfo.WriteByte)
saasWriteItem.WriteBytes.Bytes = append(saasWriteItem.WriteBytes.Bytes, *targetinfo.WriteByte)
if targetinfo.WriteBytePos < 64 {
saasWriteCmd.WriteBytes.Index_1 |= 1 << targetinfo.WriteBytePos
saasWriteItem.WriteBytes.Index_1 |= 1 << targetinfo.WriteBytePos
} else if targetinfo.WriteBytePos < 128 {
saasWriteCmd.WriteBytes.Index_2 |= 1 << (targetinfo.WriteBytePos - 64)
saasWriteItem.WriteBytes.Index_2 |= 1 << (targetinfo.WriteBytePos - 64)
}
}
if targetinfo.WriteUint32 != nil {
if saasWriteCmd.WriteUint32S == nil {
saasWriteCmd.WriteUint32S = &saasapi.Uint32S{}
// 转换uint32区
if saasWriteItem.WriteUint32S == nil {
saasWriteItem.WriteUint32S = &saasapi.Uint32S{}
}
saasWriteCmd.WriteUint32S.Uint32S = append(saasWriteCmd.WriteUint32S.Uint32S, *targetinfo.WriteUint32)
saasWriteCmd.WriteUint32S.Index_1 |= 1 << targetinfo.WriteUint32Pos
saasWriteItem.WriteUint32S.Uint32S = append(saasWriteItem.WriteUint32S.Uint32S, *targetinfo.WriteUint32)
saasWriteItem.WriteUint32S.Index_1 |= 1 << targetinfo.WriteUint32Pos
}
if targetinfo.WriteFlag != nil && targetinfo.WriteExpire != nil {
if saasWriteCmd.WriteFlagsWithExpire == nil {
saasWriteCmd.WriteFlagsWithExpire = &saasapi.FlagsWithExpire{}
// 转换flag区
if saasWriteItem.WriteFlagsWithExpire == nil {
saasWriteItem.WriteFlagsWithExpire = &saasapi.FlagsWithExpire{}
}
saasWriteCmd.WriteFlagsWithExpire.FlagsWithExpire = append(
saasWriteCmd.WriteFlagsWithExpire.FlagsWithExpire, &saasapi.FlagWithExpire{
saasWriteItem.WriteFlagsWithExpire.FlagsWithExpire = append(
saasWriteItem.WriteFlagsWithExpire.FlagsWithExpire, &saasapi.FlagWithExpire{
Flag: *targetinfo.WriteFlag,
Expire: *targetinfo.WriteExpire,
})
saasWriteCmd.WriteFlagsWithExpire.Index_1 |= 1 << targetinfo.WriteFlagWithExpirePos
saasWriteItem.WriteFlagsWithExpire.Index_1 |= 1 << targetinfo.WriteFlagWithExpirePos
}
}
}
// 写入文件
destWriter.WriteString(jasonMarshal.Format(saasWriteCmd))
destWriter.WriteString("\n")
processedLine++
if processedLine%100000 == 0 {
fmt.Printf("\rconverted records: %v [%v]", processedLine, destName)
}
byteBuf.WriteString(jasonMarshal.Format(saasWriteItem))
byteBuf.WriteString("\n")
}
fmt.Printf("\rconverted records: %v [%v]\n", processedLine, destName)
return nil
resultChan <- convertResult{byteBuf, len(lines)}
}

View File

@@ -12,21 +12,19 @@ func RunHelp(args ...string) error {
}
const usage = `
Usage: [[command] [arguments]]
The commands are:
Usage: saastool COMMAND [OPTIONS]
Commands:
write Write user's 'bytes / uint32s / flags'
read Read user's 'bytes / uint32s / flags'
columnwrite Write columns for 'deviceid / openid' users
columnwrite Write columns for 'deviceid / openid' users
tasklist List tasks
taskcancel Cancel task
taskdetail Show task detail
convert Convert data to write format
makehash Make file hash for upload task
task Task commands
"help" is the default command.
Use "saastool [command] -help" for more information about a command.
Use "saastool COMMAND -help" for more information about a command.
`
// strip Stripping redundant data from redis

View File

@@ -29,14 +29,13 @@ func Run(args ...string) error {
return RunColumnWrite(args...)
case "convert":
return RunConvert(args...)
case "tasklist":
return RunTaskList(args...)
case "taskcancel":
return RunTaskCancel(args...)
case "taskdetail":
return RunTaskDetail(args...)
case "makehash":
return RunMakeHash(args...)
case "verify":
return RunVerify(args...)
case "task":
return RunTask(args...)
default:
err := fmt.Errorf(`unknown command "%s"`+"\n"+`Run 'saastool help' for usage`, name)
slog.Warn(err.Error())

215
cmd/saastool/make_hash.go Normal file
View File

@@ -0,0 +1,215 @@
package main
import (
"crypto/sha256"
"encoding/hex"
"flag"
"fmt"
"os"
"path"
"runtime"
"sort"
"sync"
"e.coding.net/rta/public/saasapi"
"google.golang.org/protobuf/encoding/protojson"
)
const (
blockSizeMin = 10 * 1024 * 1024
blockSizeMax = 200 * 1024 * 1024
)
type makeHashParams struct {
sourcePath string
destPath string
task *saasapi.Task
}
// 计算任务
type hashTask struct {
chunk []byte
index int
}
// 计算结果
type hashResult struct {
hash string
blockSize uint64
index int
}
func RunMakeHash(args ...string) error {
fs := flag.NewFlagSet("tasklocalmake", flag.ExitOnError)
sourcePath := paramSourcePath(fs)
destPath := paramDestPath(fs)
blockSize := paramBlockSize(fs)
if err := fs.Parse(args); err != nil {
fmt.Println("command line parse error", "err", err)
return err
}
if fs.NArg() > 0 || len(*sourcePath) == 0 || len(*destPath) == 0 {
fs.PrintDefaults()
return nil
}
if blockSize < blockSizeMin || blockSize > blockSizeMax {
fmt.Println("block size error", "min", blockSizeMin, "max", blockSizeMax)
return nil
}
makeHashParams := makeHashParams{
sourcePath: *sourcePath,
destPath: *destPath,
task: &saasapi.Task{
TaskBlockSize: blockSize,
},
}
return doMakeHash(makeHashParams)
}
func doMakeHash(makeHashParams makeHashParams) error {
fsInfo, err := os.Stat(makeHashParams.sourcePath)
if err != nil {
return err
}
if !fsInfo.IsDir() {
// 如果是文件,直接计算
return doFileHash(makeHashParams)
}
// 读取目录下信息
dirEntry, err := os.ReadDir(makeHashParams.sourcePath)
if err != nil {
return err
}
// 遍历目录
for _, dir := range dirEntry {
newParam := makeHashParams
newParam.sourcePath = path.Join(makeHashParams.sourcePath, dir.Name())
if dir.IsDir() {
newParam.destPath = path.Join(makeHashParams.destPath, dir.Name())
}
if err = doMakeHash(newParam); err != nil {
return err
}
}
return saveTaskFile(makeHashParams)
}
func doFileHash(makeHashParams makeHashParams) error {
sourceFile, err := os.Open(makeHashParams.sourcePath)
if err != nil {
return err
}
defer sourceFile.Close()
fi, err := sourceFile.Stat()
if err != nil {
return err
}
tasks := make(chan hashTask)
results := make(chan hashResult)
// 启动工作协程
hashMaxWorker := runtime.GOMAXPROCS(0)
for range hashMaxWorker {
go hashWorker(tasks, results)
}
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
index := 0
buffer := make([]byte, makeHashParams.task.TaskBlockSize)
for {
n, err := sourceFile.Read(buffer)
if n > 0 {
wg.Add(1)
fmt.Printf("\rhashing file [%v], block [%v]", makeHashParams.sourcePath, index)
tasks <- hashTask{chunk: buffer[:n], index: index}
index++
}
if err != nil {
break
}
}
close(tasks)
wg.Done()
}()
var allResults []hashResult
go func() {
for r := range results {
allResults = append(allResults, r)
wg.Done()
}
}()
wg.Wait()
close(results)
// 按索引排序结果
sort.Slice(allResults, func(i, j int) bool {
return allResults[i].index < allResults[j].index
})
// 输出结果
fileInfo := &saasapi.FileInfo{
FileName: makeHashParams.sourcePath,
FileSize: uint64(fi.Size()),
}
for _, r := range allResults {
fileInfo.FileBlocks = append(fileInfo.FileBlocks, &saasapi.FileBlock{
BlockSha256: r.hash,
BlockLength: r.blockSize,
})
}
makeHashParams.task.TaskFileInfos = append(makeHashParams.task.TaskFileInfos, fileInfo)
fmt.Println("")
return nil
}
// hash计算协程
func hashWorker(tasks <-chan hashTask, results chan<- hashResult) {
for t := range tasks {
h := sha256.New()
h.Write(t.chunk)
hash := hex.EncodeToString(h.Sum(nil))
results <- hashResult{hash: hash, index: t.index, blockSize: uint64(len(t.chunk))}
}
}
func saveTaskFile(makeHashParams makeHashParams) error {
taskFile, err := os.Create(makeHashParams.destPath)
if err != nil {
return err
}
defer taskFile.Close()
h := sha256.New()
for _, fileInfo := range makeHashParams.task.TaskFileInfos {
for _, fileBlock := range fileInfo.FileBlocks {
h.Write([]byte(fileBlock.BlockSha256))
}
}
makeHashParams.task.TaskSha256 = hex.EncodeToString(h.Sum(nil))
_, err = taskFile.WriteString(protojson.Format(makeHashParams.task))
if err != nil {
return err
}
return nil
}

View File

@@ -2,31 +2,50 @@ package main
import (
"flag"
"fmt"
"strconv"
"strings"
)
func paramConfig(fs *flag.FlagSet) *string {
return fs.String("config", "cfg.toml", "Config file.")
}
func paramTargets(fs *flag.FlagSet) *string {
return fs.String("targets", "", "target setting")
func paramMap(fs *flag.FlagSet) *string {
return fs.String("map", "", "target map setting")
}
func paramSourcePath(fs *flag.FlagSet) *string {
return fs.String("source", "", "Data path source for write command.")
return fs.String("source", "", "Source path or filename")
}
func paramDestPath(fs *flag.FlagSet) *string {
return fs.String("dest", "", "Data path destination for write command.")
return fs.String("dest", "", "Destination path or filename")
}
func paramAppid(fs *flag.FlagSet) *string {
return fs.String("appid", "", "Wechat appid")
}
func paramUserids(fs *flag.FlagSet) *string {
return fs.String("userids", "", "Device ID or Wechat UserID, separated by comma")
}
func paramBatchSize(fs *flag.FlagSet) *uint {
return fs.Uint("batchsize", 10000, "Batch size to sync")
}
func paramBlockSize(fs *flag.FlagSet) uint64 {
bsize := fs.String("blocksize", "50M", "Block size to make hash. using size mode K, M, G, T")
num, err := ParseByteSize(*bsize)
if err != nil {
fmt.Println("Error parsing block size", "err", err)
fmt.Println("Using default 50M")
num = 50 * 1024 * 1024
}
return num
}
func paramAsync(fs *flag.FlagSet) *bool {
return fs.Bool("async", false, "Async mode")
}
@@ -34,3 +53,55 @@ func paramAsync(fs *flag.FlagSet) *bool {
func paramClear(fs *flag.FlagSet) *bool {
return fs.Bool("clear", false, "Clear all data before write")
}
// ParseByteSize 解析字节大小字符串为字节数
func ParseByteSize(sizeStr string) (uint64, error) {
sizeStr = strings.TrimSpace(sizeStr)
unit := ""
numStr := sizeStr
// 提取单位
if len(sizeStr) > 1 && (sizeStr[len(sizeStr)-1] == 'B' || sizeStr[len(sizeStr)-1] == 'b') {
unit = string(sizeStr[len(sizeStr)-2:])
numStr = sizeStr[:len(sizeStr)-2]
} else if len(sizeStr) > 0 && (sizeStr[len(sizeStr)-1] >= 'A' && sizeStr[len(sizeStr)-1] <= 'Z' ||
sizeStr[len(sizeStr)-1] >= 'a' && sizeStr[len(sizeStr)-1] <= 'z') {
unit = string(sizeStr[len(sizeStr)-1])
numStr = sizeStr[:len(sizeStr)-1]
}
// 解析数字部分
num, err := strconv.ParseFloat(numStr, 64)
if err != nil {
return 0, err
}
// 根据单位计算字节数
switch strings.ToUpper(unit) {
case "":
return uint64(num), nil
case "K", "KB":
return uint64(num * 1024), nil
case "M", "MB":
return uint64(num * 1024 * 1024), nil
case "G", "GB":
return uint64(num * 1024 * 1024 * 1024), nil
case "T", "TB":
return uint64(num * 1024 * 1024 * 1024 * 1024), nil
default:
return 0, fmt.Errorf("unknown unit: %s", unit)
}
}
/*
func main() {
sizes := []string{"1K", "2M", "3G", "4T", "5"}
for _, sizeStr := range sizes {
size, err := ParseByteSize(sizeStr)
if err != nil {
fmt.Printf("Error parsing %s: %v\n", sizeStr, err)
} else {
fmt.Printf("%s = %d bytes\n", sizeStr, size)
}
}
}
*/

View File

@@ -1,5 +1,95 @@
package main
import (
"flag"
"fmt"
"log/slog"
"net/http"
"strings"
"e.coding.net/rta/public/saasapi"
"e.coding.net/rta/public/saasapi/pkg/saashttp"
"google.golang.org/protobuf/encoding/protojson"
)
const (
getIdsMax = 100
)
type readParams struct {
cfg *Config
appid string
userids []string
saasHttp *saashttp.SaasClient
}
func RunRead(args ...string) error {
fs := flag.NewFlagSet("read", flag.ExitOnError)
cfgFile := paramConfig(fs)
appid := paramAppid(fs)
userids := paramUserids(fs)
if err := fs.Parse(args); err != nil {
fmt.Println("command line parse error", "err", err)
return err
}
// 切割字符串
idsSlice := strings.Split(*userids, ",")
if fs.NArg() > 0 || len(idsSlice) == 0 || (len(idsSlice) == 1 && idsSlice[0] == "") || len(idsSlice) > getIdsMax {
fs.PrintDefaults()
return nil
}
cfg, err := LoadConfigFile(*cfgFile)
if err != nil {
slog.Error("LoadConfigFile error", "err", err)
return err
}
readParams := readParams{
cfg: cfg,
userids: idsSlice,
appid: *appid,
saasHttp: &saashttp.SaasClient{
Client: http.Client{},
ApiUrls: cfg.ApiUrls,
Auth: cfg.Auth,
},
}
return doRead(readParams)
}
func doRead(readParams readParams) error {
saasReq := &saasapi.SaasReq{
Cmd: &saasapi.SaasReq_Read{
Read: &saasapi.Read{},
},
}
if readParams.appid != "" {
saasReq.UseridType = saasapi.UserIdType_OPENID
saasReq.Appid = readParams.appid
}
saasReadItems := []*saasapi.ReadItem{}
for _, userid := range readParams.userids {
saasReadItems = append(saasReadItems, &saasapi.ReadItem{
Userid: userid,
})
}
saasReq.Cmd.(*saasapi.SaasReq_Read).Read.ReadItems = saasReadItems
res, err := readParams.saasHttp.Read(saasReq)
if err != nil {
slog.Error("submitRead error", "err", err)
return err
}
fmt.Println(protojson.Format(res))
return nil
}

View File

@@ -5,8 +5,8 @@ import (
"os"
)
// TargetConfig 配置
type TargetConfig struct {
// MapConfig 配置
type MapConfig struct {
Targets map[string]*Target `json:"targets"`
}
@@ -23,7 +23,7 @@ type Target struct {
}
// LoadConfigFile 加载配置文件
func LoadTargetFile(filename string) (*TargetConfig, error) {
func LoadMapFile(filename string) (*MapConfig, error) {
// 打开文件
f, err := os.Open(filename)
if err != nil {
@@ -31,7 +31,7 @@ func LoadTargetFile(filename string) (*TargetConfig, error) {
}
defer f.Close()
sc := &TargetConfig{}
sc := &MapConfig{}
err = json.NewDecoder(f).Decode(sc)
return sc, err

49
cmd/saastool/task.go Normal file
View File

@@ -0,0 +1,49 @@
package main
import (
"fmt"
"log/slog"
"strings"
)
func RunTask(args ...string) error {
name, args := ParseCommandName(args)
// 从参数中解析出命令
switch name {
case "", "help":
return RunTaskHelp(args...)
case "create":
return RunTaskCreate(args...)
case "list":
return RunTaskList(args...)
case "delete":
return RunTaskDelete(args...)
case "info":
return RunTaskInfo(args...)
default:
err := fmt.Errorf(`unknown command "%s"`+"\n"+`Run 'saastool task help' for usage`, name)
slog.Warn(err.Error())
return err
}
}
func RunTaskHelp(args ...string) error {
fmt.Println(strings.TrimSpace(taskUsage))
return nil
}
const taskUsage = `
Usage: saastoola task COMMAND [OPTIONS]
Commands:
create
upload Read user's 'bytes / uint32s / flags'
run
delete
info
"help" is the default command.
Use "saastool task COMMAND -help" for more information about a command.
`

View File

@@ -1,5 +0,0 @@
package main
func RunTaskCancel(args ...string) error {
return nil
}

View File

@@ -0,0 +1,5 @@
package main
func RunTaskCreate(args ...string) error {
return nil
}

View File

@@ -0,0 +1,5 @@
package main
func RunTaskDelete(args ...string) error {
return nil
}

View File

@@ -1,5 +0,0 @@
package main
func RunTaskDetail(args ...string) error {
return nil
}

View File

@@ -0,0 +1,5 @@
package main
func RunTaskInfo(args ...string) error {
return nil
}

5
cmd/saastool/task_run.go Normal file
View File

@@ -0,0 +1,5 @@
package main
func RunTaskRun(args ...string) error {
return nil
}

View File

@@ -19,7 +19,6 @@ type writeParams struct {
sourcePath string
appid string
batchSize uint
async bool
clear bool
saasHttp *saashttp.SaasClient
}
@@ -30,7 +29,6 @@ func RunWrite(args ...string) error {
sourcePath := paramSourcePath(fs)
appid := paramAppid(fs)
batchSize := paramBatchSize(fs)
async := paramAsync(fs)
clear := paramClear(fs)
if err := fs.Parse(args); err != nil {
@@ -53,7 +51,6 @@ func RunWrite(args ...string) error {
sourcePath: *sourcePath,
appid: *appid,
batchSize: *batchSize,
async: *async,
clear: *clear,
saasHttp: &saashttp.SaasClient{
Client: http.Client{},
@@ -105,7 +102,7 @@ func doLoadFileToWrite(writeParams writeParams) error {
scaner := bufio.NewScanner(file)
saasWriteCmds := []*saasapi.WriteCmd{}
saasWriteItems := []*saasapi.WriteItem{}
succ := uint32(0)
succTotal := uint32(0)
@@ -115,29 +112,29 @@ func doLoadFileToWrite(writeParams writeParams) error {
if line == "" {
continue
}
saasWriteCmd := &saasapi.WriteCmd{}
if err = protojson.Unmarshal([]byte(line), saasWriteCmd); err != nil {
saasWriteItem := &saasapi.WriteItem{}
if err = protojson.Unmarshal([]byte(line), saasWriteItem); err != nil {
return err
}
saasWriteCmds = append(saasWriteCmds, saasWriteCmd)
saasWriteItems = append(saasWriteItems, saasWriteItem)
total++
if len(saasWriteCmds) == int(writeParams.batchSize) {
if succ, _, err = submitWrite(writeParams, saasWriteCmds); err != nil {
if len(saasWriteItems) == int(writeParams.batchSize) {
if succ, _, err = submitWrite(writeParams, saasWriteItems); err != nil {
return err
}
succTotal += succ
fmt.Printf("[%v] batch_succ = %v, succ_total = %v, total_processed = %v\n", writeParams.sourcePath, succ, succTotal, total)
saasWriteCmds = saasWriteCmds[:0]
saasWriteItems = saasWriteItems[:0]
}
}
if len(saasWriteCmds) > 0 {
if succ, _, err = submitWrite(writeParams, saasWriteCmds); err != nil {
if len(saasWriteItems) > 0 {
if succ, _, err = submitWrite(writeParams, saasWriteItems); err != nil {
return err
}
succTotal += succ
@@ -147,12 +144,11 @@ func doLoadFileToWrite(writeParams writeParams) error {
return nil
}
func submitWrite(writeParams writeParams, saasWriteCmds []*saasapi.WriteCmd) (succ, total uint32, err error) {
func submitWrite(writeParams writeParams, saasWriteCmds []*saasapi.WriteItem) (succ, total uint32, err error) {
saasReq := &saasapi.SaasReq{
Cmd: &saasapi.SaasReq_Write{
Write: &saasapi.Write{
IsClearAllFirst: writeParams.clear,
Async: writeParams.async,
},
},
}
@@ -162,10 +158,15 @@ func submitWrite(writeParams writeParams, saasWriteCmds []*saasapi.WriteCmd) (su
saasReq.Appid = writeParams.appid
}
saasReq.Cmd.(*saasapi.SaasReq_Write).Write.WriteCmds = saasWriteCmds
saasReq.Cmd.(*saasapi.SaasReq_Write).Write.WriteItems = saasWriteCmds
total = uint32(len(saasWriteCmds))
succ, err = writeParams.saasHttp.Write(saasReq)
res, err := writeParams.saasHttp.Write(saasReq)
return
if err != nil {
slog.Error("submitWrite error", "err", err)
return
}
return res.GetWriteRes().GetSuccCmdCount(), total, nil
}