Pull request: home: fix migration, imp code

Updates #1401.
Updates #2646.

Squashed commit of the following:

commit 93b025a2184a72283e22748fecfc478fa549c922
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Wed Mar 24 16:41:07 2021 +0300

    home: fix migration, imp code
This commit is contained in:
Ainar Garipov 2021-03-24 17:17:44 +03:00
parent e10a3fa4b3
commit ba3fc242ab
2 changed files with 70 additions and 95 deletions

View File

@ -26,19 +26,19 @@ type (
// Performs necessary upgrade operations if needed // Performs necessary upgrade operations if needed
func upgradeConfig() error { func upgradeConfig() error {
// read a config file into an interface map, so we can manipulate values without losing any // read a config file into an interface map, so we can manipulate values without losing any
diskConfig := yobj{} diskConf := yobj{}
body, err := readConfigFile() body, err := readConfigFile()
if err != nil { if err != nil {
return err return err
} }
err = yaml.Unmarshal(body, &diskConfig) err = yaml.Unmarshal(body, &diskConf)
if err != nil { if err != nil {
log.Printf("Couldn't parse config file: %s", err) log.Printf("Couldn't parse config file: %s", err)
return err return err
} }
schemaVersionInterface, ok := diskConfig["schema_version"] schemaVersionInterface, ok := diskConf["schema_version"]
log.Tracef("got schema version %v", schemaVersionInterface) log.Tracef("got schema version %v", schemaVersionInterface)
if !ok { if !ok {
// no schema version, set it to 0 // no schema version, set it to 0
@ -57,76 +57,51 @@ func upgradeConfig() error {
return nil return nil
} }
return upgradeConfigSchema(schemaVersion, &diskConfig) return upgradeConfigSchema(schemaVersion, diskConf)
} }
// upgradeFunc is a function that upgrades a config and returns an error.
type upgradeFunc = func(diskConf yobj) (err error)
// Upgrade from oldVersion to newVersion // Upgrade from oldVersion to newVersion
func upgradeConfigSchema(oldVersion int, diskConfig *yobj) error { func upgradeConfigSchema(oldVersion int, diskConf yobj) (err error) {
switch oldVersion { upgrades := []upgradeFunc{
case 0: upgradeSchema0to1,
err := upgradeSchema0to1(diskConfig) upgradeSchema1to2,
if err != nil { upgradeSchema2to3,
return err upgradeSchema3to4,
} upgradeSchema4to5,
fallthrough upgradeSchema5to6,
case 1: upgradeSchema6to7,
err := upgradeSchema1to2(diskConfig) upgradeSchema7to8,
if err != nil {
return err
}
fallthrough
case 2:
err := upgradeSchema2to3(diskConfig)
if err != nil {
return err
}
fallthrough
case 3:
err := upgradeSchema3to4(diskConfig)
if err != nil {
return err
}
fallthrough
case 4:
err := upgradeSchema4to5(diskConfig)
if err != nil {
return err
}
fallthrough
case 5:
err := upgradeSchema5to6(diskConfig)
if err != nil {
return err
}
fallthrough
case 6:
err := upgradeSchema6to7(diskConfig)
if err != nil {
return err
}
case 7:
err := upgradeSchema7to8(diskConfig)
if err != nil {
return err
}
default:
err := fmt.Errorf("configuration file contains unknown schema_version, abort")
log.Println(err)
return err
} }
configFile := config.getConfigFilename() n := 0
body, err := yaml.Marshal(diskConfig) for i, u := range upgrades {
if i >= oldVersion {
err = u(diskConf)
if err != nil {
return err
}
n++
}
}
if n == 0 {
return fmt.Errorf("unknown configuration schema version %d", oldVersion)
}
body, err := yaml.Marshal(diskConf)
if err != nil { if err != nil {
log.Printf("Couldn't generate YAML file: %s", err) return fmt.Errorf("generating new config: %w", err)
return err
} }
config.fileData = body config.fileData = body
err = file.SafeWrite(configFile, body) confFile := config.getConfigFilename()
err = file.SafeWrite(confFile, body)
if err != nil { if err != nil {
log.Printf("Couldn't save YAML config: %s", err) return fmt.Errorf("saving new config: %w", err)
return err
} }
return nil return nil
@ -134,7 +109,7 @@ func upgradeConfigSchema(oldVersion int, diskConfig *yobj) error {
// The first schema upgrade: // The first schema upgrade:
// No more "dnsfilter.txt", filters are now kept in data/filters/ // No more "dnsfilter.txt", filters are now kept in data/filters/
func upgradeSchema0to1(diskConfig *yobj) error { func upgradeSchema0to1(diskConf yobj) error {
log.Printf("%s(): called", funcName()) log.Printf("%s(): called", funcName())
dnsFilterPath := filepath.Join(Context.workDir, "dnsfilter.txt") dnsFilterPath := filepath.Join(Context.workDir, "dnsfilter.txt")
@ -147,7 +122,7 @@ func upgradeSchema0to1(diskConfig *yobj) error {
} }
} }
(*diskConfig)["schema_version"] = 1 diskConf["schema_version"] = 1
return nil return nil
} }
@ -155,7 +130,7 @@ func upgradeSchema0to1(diskConfig *yobj) error {
// Second schema upgrade: // Second schema upgrade:
// coredns is now dns in config // coredns is now dns in config
// delete 'Corefile', since we don't use that anymore // delete 'Corefile', since we don't use that anymore
func upgradeSchema1to2(diskConfig *yobj) error { func upgradeSchema1to2(diskConf yobj) error {
log.Printf("%s(): called", funcName()) log.Printf("%s(): called", funcName())
coreFilePath := filepath.Join(Context.workDir, "Corefile") coreFilePath := filepath.Join(Context.workDir, "Corefile")
@ -168,22 +143,22 @@ func upgradeSchema1to2(diskConfig *yobj) error {
} }
} }
if _, ok := (*diskConfig)["dns"]; !ok { if _, ok := diskConf["dns"]; !ok {
(*diskConfig)["dns"] = (*diskConfig)["coredns"] diskConf["dns"] = diskConf["coredns"]
delete((*diskConfig), "coredns") delete(diskConf, "coredns")
} }
(*diskConfig)["schema_version"] = 2 diskConf["schema_version"] = 2
return nil return nil
} }
// Third schema upgrade: // Third schema upgrade:
// Bootstrap DNS becomes an array // Bootstrap DNS becomes an array
func upgradeSchema2to3(diskConfig *yobj) error { func upgradeSchema2to3(diskConf yobj) error {
log.Printf("%s(): called", funcName()) log.Printf("%s(): called", funcName())
// Let's read dns configuration from diskConfig // Let's read dns configuration from diskConf
dnsConfig, ok := (*diskConfig)["dns"] dnsConfig, ok := diskConf["dns"]
if !ok { if !ok {
return fmt.Errorf("no DNS configuration in config file") return fmt.Errorf("no DNS configuration in config file")
} }
@ -208,21 +183,21 @@ func upgradeSchema2to3(diskConfig *yobj) error {
newBootstrapConfig := []string{fmt.Sprint(bootstrapDNS)} newBootstrapConfig := []string{fmt.Sprint(bootstrapDNS)}
newDNSConfig["bootstrap_dns"] = newBootstrapConfig newDNSConfig["bootstrap_dns"] = newBootstrapConfig
(*diskConfig)["dns"] = newDNSConfig diskConf["dns"] = newDNSConfig
// Bump schema version // Bump schema version
(*diskConfig)["schema_version"] = 3 diskConf["schema_version"] = 3
return nil return nil
} }
// Add use_global_blocked_services=true setting for existing "clients" array // Add use_global_blocked_services=true setting for existing "clients" array
func upgradeSchema3to4(diskConfig *yobj) error { func upgradeSchema3to4(diskConf yobj) error {
log.Printf("%s(): called", funcName()) log.Printf("%s(): called", funcName())
(*diskConfig)["schema_version"] = 4 diskConf["schema_version"] = 4
clients, ok := (*diskConfig)["clients"] clients, ok := diskConf["clients"]
if !ok { if !ok {
return nil return nil
} }
@ -253,12 +228,12 @@ func upgradeSchema3to4(diskConfig *yobj) error {
// - name: "..." // - name: "..."
// password: "..." // password: "..."
// ... // ...
func upgradeSchema4to5(diskConfig *yobj) error { func upgradeSchema4to5(diskConf yobj) error {
log.Printf("%s(): called", funcName()) log.Printf("%s(): called", funcName())
(*diskConfig)["schema_version"] = 5 diskConf["schema_version"] = 5
name, ok := (*diskConfig)["auth_name"] name, ok := diskConf["auth_name"]
if !ok { if !ok {
return nil return nil
} }
@ -268,7 +243,7 @@ func upgradeSchema4to5(diskConfig *yobj) error {
return nil return nil
} }
pass, ok := (*diskConfig)["auth_pass"] pass, ok := diskConf["auth_pass"]
if !ok { if !ok {
return nil return nil
} }
@ -292,7 +267,7 @@ func upgradeSchema4to5(diskConfig *yobj) error {
PasswordHash: string(hash), PasswordHash: string(hash),
} }
users := []User{u} users := []User{u}
(*diskConfig)["users"] = users diskConf["users"] = users
return nil return nil
} }
@ -308,12 +283,12 @@ func upgradeSchema4to5(diskConfig *yobj) error {
// ids: // ids:
// - 127.0.0.1 // - 127.0.0.1
// - ... // - ...
func upgradeSchema5to6(diskConfig *yobj) error { func upgradeSchema5to6(diskConf yobj) error {
log.Printf("%s(): called", funcName()) log.Printf("%s(): called", funcName())
(*diskConfig)["schema_version"] = 6 diskConf["schema_version"] = 6
clients, ok := (*diskConfig)["clients"] clients, ok := diskConf["clients"]
if !ok { if !ok {
return nil return nil
} }
@ -378,12 +353,12 @@ func upgradeSchema5to6(diskConfig *yobj) error {
// dhcpv4: // dhcpv4:
// gateway_ip: 192.168.56.1 // gateway_ip: 192.168.56.1
// ... // ...
func upgradeSchema6to7(diskConfig *yobj) error { func upgradeSchema6to7(diskConf yobj) error {
log.Printf("Upgrade yaml: 6 to 7") log.Printf("Upgrade yaml: 6 to 7")
(*diskConfig)["schema_version"] = 7 diskConf["schema_version"] = 7
dhcpVal, ok := (*diskConfig)["dhcp"] dhcpVal, ok := diskConf["dhcp"]
if !ok { if !ok {
return nil return nil
} }
@ -462,12 +437,12 @@ func upgradeSchema6to7(diskConfig *yobj) error {
// 'bind_hosts': // 'bind_hosts':
// - '127.0.0.1' // - '127.0.0.1'
// //
func upgradeSchema7to8(diskConfig *yobj) (err error) { func upgradeSchema7to8(diskConf yobj) (err error) {
log.Printf("Upgrade yaml: 7 to 8") log.Printf("Upgrade yaml: 7 to 8")
(*diskConfig)["schema_version"] = 8 diskConf["schema_version"] = 8
dnsVal, ok := (*diskConfig)["dns"] dnsVal, ok := diskConf["dns"]
if !ok { if !ok {
return nil return nil
} }

View File

@ -12,7 +12,7 @@ import (
func TestUpgradeSchema1to2(t *testing.T) { func TestUpgradeSchema1to2(t *testing.T) {
diskConf := testDiskConf(1) diskConf := testDiskConf(1)
err := upgradeSchema1to2(&diskConf) err := upgradeSchema1to2(diskConf)
require.Nil(t, err) require.Nil(t, err)
require.Equal(t, diskConf["schema_version"], 2) require.Equal(t, diskConf["schema_version"], 2)
@ -35,7 +35,7 @@ func TestUpgradeSchema1to2(t *testing.T) {
func TestUpgradeSchema2to3(t *testing.T) { func TestUpgradeSchema2to3(t *testing.T) {
diskConf := testDiskConf(2) diskConf := testDiskConf(2)
err := upgradeSchema2to3(&diskConf) err := upgradeSchema2to3(diskConf)
require.Nil(t, err) require.Nil(t, err)
require.Equal(t, diskConf["schema_version"], 3) require.Equal(t, diskConf["schema_version"], 3)
@ -73,7 +73,7 @@ func TestUpgradeSchema7to8(t *testing.T) {
"schema_version": 7, "schema_version": 7,
} }
err := upgradeSchema7to8(&oldConf) err := upgradeSchema7to8(oldConf)
require.Nil(t, err) require.Nil(t, err)
require.Equal(t, oldConf["schema_version"], 8) require.Equal(t, oldConf["schema_version"], 8)