diff --git a/home/config.go b/home/config.go index 05831532..42e08cc2 100644 --- a/home/config.go +++ b/home/config.go @@ -5,7 +5,6 @@ import ( "os" "path/filepath" "sync" - "time" "github.com/AdguardTeam/AdGuardHome/dhcpd" "github.com/AdguardTeam/AdGuardHome/dnsfilter" @@ -40,10 +39,6 @@ type configuration struct { // It's reset after config is parsed fileData []byte - // cached version.json to avoid hammering github.io for each page reload - versionCheckJSON []byte - versionCheckLastTime time.Time - BindHost string `yaml:"bind_host"` // BindHost is the IP address of the HTTP server to bind to BindPort int `yaml:"bind_port"` // BindPort is the port the HTTP server Users []User `yaml:"users"` // Users that can access HTTP server diff --git a/home/control_update.go b/home/control_update.go index e711884d..6e6b1601 100644 --- a/home/control_update.go +++ b/home/control_update.go @@ -1,13 +1,7 @@ package home import ( - "archive/tar" - "archive/zip" - "compress/gzip" "encoding/json" - "fmt" - "io" - "io/ioutil" "net/http" "os" "os/exec" @@ -15,25 +9,12 @@ import ( "runtime" "strings" "syscall" - "time" + "github.com/AdguardTeam/AdGuardHome/update" "github.com/AdguardTeam/AdGuardHome/util" "github.com/AdguardTeam/golibs/log" ) -type updateInfo struct { - pkgURL string // URL for the new package - pkgName string // Full path to package file - newVer string // New version string - updateDir string // Full path to the directory containing unpacked files from the new package - backupDir string // Full path to backup directory - configName string // Full path to the current configuration file - updateConfigName string // Full path to the configuration file to check by the new binary - curBinName string // Full path to the current executable file - bkpBinName string // Full path to the current executable file in backup directory - newBinName string // Full path to the new executable file -} - type getVersionJSONRequest struct { RecheckNow bool `json:"recheck_now"` } @@ -58,28 +39,11 @@ func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) { } } - now := time.Now() - if !req.RecheckNow { - Context.controlLock.Lock() - cached := now.Sub(config.versionCheckLastTime) <= versionCheckPeriod && len(config.versionCheckJSON) != 0 - data := config.versionCheckJSON - Context.controlLock.Unlock() - - if cached { - log.Tracef("Returning cached data") - w.Header().Set("Content-Type", "application/json") - _, _ = w.Write(getVersionResp(data)) - return - } - } - - var resp *http.Response + var info update.VersionInfo for i := 0; i != 3; i++ { - log.Tracef("Downloading data from %s", versionCheckURL) - resp, err = Context.client.Get(versionCheckURL) - if resp != nil && resp.Body != nil { - defer resp.Body.Close() - } + Context.controlLock.Lock() + info, err = Context.updater.GetVersionResponse(req.RecheckNow) + Context.controlLock.Unlock() if err != nil && strings.HasSuffix(err.Error(), "i/o timeout") { // This case may happen while we're restarting DNS server // https://github.com/AdguardTeam/AdGuardHome/issues/934 @@ -92,20 +56,8 @@ func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) { return } - // read the body entirely - body, err := ioutil.ReadAll(resp.Body) - if err != nil { - httpError(w, http.StatusBadGateway, "Couldn't read response body from %s: %s", versionCheckURL, err) - return - } - - Context.controlLock.Lock() - config.versionCheckLastTime = now - config.versionCheckJSON = body - Context.controlLock.Unlock() - w.Header().Set("Content-Type", "application/json") - _, err = w.Write(getVersionResp(body)) + _, err = w.Write(getVersionResp(info)) if err != nil { httpError(w, http.StatusInternalServerError, "Couldn't write body: %s", err) } @@ -113,18 +65,12 @@ func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) { // Perform an update procedure to the latest available version func handleUpdate(w http.ResponseWriter, r *http.Request) { - if len(config.versionCheckJSON) == 0 { + if len(Context.updater.NewVersion) == 0 { httpError(w, http.StatusBadRequest, "/update request isn't allowed now") return } - u, err := getUpdateInfo(config.versionCheckJSON) - if err != nil { - httpError(w, http.StatusInternalServerError, "%s", err) - return - } - - err = doUpdate(u) + err := Context.updater.DoUpdate() if err != nil { httpError(w, http.StatusInternalServerError, "%s", err) return @@ -135,33 +81,18 @@ func handleUpdate(w http.ResponseWriter, r *http.Request) { f.Flush() } - go finishUpdate(u) + go finishUpdate() } // Convert version.json data to our JSON response -func getVersionResp(data []byte) []byte { - versionJSON := make(map[string]interface{}) - err := json.Unmarshal(data, &versionJSON) - if err != nil { - log.Error("version.json: %s", err) - return []byte{} - } - +func getVersionResp(info update.VersionInfo) []byte { ret := make(map[string]interface{}) ret["can_autoupdate"] = false + ret["new_version"] = info.NewVersion + ret["announcement"] = info.Announcement + ret["announcement_url"] = info.AnnouncementURL - var ok1, ok2, ok3 bool - ret["new_version"], ok1 = versionJSON["version"].(string) - ret["announcement"], ok2 = versionJSON["announcement"].(string) - ret["announcement_url"], ok3 = versionJSON["announcement_url"].(string) - selfUpdateMinVersion, ok4 := versionJSON["selfupdate_min_version"].(string) - if !ok1 || !ok2 || !ok3 || !ok4 { - log.Error("version.json: invalid data") - return []byte{} - } - - _, ok := getDownloadURL(versionJSON) - if ok && ret["new_version"] != versionString && versionString >= selfUpdateMinVersion { + if info.CanAutoUpdate { canUpdate := true tlsConf := tlsConfigSettings{} @@ -185,375 +116,18 @@ func getVersionResp(data []byte) []byte { return d } -// Copy file on disk -func copyFile(src, dst string) error { - d, e := ioutil.ReadFile(src) - if e != nil { - return e - } - e = ioutil.WriteFile(dst, d, 0644) - if e != nil { - return e - } - return nil -} - -// Fill in updateInfo object -func getUpdateInfo(jsonData []byte) (*updateInfo, error) { - var u updateInfo - - workDir := Context.workDir - - versionJSON := make(map[string]interface{}) - err := json.Unmarshal(jsonData, &versionJSON) - if err != nil { - return nil, fmt.Errorf("JSON parse: %s", err) - } - - pkgURL, ok := getDownloadURL(versionJSON) - if !ok { - return nil, fmt.Errorf("failed to get download URL") - } - - u.pkgURL = pkgURL - u.newVer = versionJSON["version"].(string) - if len(u.pkgURL) == 0 || len(u.newVer) == 0 { - return nil, fmt.Errorf("invalid JSON") - } - - if u.newVer == versionString { - return nil, fmt.Errorf("no need to update") - } - - u.updateDir = filepath.Join(workDir, fmt.Sprintf("agh-update-%s", u.newVer)) - u.backupDir = filepath.Join(workDir, "agh-backup") - - _, pkgFileName := filepath.Split(u.pkgURL) - if len(pkgFileName) == 0 { - return nil, fmt.Errorf("invalid JSON") - } - u.pkgName = filepath.Join(u.updateDir, pkgFileName) - - u.configName = config.getConfigFilename() - u.updateConfigName = filepath.Join(u.updateDir, "AdGuardHome", "AdGuardHome.yaml") - if strings.HasSuffix(pkgFileName, ".zip") { - u.updateConfigName = filepath.Join(u.updateDir, "AdGuardHome.yaml") - } - - binName := "AdGuardHome" - if runtime.GOOS == "windows" { - binName = "AdGuardHome.exe" - } - - // TODO: This is a mistake, work dir can be different - u.curBinName = filepath.Join(workDir, binName) - if !util.FileExists(u.curBinName) { - return nil, fmt.Errorf("executable file %s doesn't exist", u.curBinName) - } - u.bkpBinName = filepath.Join(u.backupDir, binName) - u.newBinName = filepath.Join(u.updateDir, "AdGuardHome", binName) - if strings.HasSuffix(pkgFileName, ".zip") { - u.newBinName = filepath.Join(u.updateDir, binName) - } - - return &u, nil -} - -// getDownloadURL - gets download URL for the current GOOS/GOARCH -// returns -func getDownloadURL(json map[string]interface{}) (string, bool) { - var key string - - if runtime.GOARCH == "arm" && ARMVersion != "" { - // the key is: - // download_linux_armv5 for ARMv5 - // download_linux_armv6 for ARMv6 - // download_linux_armv7 for ARMv7 - key = fmt.Sprintf("download_%s_%sv%s", runtime.GOOS, runtime.GOARCH, ARMVersion) - } - - u, ok := json[key] - if !ok { - // the key is download_linux_arm or download_linux_arm64 for regular ARM versions - key = fmt.Sprintf("download_%s_%s", runtime.GOOS, runtime.GOARCH) - u, ok = json[key] - } - - if !ok { - return "", false - } - - return u.(string), true -} - -// Unpack all files from .zip file to the specified directory -// Existing files are overwritten -// Return the list of files (not directories) written -func zipFileUnpack(zipfile, outdir string) ([]string, error) { - - r, err := zip.OpenReader(zipfile) - if err != nil { - return nil, fmt.Errorf("zip.OpenReader(): %s", err) - } - defer r.Close() - - var files []string - var err2 error - var zr io.ReadCloser - for _, zf := range r.File { - zr, err = zf.Open() - if err != nil { - err2 = fmt.Errorf("zip file Open(): %s", err) - break - } - - fi := zf.FileInfo() - if len(fi.Name()) == 0 { - continue - } - - fn := filepath.Join(outdir, fi.Name()) - - if fi.IsDir() { - err = os.Mkdir(fn, fi.Mode()) - if err != nil && !os.IsExist(err) { - err2 = fmt.Errorf("os.Mkdir(): %s", err) - break - } - log.Tracef("created directory %s", fn) - continue - } - - f, err := os.OpenFile(fn, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, fi.Mode()) - if err != nil { - err2 = fmt.Errorf("os.OpenFile(): %s", err) - break - } - _, err = io.Copy(f, zr) - if err != nil { - f.Close() - err2 = fmt.Errorf("io.Copy(): %s", err) - break - } - f.Close() - - log.Tracef("created file %s", fn) - files = append(files, fi.Name()) - } - - zr.Close() - return files, err2 -} - -// Unpack all files from .tar.gz file to the specified directory -// Existing files are overwritten -// Return the list of files (not directories) written -func targzFileUnpack(tarfile, outdir string) ([]string, error) { - - f, err := os.Open(tarfile) - if err != nil { - return nil, fmt.Errorf("os.Open(): %s", err) - } - defer f.Close() - - gzReader, err := gzip.NewReader(f) - if err != nil { - return nil, fmt.Errorf("gzip.NewReader(): %s", err) - } - - var files []string - var err2 error - tarReader := tar.NewReader(gzReader) - for { - header, err := tarReader.Next() - if err == io.EOF { - err2 = nil - break - } - if err != nil { - err2 = fmt.Errorf("tarReader.Next(): %s", err) - break - } - if len(header.Name) == 0 { - continue - } - - fn := filepath.Join(outdir, header.Name) - - if header.Typeflag == tar.TypeDir { - err = os.Mkdir(fn, os.FileMode(header.Mode&0777)) - if err != nil && !os.IsExist(err) { - err2 = fmt.Errorf("os.Mkdir(%s): %s", fn, err) - break - } - log.Tracef("created directory %s", fn) - continue - } else if header.Typeflag != tar.TypeReg { - log.Tracef("%s: unknown file type %d, skipping", header.Name, header.Typeflag) - continue - } - - f, err := os.OpenFile(fn, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, os.FileMode(header.Mode&0777)) - if err != nil { - err2 = fmt.Errorf("os.OpenFile(%s): %s", fn, err) - break - } - _, err = io.Copy(f, tarReader) - if err != nil { - f.Close() - err2 = fmt.Errorf("io.Copy(): %s", err) - break - } - f.Close() - - log.Tracef("created file %s", fn) - files = append(files, header.Name) - } - - gzReader.Close() - return files, err2 -} - -func copySupportingFiles(files []string, srcdir, dstdir string, useSrcNameOnly, useDstNameOnly bool) error { - for _, f := range files { - _, name := filepath.Split(f) - if name == "AdGuardHome" || name == "AdGuardHome.exe" || name == "AdGuardHome.yaml" { - continue - } - - src := filepath.Join(srcdir, f) - if useSrcNameOnly { - src = filepath.Join(srcdir, name) - } - - dst := filepath.Join(dstdir, f) - if useDstNameOnly { - dst = filepath.Join(dstdir, name) - } - - err := copyFile(src, dst) - if err != nil && !os.IsNotExist(err) { - return err - } - - log.Tracef("Copied: %s -> %s", src, dst) - } - return nil -} - -// Download package file and save it to disk -func getPackageFile(u *updateInfo) error { - resp, err := Context.client.Get(u.pkgURL) - if err != nil { - return fmt.Errorf("HTTP request failed: %s", err) - } - if resp != nil && resp.Body != nil { - defer resp.Body.Close() - } - - log.Tracef("Reading HTTP body") - body, err := ioutil.ReadAll(resp.Body) - if err != nil { - return fmt.Errorf("ioutil.ReadAll() failed: %s", err) - } - - log.Tracef("Saving package to file") - err = ioutil.WriteFile(u.pkgName, body, 0644) - if err != nil { - return fmt.Errorf("ioutil.WriteFile() failed: %s", err) - } - return nil -} - -// Perform an update procedure -func doUpdate(u *updateInfo) error { - log.Info("Updating from %s to %s. URL:%s Package:%s", - versionString, u.newVer, u.pkgURL, u.pkgName) - - _ = os.Mkdir(u.updateDir, 0755) - - var err error - err = getPackageFile(u) - if err != nil { - return err - } - - log.Tracef("Unpacking the package") - _, file := filepath.Split(u.pkgName) - var files []string - if strings.HasSuffix(file, ".zip") { - files, err = zipFileUnpack(u.pkgName, u.updateDir) - if err != nil { - return fmt.Errorf("zipFileUnpack() failed: %s", err) - } - } else if strings.HasSuffix(file, ".tar.gz") { - files, err = targzFileUnpack(u.pkgName, u.updateDir) - if err != nil { - return fmt.Errorf("targzFileUnpack() failed: %s", err) - } - } else { - return fmt.Errorf("unknown package extension") - } - - log.Tracef("Checking configuration") - err = copyFile(u.configName, u.updateConfigName) - if err != nil { - return fmt.Errorf("copyFile() failed: %s", err) - } - cmd := exec.Command(u.newBinName, "--check-config") - err = cmd.Run() - if err != nil || cmd.ProcessState.ExitCode() != 0 { - return fmt.Errorf("exec.Command(): %s %d", err, cmd.ProcessState.ExitCode()) - } - - log.Tracef("Backing up the current configuration") - _ = os.Mkdir(u.backupDir, 0755) - err = copyFile(u.configName, filepath.Join(u.backupDir, "AdGuardHome.yaml")) - if err != nil { - return fmt.Errorf("copyFile() failed: %s", err) - } - - // ./README.md -> backup/README.md - err = copySupportingFiles(files, Context.workDir, u.backupDir, true, true) - if err != nil { - return fmt.Errorf("copySupportingFiles(%s, %s) failed: %s", - Context.workDir, u.backupDir, err) - } - - // update/[AdGuardHome/]README.md -> ./README.md - err = copySupportingFiles(files, u.updateDir, Context.workDir, false, true) - if err != nil { - return fmt.Errorf("copySupportingFiles(%s, %s) failed: %s", - u.updateDir, Context.workDir, err) - } - - log.Tracef("Renaming: %s -> %s", u.curBinName, u.bkpBinName) - err = os.Rename(u.curBinName, u.bkpBinName) - if err != nil { - return err - } - if runtime.GOOS == "windows" { - // rename fails with "File in use" error - err = copyFile(u.newBinName, u.curBinName) - } else { - err = os.Rename(u.newBinName, u.curBinName) - } - if err != nil { - return err - } - log.Tracef("Renamed: %s -> %s", u.newBinName, u.curBinName) - - _ = os.Remove(u.pkgName) - _ = os.RemoveAll(u.updateDir) - return nil -} - // Complete an update procedure -func finishUpdate(u *updateInfo) { +func finishUpdate() { log.Info("Stopping all tasks") cleanup() cleanupAlways() + exeName := "AdGuardHome" + if runtime.GOOS == "windows" { + exeName = "AdGuardHome.exe" + } + curBinName := filepath.Join(Context.workDir, exeName) + if runtime.GOOS == "windows" { if Context.runningAsService { // Note: @@ -567,7 +141,7 @@ func finishUpdate(u *updateInfo) { os.Exit(0) } - cmd := exec.Command(u.curBinName, os.Args[1:]...) + cmd := exec.Command(curBinName, os.Args[1:]...) log.Info("Restarting: %v", cmd.Args) cmd.Stdin = os.Stdin cmd.Stdout = os.Stdout @@ -579,7 +153,7 @@ func finishUpdate(u *updateInfo) { os.Exit(0) } else { log.Info("Restarting: %v", os.Args) - err := syscall.Exec(u.curBinName, os.Args, os.Environ()) + err := syscall.Exec(curBinName, os.Args, os.Environ()) if err != nil { log.Fatalf("syscall.Exec() failed: %s", err) } diff --git a/home/home.go b/home/home.go index 79cf0843..e2ea4389 100644 --- a/home/home.go +++ b/home/home.go @@ -20,6 +20,7 @@ import ( "gopkg.in/natefinch/lumberjack.v2" + "github.com/AdguardTeam/AdGuardHome/update" "github.com/AdguardTeam/AdGuardHome/util" "github.com/joomcode/errorx" @@ -47,8 +48,6 @@ var ( ARMVersion = "" ) -const versionCheckPeriod = time.Hour * 8 - // Global context type homeContext struct { // Modules @@ -67,6 +66,7 @@ type homeContext struct { web *Web // Web (HTTP, HTTPS) module tls *TLSMod // TLS module autoHosts util.AutoHosts // IP-hostname pairs taken from system configuration (e.g. /etc/hosts) files + updater *update.Updater // Runtime properties // -- @@ -225,6 +225,16 @@ func run(args options) { os.Exit(1) } Context.autoHosts.Init("") + + Context.updater = update.NewUpdater(Context.workDir) + Context.updater.Client = Context.client + Context.updater.VersionURL = versionCheckURL + Context.updater.VersionString = versionString + Context.updater.OS = runtime.GOOS + Context.updater.Arch = runtime.GOARCH + Context.updater.ARMVersion = ARMVersion + Context.updater.ConfigName = config.getConfigFilename() + Context.clients.Init(config.Clients, Context.dhcpServer, &Context.autoHosts) config.Clients = nil diff --git a/update/check.go b/update/check.go index 75b551bb..c5492431 100644 --- a/update/check.go +++ b/update/check.go @@ -71,6 +71,9 @@ func (u *Updater) parseVersionResponse(data []byte) (VersionInfo, error) { info.CanAutoUpdate = true } + u.NewVersion = info.NewVersion + u.PackageURL = info.PackageURL + return info, nil } diff --git a/update/updater.go b/update/updater.go index cb094f79..413bbd0c 100644 --- a/update/updater.go +++ b/update/updater.go @@ -14,6 +14,7 @@ import ( "strings" "time" + "github.com/AdguardTeam/AdGuardHome/util" "github.com/AdguardTeam/golibs/log" ) @@ -110,7 +111,17 @@ func (u *Updater) prepare() error { u.backupExeName = filepath.Join(u.backupDir, exeName) u.updateExeName = filepath.Join(u.updateDir, exeName) - u.currentExeName = os.Args[0] + + log.Info("Updating from %s to %s. URL:%s", + u.VersionString, u.NewVersion, u.PackageURL) + + // If the binary file isn't found in working directory, we won't be able to auto-update + // Getting the full path to the current binary file on UNIX and checking write permissions + // is more difficult. + u.currentExeName = filepath.Join(u.workDir, exeName) + if !util.FileExists(u.currentExeName) { + return fmt.Errorf("executable file %s doesn't exist", u.currentExeName) + } return nil }