diff --git a/control_update.go b/control_update.go index 606d8771..799c4968 100644 --- a/control_update.go +++ b/control_update.go @@ -178,42 +178,62 @@ func getUpdateInfo(jsonData []byte) (*updateInfo, error) { } // Unpack all files from .zip file to the specified directory -func zipFileUnpack(zipfile, outdir string) error { +// Existing files are overwritten +// Return the list of files (not directories) written +func zipFileUnpack(zipfile, outdir string) ([]string, error) { + r, err := zip.OpenReader(zipfile) if err != nil { - return fmt.Errorf("zip.OpenReader(): %s", err) + return nil, fmt.Errorf("zip.OpenReader(): %s", err) } defer r.Close() + var files []string + var err2 error + var zr io.ReadCloser for _, zf := range r.File { - zr, err := zf.Open() + zr, err = zf.Open() if err != nil { - return fmt.Errorf("zip file Open(): %s", err) + err2 = fmt.Errorf("zip file Open(): %s", err) + break } + fi := zf.FileInfo() + if len(fi.Name()) == 0 { + continue + } + fn := filepath.Join(outdir, fi.Name()) if fi.IsDir() { err = os.Mkdir(fn, fi.Mode()) - if err != nil { - return fmt.Errorf("zip file Read(): %s", err) + if err != nil && !os.IsExist(err) { + err2 = fmt.Errorf("os.Mkdir(): %s", err) + break } + log.Tracef("created directory %s", fn) continue } f, err := os.OpenFile(fn, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, fi.Mode()) if err != nil { - zr.Close() - return fmt.Errorf("os.OpenFile(): %s", err) + err2 = fmt.Errorf("os.OpenFile(): %s", err) + break } _, err = io.Copy(f, zr) if err != nil { - zr.Close() - return fmt.Errorf("io.Copy(): %s", err) + f.Close() + err2 = fmt.Errorf("io.Copy(): %s", err) + break } - zr.Close() + f.Close() + + log.Tracef("created file %s", fn) + files = append(files, fi.Name()) } - return nil + + zr.Close() + return files, err2 } // Unpack all files from .tar.gz file to the specified directory @@ -314,7 +334,7 @@ func doUpdate(u *updateInfo) error { _ = os.Mkdir(u.updateDir, 0755) _, file := filepath.Split(u.pkgName) if strings.HasSuffix(file, ".zip") { - err = zipFileUnpack(u.pkgName, u.updateDir) + _, err = zipFileUnpack(u.pkgName, u.updateDir) if err != nil { return fmt.Errorf("zipFileUnpack() failed: %s", err) } diff --git a/control_update_test.go b/control_update_test.go index dbf9d34f..8c86852c 100644 --- a/control_update_test.go +++ b/control_update_test.go @@ -43,9 +43,10 @@ func testZipFileUnpack(t *testing.T) { fn := "./dist/AdGuardHome_v0.95_Windows_amd64.zip" outdir := "./test-unpack" _ = os.Mkdir(outdir, 0755) - e := zipFileUnpack(fn, outdir) + files, e := zipFileUnpack(fn, outdir) if e != nil { t.Fatalf("FAILED: %s", e) } + t.Logf("%v", files) os.RemoveAll(outdir) }