// Copyright 2019 Sorint.lab // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied // See the License for the specific language governing permissions and // limitations under the License. package unarchive import ( "archive/tar" "errors" "fmt" "io" "log" "os" "path/filepath" "runtime" ) const ( defaultDirPerm = 0755 ) func Unarchive(source io.Reader, destDir string, overwrite, removeDestDir bool) error { var err error destDir, err = filepath.Abs(destDir) if err != nil { return fmt.Errorf("failed to calculate destination dir absolute path: %w", err) } // don't follow destdir if it's a symlink fi, err := os.Lstat(destDir) if err != nil && !os.IsNotExist(err) { return fmt.Errorf("failed to lstat destination dir: %w", err) } if fi != nil && !fi.IsDir() { return fmt.Errorf("destination path %q already exists and it's not a directory (mode: %q)", destDir, fi.Mode().String()) } if fi != nil && fi.IsDir() && removeDestDir { if err := os.RemoveAll(destDir); err != nil { return fmt.Errorf("destination path %q already exists and it's not a directory (mode: %q)", destDir, fi.Mode().String()) } } tr := tar.NewReader(source) for { err := untarNext(tr, destDir, overwrite) if errors.Is(err, io.EOF) { break } if err != nil { return fmt.Errorf("error reading file in tar archive: %w", err) } } return nil } func untarNext(tr *tar.Reader, destDir string, overwrite bool) error { hdr, err := tr.Next() if err != nil { return err // don't wrap error; calling loop must break on io.EOF } destPath := filepath.Join(destDir, hdr.Name) log.Printf("file: %q", destPath) // do not overwrite existing files, if configured if !overwrite && fileExists(destPath) { return fmt.Errorf("file already exists: %s", destPath) } // if "to" is a file and now exits and it's not a file then remove it if err := os.RemoveAll(destPath); err != nil { return err } switch hdr.Typeflag { case tar.TypeDir: fi, err := os.Lstat(destPath) if err != nil && !os.IsNotExist(err) { return err } if fi != nil && !fi.IsDir() { if err := os.RemoveAll(destPath); err != nil { return err } } return mkdir(destPath, hdr.FileInfo().Mode()) case tar.TypeReg, tar.TypeRegA, tar.TypeChar, tar.TypeBlock, tar.TypeFifo: fi, err := os.Lstat(destPath) if err != nil && !os.IsNotExist(err) { return err } if fi != nil && !fi.Mode().IsRegular() { if err := os.RemoveAll(destPath); err != nil { return err } } return writeNewFile(destPath, tr, hdr.FileInfo().Mode()) case tar.TypeSymlink: if fileExists(destPath) { if err := os.RemoveAll(destPath); err != nil { return err } } return writeNewSymbolicLink(destPath, hdr.Linkname) case tar.TypeLink: if fileExists(destPath) { if err := os.RemoveAll(destPath); err != nil { return err } } return writeNewHardLink(destPath, filepath.Join(destPath, hdr.Linkname)) case tar.TypeXGlobalHeader: return nil // ignore the pax global header from git-generated tarballs default: return fmt.Errorf("%s: unknown type flag: %c", hdr.Name, hdr.Typeflag) } } func fileExists(name string) bool { _, err := os.Lstat(name) return !os.IsNotExist(err) } func mkdir(dirPath string, mode os.FileMode) error { err := os.MkdirAll(dirPath, mode) if err != nil { return fmt.Errorf("%s: making directory: %w", dirPath, err) } return nil } func writeNewFile(fpath string, in io.Reader, mode os.FileMode) error { err := os.MkdirAll(filepath.Dir(fpath), defaultDirPerm) if err != nil { return fmt.Errorf("%s: making directory for file: %w", fpath, err) } out, err := os.Create(fpath) if err != nil { return fmt.Errorf("%s: creating new file: %w", fpath, err) } defer out.Close() err = out.Chmod(mode) if err != nil && runtime.GOOS != "windows" { return fmt.Errorf("%s: changing file mode: %w", fpath, err) } _, err = io.Copy(out, in) if err != nil { return fmt.Errorf("%s: writing file: %w", fpath, err) } return nil } func writeNewSymbolicLink(fpath string, target string) error { err := os.MkdirAll(filepath.Dir(fpath), defaultDirPerm) if err != nil { return fmt.Errorf("%s: making directory for file: %w", fpath, err) } err = os.Symlink(target, fpath) if err != nil { return fmt.Errorf("%s: making symbolic link for: %w", fpath, err) } return nil } func writeNewHardLink(fpath string, target string) error { err := os.MkdirAll(filepath.Dir(fpath), defaultDirPerm) if err != nil { return fmt.Errorf("%s: making directory for file: %w", fpath, err) } err = os.Link(target, fpath) if err != nil { return fmt.Errorf("%s: making hard link for: %w", fpath, err) } return nil }