package main import ( "crypto/sha256" "encoding/hex" "flag" "fmt" "os" "path" "runtime" "sort" "sync" "e.coding.net/rta/public/saasapi" "google.golang.org/protobuf/encoding/protojson" ) const ( blockSizeMin = 10 * 1024 * 1024 blockSizeMax = 200 * 1024 * 1024 ) type makeHashParams struct { sourcePath string destPath string task *saasapi.Task } // 计算任务 type hashTask struct { chunk []byte index int } // 计算结果 type hashResult struct { hash string blockSize uint64 index int } func RunMakeHash(args ...string) error { fs := flag.NewFlagSet("tasklocalmake", flag.ExitOnError) sourcePath := paramSourcePath(fs) destPath := paramDestPath(fs) blockSize := paramBlockSize(fs) if err := fs.Parse(args); err != nil { fmt.Println("command line parse error", "err", err) return err } if fs.NArg() > 0 || len(*sourcePath) == 0 || len(*destPath) == 0 { fs.PrintDefaults() return nil } if blockSize < blockSizeMin || blockSize > blockSizeMax { fmt.Println("block size error", "min", blockSizeMin, "max", blockSizeMax) return nil } makeHashParams := makeHashParams{ sourcePath: *sourcePath, destPath: *destPath, task: &saasapi.Task{ TaskBlockSize: blockSize, }, } return doMakeHash(makeHashParams) } func doMakeHash(makeHashParams makeHashParams) error { fsInfo, err := os.Stat(makeHashParams.sourcePath) if err != nil { return err } if !fsInfo.IsDir() { // 如果是文件,直接计算 return doFileHash(makeHashParams) } // 读取目录下信息 dirEntry, err := os.ReadDir(makeHashParams.sourcePath) if err != nil { return err } // 遍历目录 for _, dir := range dirEntry { newParam := makeHashParams newParam.sourcePath = path.Join(makeHashParams.sourcePath, dir.Name()) if dir.IsDir() { newParam.destPath = path.Join(makeHashParams.destPath, dir.Name()) } if err = doMakeHash(newParam); err != nil { return err } } return saveTaskFile(makeHashParams) } func doFileHash(makeHashParams makeHashParams) error { sourceFile, err := os.Open(makeHashParams.sourcePath) if err != nil { return err } defer sourceFile.Close() fi, err := sourceFile.Stat() if err != nil { return err } tasks := make(chan hashTask) results := make(chan hashResult) // 启动工作协程 hashMaxWorker := runtime.GOMAXPROCS(0) for range hashMaxWorker { go hashWorker(tasks, results) } wg := sync.WaitGroup{} wg.Add(1) go func() { index := 0 buffer := make([]byte, makeHashParams.task.TaskBlockSize) for { n, err := sourceFile.Read(buffer) if n > 0 { wg.Add(1) fmt.Printf("\rhashing file [%v], block [%v]", makeHashParams.sourcePath, index) tasks <- hashTask{chunk: buffer[:n], index: index} index++ } if err != nil { break } } close(tasks) wg.Done() }() var allResults []hashResult go func() { for r := range results { allResults = append(allResults, r) wg.Done() } }() wg.Wait() close(results) // 按索引排序结果 sort.Slice(allResults, func(i, j int) bool { return allResults[i].index < allResults[j].index }) // 输出结果 fileInfo := &saasapi.FileInfo{ FileName: makeHashParams.sourcePath, FileSize: uint64(fi.Size()), } for _, r := range allResults { fileInfo.FileBlocks = append(fileInfo.FileBlocks, &saasapi.FileBlock{ BlockSha256: r.hash, BlockLength: r.blockSize, }) } makeHashParams.task.TaskFileInfos = append(makeHashParams.task.TaskFileInfos, fileInfo) fmt.Println("") return nil } // hash计算协程 func hashWorker(tasks <-chan hashTask, results chan<- hashResult) { for t := range tasks { h := sha256.New() h.Write(t.chunk) hash := hex.EncodeToString(h.Sum(nil)) results <- hashResult{hash: hash, index: t.index, blockSize: uint64(len(t.chunk))} } } func saveTaskFile(makeHashParams makeHashParams) error { taskFile, err := os.Create(makeHashParams.destPath) if err != nil { return err } defer taskFile.Close() h := sha256.New() for _, fileInfo := range makeHashParams.task.TaskFileInfos { for _, fileBlock := range fileInfo.FileBlocks { h.Write([]byte(fileBlock.BlockSha256)) } } makeHashParams.task.TaskSha256 = hex.EncodeToString(h.Sum(nil)) _, err = taskFile.WriteString(protojson.Format(makeHashParams.task)) if err != nil { return err } return nil }