package main import ( "crypto/sha256" "encoding/hex" "flag" "fmt" "os" "path" "runtime" "sort" "sync" "e.coding.net/rta/public/saasapi" "google.golang.org/protobuf/encoding/protojson" ) const ( blockSizeMin = 10 * 1024 * 1024 blockSizeMax = 200 * 1024 * 1024 ) type makeTaskParams struct { sourcePath string hashFile string task *saasapi.Task } // 计算任务 type hashTask struct { chunk []byte index int } // 计算结果 type hashResult struct { hash string blockSize uint64 index int } func RunTaskMake(args ...string) error { fs := flag.NewFlagSet("make", flag.ExitOnError) sourcePath := paramSourceConvertedPath(fs) hashFile := paramOutputHashFile(fs) blockSize := paramBlockSize(fs) desc := paramTaskDesc(fs) if err := fs.Parse(args); err != nil { fmt.Println("command line parse error", "err", err) return err } if fs.NArg() > 0 || len(*sourcePath) == 0 || len(*hashFile) == 0 { fs.PrintDefaults() return nil } if blockSize < blockSizeMin || blockSize > blockSizeMax { fmt.Println("block size error", "min", blockSizeMin, "max", blockSizeMax) return nil } makeTaskParams := makeTaskParams{ sourcePath: *sourcePath, hashFile: *hashFile, task: &saasapi.Task{ TaskBlockSize: blockSize, TaskDescription: *desc, }, } return doMakeHash(makeTaskParams) } func doMakeHash(makeTaskParams makeTaskParams) error { fsInfo, err := os.Stat(makeTaskParams.sourcePath) if err != nil { return err } if !fsInfo.IsDir() { // 如果是文件,直接计算 return doTaskMake(makeTaskParams) } // 读取目录下信息 dirEntry, err := os.ReadDir(makeTaskParams.sourcePath) if err != nil { return err } // 遍历目录 for _, dir := range dirEntry { newParam := makeTaskParams newParam.sourcePath = path.Join(makeTaskParams.sourcePath, dir.Name()) if err = doMakeHash(newParam); err != nil { return err } } return saveTaskFile(makeTaskParams) } func doTaskMake(makeTaskParams makeTaskParams) error { sourceFile, err := os.Open(makeTaskParams.sourcePath) if err != nil { return err } defer sourceFile.Close() fi, err := sourceFile.Stat() if err != nil { return err } tasks := make(chan hashTask) results := make(chan hashResult) // 启动工作协程 hashMaxWorker := runtime.GOMAXPROCS(0) for range hashMaxWorker { go hashWorker(tasks, results) } wg := sync.WaitGroup{} wg.Add(1) go func() { index := 0 buffer := make([]byte, makeTaskParams.task.TaskBlockSize) for { n, err := sourceFile.Read(buffer) if n > 0 { wg.Add(1) fmt.Printf("\rhashing file [%v], block [%v]", makeTaskParams.sourcePath, index) tasks <- hashTask{chunk: buffer[:n], index: index} index++ } if err != nil { break } } close(tasks) wg.Done() }() var allResults []hashResult go func() { for r := range results { allResults = append(allResults, r) wg.Done() } }() wg.Wait() close(results) // 按索引排序结果 sort.Slice(allResults, func(i, j int) bool { return allResults[i].index < allResults[j].index }) // 输出结果 fileInfo := &saasapi.FileInfo{ FileName: makeTaskParams.sourcePath, FileSize: uint64(fi.Size()), } for _, r := range allResults { fileInfo.FileBlocks = append(fileInfo.FileBlocks, &saasapi.FileBlock{ BlockSha256: r.hash, BlockLength: r.blockSize, }) } makeTaskParams.task.TaskFileInfos = append(makeTaskParams.task.TaskFileInfos, fileInfo) fmt.Println("") return nil } // hash计算协程 func hashWorker(tasks <-chan hashTask, results chan<- hashResult) { for t := range tasks { h := sha256.New() h.Write(t.chunk) hash := hex.EncodeToString(h.Sum(nil)) results <- hashResult{hash: hash, index: t.index, blockSize: uint64(len(t.chunk))} } } func saveTaskFile(makeTaskParams makeTaskParams) error { taskFile, err := os.Create(makeTaskParams.hashFile) if err != nil { return err } defer taskFile.Close() h := sha256.New() for _, fileInfo := range makeTaskParams.task.TaskFileInfos { for _, fileBlock := range fileInfo.FileBlocks { h.Write([]byte(fileBlock.BlockSha256)) } } makeTaskParams.task.TaskSha256 = hex.EncodeToString(h.Sum(nil)) _, err = taskFile.WriteString(protojson.Format(makeTaskParams.task)) if err != nil { return err } return nil }