Files
saasapi/cmd/saastool/write.go
2026-01-11 17:24:11 +08:00

179 lines
4.2 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package main
import (
"bufio"
"flag"
"fmt"
"net/http"
"os"
"path"
"strings"
"git.algo.com.cn/public/saasapi"
"git.algo.com.cn/public/saasapi/pkg/saashttp"
"google.golang.org/protobuf/encoding/protojson"
)
type writeParams struct {
cfg *Config
sourcePath string
appid string
ds string
batchSize uint
clear bool
saasHttp *saashttp.SaasClient
}
func RunWrite(args ...string) error {
fs := flag.NewFlagSet("write", flag.ExitOnError)
cfgFile := paramConfig(fs)
sourcePath := paramSourcePath(fs)
appid := paramAppid(fs)
ds := paramDataSpaceId(fs)
batchSize := paramBatchSize(fs)
clear := paramClear(fs)
if err := fs.Parse(args); err != nil {
return fmt.Errorf("Command line parse error: %w", err)
}
if fs.NArg() > 0 || len(*sourcePath) == 0 || len(*ds) == 0 {
fs.PrintDefaults()
return nil
}
if strings.ToLower(*ds) == "wuid" && len(*appid) == 0 {
fmt.Fprintln(os.Stderr, "appid must be set when data space is wuid")
return nil
}
cfg, err := LoadConfigFile(*cfgFile)
if err != nil {
return fmt.Errorf("LoadConfigFile error: %w", err)
}
writeParams := writeParams{
cfg: cfg,
sourcePath: *sourcePath,
appid: *appid,
ds: *ds,
batchSize: *batchSize,
clear: *clear,
saasHttp: &saashttp.SaasClient{
Client: &http.Client{},
ApiUrls: saashttp.InitAPIUrl(&cfg.ApiUrls),
Auth: &cfg.Auth,
},
}
return doWrite(writeParams)
}
func doWrite(writeParams writeParams) error {
fsInfo, err := os.Stat(writeParams.sourcePath)
if err != nil {
return fmt.Errorf("file stat error: %w", err)
}
if !fsInfo.IsDir() {
// 如果是文件,直接写入
return doLoadFileToWrite(writeParams)
}
// 读取目录下信息
dirEntry, err := os.ReadDir(writeParams.sourcePath)
if err != nil {
return fmt.Errorf("read dir error: %w", err)
}
// 遍历目录
for _, dir := range dirEntry {
newParam := writeParams
newParam.sourcePath = path.Join(writeParams.sourcePath, dir.Name())
if err = doWrite(newParam); err != nil {
return err
}
}
return nil
}
func doLoadFileToWrite(writeParams writeParams) error {
// 读取文件并按行遍历,以\t分割为两列第一列为userid第二列解析为string数组
file, err := os.Open(writeParams.sourcePath)
if err != nil {
return fmt.Errorf("open file error: %w. file: %v", err, writeParams.sourcePath)
}
defer file.Close()
scaner := bufio.NewScanner(file)
saasWriteItems := []*saasapi.WriteItem{}
errCount := 0
errTotal := 0
total := 0
for scaner.Scan() {
total++
line := scaner.Text()
if line == "" {
continue
}
saasWriteItem := &saasapi.WriteItem{}
if err = protojson.Unmarshal([]byte(line), saasWriteItem); err != nil {
return fmt.Errorf("protojson unmashal error: %w. file: %v line: %v", err, writeParams.sourcePath, total)
}
saasWriteItems = append(saasWriteItems, saasWriteItem)
if len(saasWriteItems) == int(writeParams.batchSize) {
if errCount, err = submitWrite(writeParams, saasWriteItems); err != nil {
return err
}
errTotal += errCount
fmt.Printf("[%v] err_batch = %v, err_total = %v, total_processed = %v\n", writeParams.sourcePath, errCount, errTotal, total)
saasWriteItems = saasWriteItems[:0]
}
}
if len(saasWriteItems) > 0 {
if errCount, err = submitWrite(writeParams, saasWriteItems); err != nil {
return err
}
errTotal += errCount
fmt.Printf("[%v] err_batch = %v, err_total = %v, total_processed = %v\n", writeParams.sourcePath, errCount, errTotal, total)
}
return nil
}
func submitWrite(writeParams writeParams, saasWriteCmds []*saasapi.WriteItem) (errcount int, err error) {
write := &saasapi.Write{
DataspaceId: writeParams.ds,
Appid: writeParams.appid,
IsClearAllFirst: writeParams.clear,
}
saasReq := &saasapi.SaasReq{
Cmd: &saasapi.SaasReq_Write{
Write: write,
},
}
write.WriteItems = saasWriteCmds
res, err := writeParams.saasHttp.Write(saasReq)
if err != nil {
return 0, fmt.Errorf("Submit Command error: %w", err)
}
if res.GetCode() != saasapi.ErrorCode_SUCC {
err = fmt.Errorf("write failed. code:%v, status:%v", res.GetCode(), res.GetStatus())
return
}
return len(res.GetWriteRes().GetFailedUserid()), nil
}