+ app: unix, windows: require root user on first launch

This commit is contained in:
Simon Zolin 2019-04-01 12:22:54 +03:00
parent 1611057852
commit e4b53db558
3 changed files with 76 additions and 9 deletions

56
app.go
View File

@ -1,16 +1,20 @@
package main package main
import ( import (
"bufio"
"crypto/tls" "crypto/tls"
"fmt" "fmt"
"io"
"io/ioutil" "io/ioutil"
"net" "net"
"net/http" "net/http"
"os" "os"
"os/exec"
"os/signal" "os/signal"
"path/filepath" "path/filepath"
"runtime" "runtime"
"strconv" "strconv"
"strings"
"sync" "sync"
"syscall" "syscall"
@ -45,15 +49,6 @@ func main() {
return return
} }
signalChannel := make(chan os.Signal)
signal.Notify(signalChannel, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP, syscall.SIGQUIT)
go func() {
<-signalChannel
cleanup()
cleanupAlways()
os.Exit(0)
}()
// run the protection // run the protection
run(args) run(args)
} }
@ -83,6 +78,18 @@ func run(args options) {
} }
config.firstRun = detectFirstRun() config.firstRun = detectFirstRun()
if config.firstRun {
requireAdminRights()
}
signalChannel := make(chan os.Signal)
signal.Notify(signalChannel, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP, syscall.SIGQUIT)
go func() {
<-signalChannel
cleanup()
cleanupAlways()
os.Exit(0)
}()
// Do the upgrade if necessary // Do the upgrade if necessary
err := upgradeConfig() err := upgradeConfig()
@ -228,6 +235,37 @@ func run(args options) {
} }
} }
// Check if the current user has root (administrator) rights
// and if not, ask and try to run as root
func requireAdminRights() {
admin, _ := haveAdminRights()
if admin {
return
}
if runtime.GOOS == "windows" {
log.Fatal("This is the first launch of AdGuard Home. You must run it as Administrator.")
} else {
log.Error("This is the first launch of AdGuard Home. You must run it as root.")
_, _ = io.WriteString(os.Stdout, "Do you want to start AdGuard Home as root user? [y/n] ")
stdin := bufio.NewReader(os.Stdin)
buf, _ := stdin.ReadString('\n')
buf = strings.TrimSpace(buf)
if buf != "y" {
os.Exit(1)
}
cmd := exec.Command("sudo", os.Args...)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
_ = cmd.Run()
os.Exit(1)
}
}
// Write PID to a file // Write PID to a file
func writePIDFile(fn string) bool { func writePIDFile(fn string) bool {
data := fmt.Sprintf("%d", os.Getpid()) data := fmt.Sprintf("%d", os.Getpid())

View File

@ -3,6 +3,7 @@
package main package main
import ( import (
"os"
"syscall" "syscall"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
@ -19,3 +20,8 @@ func setRlimit(val uint) {
log.Error("Setrlimit() failed: %v", err) log.Error("Setrlimit() failed: %v", err)
} }
} }
// Check if the current user has root (administrator) rights
func haveAdminRights() (bool, error) {
return os.Getuid() == 0, nil
}

View File

@ -1,5 +1,28 @@
package main package main
import "golang.org/x/sys/windows"
// Set user-specified limit of how many fd's we can use // Set user-specified limit of how many fd's we can use
func setRlimit(val uint) { func setRlimit(val uint) {
} }
func haveAdminRights() (bool, error) {
var token windows.Token
h, _ := windows.GetCurrentProcess()
err := windows.OpenProcessToken(h, windows.TOKEN_QUERY, &token)
if err != nil {
return false, err
}
info := make([]byte, 4)
var returnedLen uint32
err = windows.GetTokenInformation(token, windows.TokenElevation, &info[0], uint32(len(info)), &returnedLen)
token.Close()
if err != nil {
return false, err
}
if info[0] == 0 {
return false, nil
}
return true, nil
}