Files
saasapi/cmd/saastool/task_make.go
2025-04-22 14:30:30 +08:00

244 lines
4.9 KiB
Go

package main
import (
"crypto/sha256"
"encoding/hex"
"flag"
"fmt"
"os"
"path"
"runtime"
"sort"
"sync"
"e.coding.net/rta/public/saasapi"
"google.golang.org/protobuf/encoding/protojson"
)
const (
blockSizeMin = 50 * 1024 * 1024
blockSizeMax = 200 * 1024 * 1024
)
type makeTaskParams struct {
sourcePath string
hashFile string
task *saasapi.Task
}
// 计算任务
type hashTask struct {
chunk *[]byte
hash string
blockSize uint64
index int
}
/*
// 计算结果
type hashResult struct {
hash string
blockSize uint64
index int
}
*/
func RunTaskMake(args ...string) error {
fs := flag.NewFlagSet("make", flag.ExitOnError)
sourcePath := paramSourceConvertedPath(fs)
hashFile := paramOutputHashFile(fs)
blockSize := paramBlockSize(fs)
desc := paramTaskDesc(fs)
if err := fs.Parse(args); err != nil {
fmt.Fprintln(os.Stderr, "command line parse error", "err", err)
return err
}
if fs.NArg() > 0 || len(*sourcePath) == 0 || len(*hashFile) == 0 {
fs.PrintDefaults()
return nil
}
blockSizeNum, err := ParseByteSize(*blockSize)
if err != nil {
fmt.Fprintln(os.Stderr, "Error parsing block size", "err", err)
fmt.Fprintln(os.Stderr, "Using default 200M")
blockSizeNum = 200 * 1024 * 1024
}
if blockSizeNum < blockSizeMin || blockSizeNum > blockSizeMax {
fmt.Fprintln(os.Stderr, "block size error", "min", blockSizeMin, "max", blockSizeMax)
return nil
}
makeTaskParams := makeTaskParams{
sourcePath: *sourcePath,
hashFile: *hashFile,
task: &saasapi.Task{
TaskBlockSize: blockSizeNum,
TaskDescription: *desc,
},
}
return doMakeHash(makeTaskParams)
}
func doMakeHash(makeTaskParams makeTaskParams) error {
fsInfo, err := os.Stat(makeTaskParams.sourcePath)
if err != nil {
return err
}
if !fsInfo.IsDir() {
// 如果是文件,直接计算
return doTaskMake(makeTaskParams)
}
// 读取目录下信息
dirEntry, err := os.ReadDir(makeTaskParams.sourcePath)
if err != nil {
return err
}
// 遍历目录
for _, dir := range dirEntry {
newParam := makeTaskParams
newParam.sourcePath = path.Join(makeTaskParams.sourcePath, dir.Name())
if err = doMakeHash(newParam); err != nil {
return err
}
}
return saveTaskFile(makeTaskParams)
}
func doTaskMake(makeTaskParams makeTaskParams) error {
sourceFile, err := os.Open(makeTaskParams.sourcePath)
if err != nil {
return err
}
defer sourceFile.Close()
fi, err := sourceFile.Stat()
if err != nil {
return err
}
totalSize := uint64(fi.Size())
// 计算读取次数
readTimes := int(totalSize / makeTaskParams.task.TaskBlockSize)
if totalSize%makeTaskParams.task.TaskBlockSize != 0 {
readTimes++
}
tasks := make(chan *hashTask)
results := make(chan *hashTask)
// 启动工作协程
hashMaxWorker := runtime.GOMAXPROCS(0)
for range hashMaxWorker {
go hashWorker(tasks, results)
}
// 初始化读缓存
readBuffers := make([][]byte, hashMaxWorker)
for i := range hashMaxWorker {
readBuffers[i] = make([]byte, makeTaskParams.task.TaskBlockSize)
}
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
for index := range readTimes {
buffer := &readBuffers[index%hashMaxWorker]
n, err := sourceFile.Read(*buffer)
if n > 0 {
wg.Add(1)
tasks <- &hashTask{
chunk: buffer,
index: index,
blockSize: uint64(n),
}
}
if err != nil {
break
}
}
close(tasks)
wg.Done()
}()
// 接收结果
var allResults []*hashTask
go func() {
for r := range results {
allResults = append(allResults, r)
fmt.Printf("\rhashed file [%v], block [%v]", makeTaskParams.sourcePath, r.index)
wg.Done()
}
}()
wg.Wait()
close(results)
// 按索引排序结果
sort.Slice(allResults, func(i, j int) bool {
return allResults[i].index < allResults[j].index
})
// 输出结果
fileInfo := &saasapi.FileInfo{
FileName: makeTaskParams.sourcePath,
FileSize: uint64(fi.Size()),
}
for _, r := range allResults {
fileInfo.FileBlocks = append(fileInfo.FileBlocks, &saasapi.FileBlock{
BlockSha256: r.hash,
BlockLength: r.blockSize,
})
}
makeTaskParams.task.TaskFileInfos = append(makeTaskParams.task.TaskFileInfos, fileInfo)
fmt.Println("")
return nil
}
// hash计算协程
func hashWorker(tasks <-chan *hashTask, results chan<- *hashTask) {
h := sha256.New()
for t := range tasks {
h.Write((*t.chunk)[:t.blockSize])
t.hash = hex.EncodeToString(h.Sum(nil))
results <- t
h.Reset()
}
}
func saveTaskFile(makeTaskParams makeTaskParams) error {
taskFile, err := os.Create(makeTaskParams.hashFile)
if err != nil {
return err
}
defer taskFile.Close()
h := sha256.New()
for _, fileInfo := range makeTaskParams.task.TaskFileInfos {
for _, fileBlock := range fileInfo.FileBlocks {
h.Write([]byte(fileBlock.BlockSha256))
}
}
makeTaskParams.task.TaskSha256 = hex.EncodeToString(h.Sum(nil))
_, err = taskFile.WriteString(protojson.Format(makeTaskParams.task))
if err != nil {
return err
}
return nil
}