Merge pull request #154 in DNS/adguard-dns from fix/596 to master

* commit 'b74eded414cf3f31da6feb185619afab85233b5a': (34 commits)
  [change] control: fix issues from review
  [change] config: fix default upstreams list
  * client: typo
  * client: remove log
  * client: fix grammar
  [change] control: add upstreams validation before dns config test
  [change] control: add upstreams validation
  [change] control: update bootstrap DNS check
  * client: remove empty elements from upstream and bootstrap
  * client: locales and pass object to testUpstream
  [change] config, control, openapi: fix issues from reviw
  [fix] control: fix json decode for upstream config
  * client: upstream form
  [change] control, openapi: Handle upstreams test with JSON
  [change] upgrade_test: rework tests
  [change] upgrade_test: add test for upgrade
  [change] control: Remove unuseful check
  [change] control: Fix issues from review
  [change] dnsforward: Add comments for public fields
  [change] control: Handle upstream config with JSON
  ...
This commit is contained in:
Alexey Dmitrievskiy 2019-03-12 14:30:31 +03:00
commit f857ed74ec
20 changed files with 776 additions and 346 deletions

View File

@ -79,7 +79,7 @@
"no_settings": "No settings",
"general_settings": "General settings",
"upstream_dns": "Upstream DNS servers",
"upstream_dns_hint": "If you keep this field empty, AdGuard Home will use <a href='https:\/\/1.1.1.1\/' target='_blank'>Cloudflare DNS<\/a> as an upstream. Use tls:\/\/ prefix for DNS over TLS servers.",
"upstream_dns_hint": "If you keep this field empty, AdGuard Home will use <a href='https:\/\/1.1.1.1\/' target='_blank'>Cloudflare DNS<\/a> as an upstream.",
"test_upstream_btn": "Test upstreams",
"apply_btn": "Apply",
"disabled_filtering_toast": "Disabled filtering",
@ -246,5 +246,8 @@
"form_error_equal": "Shouldn't be equal",
"form_error_password": "Password mismatched",
"reset_settings": "Reset settings",
"update_announcement": "AdGuard Home {{version}} is now available! <0>Click here<\/0> for more info."
"update_announcement": "AdGuard Home {{version}} is now available! <0>Click here</0> for more info.",
"upstream_parallel": "Use parallel queries to speed up resolving by simultaneously querying all upstream servers",
"bootstrap_dns": "Bootstrap DNS servers",
"bootstrap_dns_desc": "Bootstrap DNS servers are used to resolve IP addresses of the DOH/DOT resolvers you specify as upstreams."
}

View File

@ -3,7 +3,7 @@ import round from 'lodash/round';
import { t } from 'i18next';
import { showLoading, hideLoading } from 'react-redux-loading-bar';
import { normalizeHistory, normalizeFilteringStatus, normalizeLogs } from '../helpers/helpers';
import { normalizeHistory, normalizeFilteringStatus, normalizeLogs, normalizeTextarea } from '../helpers/helpers';
import { SETTINGS_NAMES } from '../helpers/constants';
import Api from '../api/Api';
@ -452,10 +452,18 @@ export const setUpstreamRequest = createAction('SET_UPSTREAM_REQUEST');
export const setUpstreamFailure = createAction('SET_UPSTREAM_FAILURE');
export const setUpstreamSuccess = createAction('SET_UPSTREAM_SUCCESS');
export const setUpstream = url => async (dispatch) => {
export const setUpstream = config => async (dispatch) => {
dispatch(setUpstreamRequest());
try {
await apiClient.setUpstream(url);
const values = { ...config };
values.bootstrap_dns = (
values.bootstrap_dns && normalizeTextarea(values.bootstrap_dns)
) || [];
values.upstream_dns = (
values.upstream_dns && normalizeTextarea(values.upstream_dns)
) || [];
await apiClient.setUpstream(values);
dispatch(addSuccessToast('updated_upstream_dns_toast'));
dispatch(setUpstreamSuccess());
} catch (error) {
@ -468,11 +476,18 @@ export const testUpstreamRequest = createAction('TEST_UPSTREAM_REQUEST');
export const testUpstreamFailure = createAction('TEST_UPSTREAM_FAILURE');
export const testUpstreamSuccess = createAction('TEST_UPSTREAM_SUCCESS');
export const testUpstream = servers => async (dispatch) => {
export const testUpstream = config => async (dispatch) => {
dispatch(testUpstreamRequest());
try {
const upstreamResponse = await apiClient.testUpstream(servers);
const values = { ...config };
values.bootstrap_dns = (
values.bootstrap_dns && normalizeTextarea(values.bootstrap_dns)
) || [];
values.upstream_dns = (
values.upstream_dns && normalizeTextarea(values.upstream_dns)
) || [];
const upstreamResponse = await apiClient.testUpstream(values);
const testMessages = Object.keys(upstreamResponse).map((key) => {
const message = upstreamResponse[key];
if (message !== 'OK') {

View File

@ -34,7 +34,7 @@ export default class Api {
GLOBAL_QUERY_LOG = { path: 'querylog', method: 'GET' };
GLOBAL_QUERY_LOG_ENABLE = { path: 'querylog_enable', method: 'POST' };
GLOBAL_QUERY_LOG_DISABLE = { path: 'querylog_disable', method: 'POST' };
GLOBAL_SET_UPSTREAM_DNS = { path: 'set_upstream_dns', method: 'POST' };
GLOBAL_SET_UPSTREAM_DNS = { path: 'set_upstreams_config', method: 'POST' };
GLOBAL_TEST_UPSTREAM_DNS = { path: 'test_upstream_dns', method: 'POST' };
GLOBAL_VERSION = { path: 'version.json', method: 'GET' };
GLOBAL_ENABLE_PROTECTION = { path: 'enable_protection', method: 'POST' };
@ -110,7 +110,7 @@ export default class Api {
const { path, method } = this.GLOBAL_SET_UPSTREAM_DNS;
const config = {
data: url,
header: { 'Content-Type': 'text/plain' },
headers: { 'Content-Type': 'application/json' },
};
return this.makeRequest(path, method, config);
}
@ -119,7 +119,7 @@ export default class Api {
const { path, method } = this.GLOBAL_TEST_UPSTREAM_DNS;
const config = {
data: servers,
header: { 'Content-Type': 'text/plain' },
headers: { 'Content-Type': 'application/json' },
};
return this.makeRequest(path, method, config);
}

View File

@ -1,97 +0,0 @@
import React, { Component } from 'react';
import PropTypes from 'prop-types';
import classnames from 'classnames';
import { Trans, withNamespaces } from 'react-i18next';
import Card from '../ui/Card';
class Upstream extends Component {
handleChange = (e) => {
const { value } = e.currentTarget;
this.props.handleUpstreamChange(value);
};
handleSubmit = (e) => {
e.preventDefault();
this.props.handleUpstreamSubmit();
};
handleTest = () => {
this.props.handleUpstreamTest();
}
render() {
const testButtonClass = classnames({
'btn btn-primary btn-standard mr-2': true,
'btn btn-primary btn-standard mr-2 btn-loading': this.props.processingTestUpstream,
});
const { t } = this.props;
return (
<Card
title={ t('upstream_dns') }
subtitle={ t('upstream_dns_hint') }
bodyType="card-body box-body--settings"
>
<div className="row">
<div className="col">
<form>
<textarea
className="form-control form-control--textarea"
value={this.props.upstreamDns}
onChange={this.handleChange}
/>
<div className="card-actions">
<button
className={testButtonClass}
type="button"
onClick={this.handleTest}
>
<Trans>test_upstream_btn</Trans>
</button>
<button
className="btn btn-success btn-standard"
type="submit"
onClick={this.handleSubmit}
>
<Trans>apply_btn</Trans>
</button>
</div>
</form>
<hr/>
<div className="list leading-loose">
<Trans>examples_title</Trans>:
<ol className="leading-loose">
<li>
<code>1.1.1.1</code> - { t('example_upstream_regular') }
</li>
<li>
<code>tls://1dot1dot1dot1.cloudflare-dns.com</code> - <span dangerouslySetInnerHTML={{ __html: t('example_upstream_dot') }} />
</li>
<li>
<code>https://cloudflare-dns.com/dns-query</code> - <span dangerouslySetInnerHTML={{ __html: t('example_upstream_doh') }} />
</li>
<li>
<code>tcp://1.1.1.1</code> - { t('example_upstream_tcp') }
</li>
<li>
<code>sdns://...</code> - <span dangerouslySetInnerHTML={{ __html: t('example_upstream_sdns') }} />
</li>
</ol>
</div>
</div>
</div>
</Card>
);
}
}
Upstream.propTypes = {
upstreamDns: PropTypes.string,
processingTestUpstream: PropTypes.bool,
handleUpstreamChange: PropTypes.func,
handleUpstreamSubmit: PropTypes.func,
handleUpstreamTest: PropTypes.func,
t: PropTypes.func,
};
export default withNamespaces()(Upstream);

View File

@ -0,0 +1,32 @@
import React from 'react';
import PropTypes from 'prop-types';
import { Trans, withNamespaces } from 'react-i18next';
const Examples = props => (
<div className="list leading-loose">
<Trans>examples_title</Trans>:
<ol className="leading-loose">
<li>
<code>1.1.1.1</code> - { props.t('example_upstream_regular') }
</li>
<li>
<code>tls://1dot1dot1dot1.cloudflare-dns.com</code> - <span dangerouslySetInnerHTML={{ __html: props.t('example_upstream_dot') }} />
</li>
<li>
<code>https://cloudflare-dns.com/dns-query</code> - <span dangerouslySetInnerHTML={{ __html: props.t('example_upstream_doh') }} />
</li>
<li>
<code>tcp://1.1.1.1</code> - { props.t('example_upstream_tcp') }
</li>
<li>
<code>sdns://...</code> - <span dangerouslySetInnerHTML={{ __html: props.t('example_upstream_sdns') }} />
</li>
</ol>
</div>
);
Examples.propTypes = {
t: PropTypes.func.isRequired,
};
export default withNamespaces()(Examples);

View File

@ -0,0 +1,139 @@
import React from 'react';
import { connect } from 'react-redux';
import PropTypes from 'prop-types';
import { Field, reduxForm, formValueSelector } from 'redux-form';
import { Trans, withNamespaces } from 'react-i18next';
import flow from 'lodash/flow';
import classnames from 'classnames';
import { renderSelectField } from '../../../helpers/form';
let Form = (props) => {
const {
t,
handleSubmit,
testUpstream,
upstreamDns,
bootstrapDns,
allServers,
submitting,
invalid,
processingSetUpstream,
processingTestUpstream,
} = props;
const testButtonClass = classnames({
'btn btn-primary btn-standard mr-2': true,
'btn btn-primary btn-standard mr-2 btn-loading': processingTestUpstream,
});
return (
<form onSubmit={handleSubmit}>
<div className="row">
<div className="col-12">
<div className="form__group form__group--settings">
<label className="form__label" htmlFor="upstream_dns">
<Trans>upstream_dns</Trans>
</label>
<Field
id="upstream_dns"
name="upstream_dns"
component="textarea"
type="text"
className="form-control form-control--textarea"
placeholder={t('upstream_dns')}
/>
</div>
</div>
<div className="col-12">
<div className="form__group form__group--settings">
<Field
name="all_servers"
type="checkbox"
component={renderSelectField}
placeholder={t('upstream_parallel')}
/>
</div>
</div>
<div className="col-12">
<div className="form__group">
<label className="form__label" htmlFor="bootstrap_dns">
<Trans>bootstrap_dns</Trans>
</label>
<div className="form__desc form__desc--top">
<Trans>bootstrap_dns_desc</Trans>
</div>
<Field
id="bootstrap_dns"
name="bootstrap_dns"
component="textarea"
type="text"
className="form-control"
placeholder={t('bootstrap_dns')}
/>
</div>
</div>
</div>
<div className="card-actions">
<div className="btn-list">
<button
type="button"
className={testButtonClass}
onClick={() => testUpstream({
upstream_dns: upstreamDns,
bootstrap_dns: bootstrapDns,
all_servers: allServers,
})}
disabled={!upstreamDns || processingTestUpstream}
>
<Trans>test_upstream_btn</Trans>
</button>
<button
type="submit"
className="btn btn-success btn-standard"
disabled={
submitting
|| invalid
|| processingSetUpstream
|| processingTestUpstream
}
>
<Trans>apply_btn</Trans>
</button>
</div>
</div>
</form>
);
};
Form.propTypes = {
handleSubmit: PropTypes.func,
testUpstream: PropTypes.func,
submitting: PropTypes.bool,
invalid: PropTypes.bool,
initialValues: PropTypes.object,
upstreamDns: PropTypes.string,
bootstrapDns: PropTypes.string,
allServers: PropTypes.bool,
processingTestUpstream: PropTypes.bool,
processingSetUpstream: PropTypes.bool,
t: PropTypes.func,
};
const selector = formValueSelector('upstreamForm');
Form = connect((state) => {
const upstreamDns = selector(state, 'upstream_dns');
const bootstrapDns = selector(state, 'bootstrap_dns');
const allServers = selector(state, 'all_servers');
return {
upstreamDns,
bootstrapDns,
allServers,
};
})(Form);
export default flow([
withNamespaces(),
reduxForm({ form: 'upstreamForm' }),
])(Form);

View File

@ -0,0 +1,67 @@
import React, { Component } from 'react';
import PropTypes from 'prop-types';
import { withNamespaces } from 'react-i18next';
import Form from './Form';
import Examples from './Examples';
import Card from '../../ui/Card';
class Upstream extends Component {
handleSubmit = (values) => {
this.props.setUpstream(values);
};
handleTest = (values) => {
this.props.testUpstream(values);
}
render() {
const {
t,
upstreamDns: upstream_dns,
bootstrapDns: bootstrap_dns,
allServers: all_servers,
processingSetUpstream,
processingTestUpstream,
} = this.props;
return (
<Card
title={ t('upstream_dns') }
subtitle={ t('upstream_dns_hint') }
bodyType="card-body box-body--settings"
>
<div className="row">
<div className="col">
<Form
initialValues={{
upstream_dns,
bootstrap_dns,
all_servers,
}}
testUpstream={this.handleTest}
onSubmit={this.handleSubmit}
processingTestUpstream={processingTestUpstream}
processingSetUpstream={processingSetUpstream}
/>
<hr/>
<Examples />
</div>
</div>
</Card>
);
}
}
Upstream.propTypes = {
upstreamDns: PropTypes.string,
bootstrapDns: PropTypes.string,
allServers: PropTypes.bool,
setUpstream: PropTypes.func.isRequired,
testUpstream: PropTypes.func.isRequired,
processingSetUpstream: PropTypes.bool.isRequired,
processingTestUpstream: PropTypes.bool.isRequired,
t: PropTypes.func.isRequired,
};
export default withNamespaces()(Upstream);

View File

@ -41,22 +41,6 @@ class Settings extends Component {
this.props.getTlsStatus();
}
handleUpstreamChange = (value) => {
this.props.handleUpstreamChange({ upstreamDns: value });
};
handleUpstreamSubmit = () => {
this.props.setUpstream(this.props.dashboard.upstreamDns);
};
handleUpstreamTest = () => {
if (this.props.dashboard.upstreamDns.length > 0) {
this.props.testUpstream(this.props.dashboard.upstreamDns);
} else {
this.props.addErrorToast({ error: this.props.t('no_servers_specified') });
}
};
renderSettings = (settings) => {
if (Object.keys(settings).length > 0) {
return Object.keys(settings).map((key) => {
@ -75,8 +59,7 @@ class Settings extends Component {
}
render() {
const { settings, t } = this.props;
const { upstreamDns } = this.props.dashboard;
const { settings, dashboard, t } = this.props;
return (
<Fragment>
<PageTitle title={ t('settings') } />
@ -91,11 +74,13 @@ class Settings extends Component {
</div>
</Card>
<Upstream
upstreamDns={upstreamDns}
upstreamDns={dashboard.upstreamDns}
bootstrapDns={dashboard.bootstrapDns}
allServers={dashboard.allServers}
setUpstream={this.props.setUpstream}
testUpstream={this.props.testUpstream}
processingTestUpstream={settings.processingTestUpstream}
handleUpstreamChange={this.handleUpstreamChange}
handleUpstreamSubmit={this.handleUpstreamSubmit}
handleUpstreamTest={this.handleUpstreamTest}
processingSetUpstream={settings.processingSetUpstream}
/>
<Encryption
encryption={this.props.encryption}
@ -125,7 +110,6 @@ Settings.propTypes = {
toggleSetting: PropTypes.func,
handleUpstreamChange: PropTypes.func,
setUpstream: PropTypes.func,
upstream: PropTypes.string,
t: PropTypes.func,
};

View File

@ -91,6 +91,10 @@
line-height: 1.5;
}
.checkbox__label-text--long {
max-width: initial;
}
.checkbox__label-title {
display: block;
line-height: 1.5;

View File

@ -32,7 +32,7 @@ export const renderSelectField = ({
disabled={disabled}
/>
<span className="checkbox__label">
<span className="checkbox__label-text">
<span className="checkbox__label-text checkbox__label-text--long">
<span className="checkbox__label-title">{placeholder}</span>
</span>
</span>

View File

@ -201,3 +201,5 @@ export const redirectToCurrentProtocol = (values, httpPort = 80) => {
window.location.replace(`http://${hostname}:${httpPort}/${hash}`);
}
};
export const normalizeTextarea = text => text && text.replace(/[;, ]/g, '\n').split('\n').filter(n => n);

View File

@ -51,6 +51,8 @@ const dashboard = handleActions({
dns_address: dnsAddress,
querylog_enabled: queryLogEnabled,
upstream_dns: upstreamDns,
bootstrap_dns: bootstrapDns,
all_servers: allServers,
protection_enabled: protectionEnabled,
language,
http_port: httpPort,
@ -64,6 +66,8 @@ const dashboard = handleActions({
dnsAddress,
queryLogEnabled,
upstreamDns: upstreamDns.join('\n'),
bootstrapDns: bootstrapDns.join('\n'),
allServers,
protectionEnabled,
language,
httpPort,
@ -171,7 +175,9 @@ const dashboard = handleActions({
logStatusProcessing: false,
processingVersion: true,
processingFiltering: true,
upstreamDns: [],
upstreamDns: '',
bootstrapDns: '',
allServers: false,
protectionEnabled: false,
processingProtection: false,
httpPort: 80,

View File

@ -60,7 +60,8 @@ type dnsConfig struct {
UpstreamDNS []string `yaml:"upstream_dns"`
}
var defaultDNS = []string{"tls://1.1.1.1", "tls://1.0.0.1"}
var defaultDNS = []string{"https://dns.cloudflare.com/dns-query"}
var defaultBootstrap = []string{"1.1.1.1"}
type tlsConfigSettings struct {
Enabled bool `yaml:"enabled" json:"enabled"` // Enabled is the encryption (DOT/DOH/HTTPS) status
@ -114,7 +115,8 @@ var config = configuration{
QueryLogEnabled: true,
Ratelimit: 20,
RefuseAny: true,
BootstrapDNS: "8.8.8.8:53",
BootstrapDNS: defaultBootstrap,
AllServers: false,
},
UpstreamDNS: defaultDNS,
},

View File

@ -27,6 +27,8 @@ const updatePeriod = time.Minute * 30
var versionCheckJSON []byte
var versionCheckLastTime time.Time
var protocols = []string{"tls://", "https://", "tcp://", "sdns://"}
const versionCheckURL = "https://adguardteam.github.io/AdGuardHome/version.json"
const versionCheckPeriod = time.Hour * 8
@ -41,9 +43,7 @@ var client = &http.Client{
func returnOK(w http.ResponseWriter) {
_, err := fmt.Fprintf(w, "OK\n")
if err != nil {
errorText := fmt.Sprintf("Couldn't write body: %s", err)
log.Error(errorText)
http.Error(w, errorText, http.StatusInternalServerError)
httpError(w, http.StatusInternalServerError, "Couldn't write body: %s", err)
}
}
@ -85,23 +85,20 @@ func handleStatus(w http.ResponseWriter, r *http.Request) {
"running": isRunning(),
"bootstrap_dns": config.DNS.BootstrapDNS,
"upstream_dns": config.DNS.UpstreamDNS,
"all_servers": config.DNS.AllServers,
"version": VersionString,
"language": config.Language,
}
jsonVal, err := json.Marshal(data)
if err != nil {
errorText := fmt.Sprintf("Unable to marshal status json: %s", err)
log.Error(errorText)
http.Error(w, errorText, 500)
httpError(w, http.StatusInternalServerError, "Unable to marshal status json: %s", err)
return
}
w.Header().Set("Content-Type", "application/json")
_, err = w.Write(jsonVal)
if err != nil {
errorText := fmt.Sprintf("Unable to write response json: %s", err)
log.Error(errorText)
http.Error(w, errorText, 500)
httpError(w, http.StatusInternalServerError, "Unable to write response json: %s", err)
return
}
}
@ -139,18 +136,14 @@ func handleQueryLog(w http.ResponseWriter, r *http.Request) {
jsonVal, err := json.Marshal(data)
if err != nil {
errorText := fmt.Sprintf("Couldn't marshal data into json: %s", err)
log.Error(errorText)
http.Error(w, errorText, http.StatusInternalServerError)
httpError(w, http.StatusInternalServerError, "Couldn't marshal data into json: %s", err)
return
}
w.Header().Set("Content-Type", "application/json")
_, err = w.Write(jsonVal)
if err != nil {
errorText := fmt.Sprintf("Unable to write response json: %s", err)
log.Error(errorText)
http.Error(w, errorText, http.StatusInternalServerError)
httpError(w, http.StatusInternalServerError, "Unable to write response json: %s", err)
}
}
@ -196,9 +189,7 @@ func handleStatsTop(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, err := w.Write(statsJSON.Bytes())
if err != nil {
errorText := fmt.Sprintf("Couldn't write body: %s", err)
log.Error(errorText)
http.Error(w, errorText, http.StatusInternalServerError)
httpError(w, http.StatusInternalServerError, "Couldn't write body: %s", err)
}
}
@ -208,9 +199,7 @@ func handleStatsReset(w http.ResponseWriter, r *http.Request) {
dnsServer.PurgeStats()
_, err := fmt.Fprintf(w, "OK\n")
if err != nil {
errorText := fmt.Sprintf("Couldn't write body: %s", err)
log.Error(errorText)
http.Error(w, errorText, http.StatusInternalServerError)
httpError(w, http.StatusInternalServerError, "Couldn't write body: %s", err)
}
}
@ -221,17 +210,13 @@ func handleStats(w http.ResponseWriter, r *http.Request) {
statsJSON, err := json.Marshal(summed)
if err != nil {
errorText := fmt.Sprintf("Unable to marshal status json: %s", err)
log.Error(errorText)
http.Error(w, errorText, 500)
httpError(w, http.StatusInternalServerError, "Unable to marshal status json: %s", err)
return
}
w.Header().Set("Content-Type", "application/json")
_, err = w.Write(statsJSON)
if err != nil {
errorText := fmt.Sprintf("Unable to write response json: %s", err)
log.Error(errorText)
http.Error(w, errorText, 500)
httpError(w, http.StatusInternalServerError, "Unable to write response json: %s", err)
return
}
}
@ -259,40 +244,31 @@ func handleStatsHistory(w http.ResponseWriter, r *http.Request) {
// parse start and end time
startTime, err := time.Parse(time.RFC3339, r.URL.Query().Get("start_time"))
if err != nil {
errorText := fmt.Sprintf("Must specify valid start_time parameter: %s", err)
log.Error(errorText)
http.Error(w, errorText, http.StatusBadRequest)
httpError(w, http.StatusBadRequest, "Must specify valid start_time parameter: %s", err)
return
}
endTime, err := time.Parse(time.RFC3339, r.URL.Query().Get("end_time"))
if err != nil {
errorText := fmt.Sprintf("Must specify valid end_time parameter: %s", err)
log.Error(errorText)
http.Error(w, errorText, http.StatusBadRequest)
httpError(w, http.StatusBadRequest, "Must specify valid end_time parameter: %s", err)
return
}
data, err := dnsServer.GetStatsHistory(timeUnit, startTime, endTime)
if err != nil {
errorText := fmt.Sprintf("Cannot get stats history: %s", err)
http.Error(w, errorText, http.StatusBadRequest)
httpError(w, http.StatusBadRequest, "Cannot get stats history: %s", err)
return
}
statsJSON, err := json.Marshal(data)
if err != nil {
errorText := fmt.Sprintf("Unable to marshal status json: %s", err)
log.Error(errorText)
http.Error(w, errorText, http.StatusInternalServerError)
httpError(w, http.StatusInternalServerError, "Unable to marshal status json: %s", err)
return
}
w.Header().Set("Content-Type", "application/json")
_, err = w.Write(statsJSON)
if err != nil {
errorText := fmt.Sprintf("Unable to write response json: %s", err)
log.Error(errorText)
http.Error(w, errorText, http.StatusInternalServerError)
httpError(w, http.StatusInternalServerError, "Unable to write response json: %s", err)
return
}
}
@ -322,68 +298,108 @@ func sortByValue(m map[string]int) []string {
// upstreams configuration
// -----------------------
func handleSetUpstreamDNS(w http.ResponseWriter, r *http.Request) {
// TODO this struct will become unnecessary after config file rework
type upstreamConfig struct {
Upstreams []string `json:"upstream_dns"` // Upstreams
BootstrapDNS []string `json:"bootstrap_dns"` // Bootstrap DNS
AllServers bool `json:"all_servers"` // --all-servers param for dnsproxy
}
func handleSetUpstreamConfig(w http.ResponseWriter, r *http.Request) {
log.Tracef("%s %v", r.Method, r.URL)
body, err := ioutil.ReadAll(r.Body)
newconfig := upstreamConfig{}
err := json.NewDecoder(r.Body).Decode(&newconfig)
if err != nil {
errorText := fmt.Sprintf("Failed to read request body: %s", err)
log.Error(errorText)
http.Error(w, errorText, http.StatusBadRequest)
httpError(w, http.StatusBadRequest, "Failed to parse new upstreams config json: %s", err)
return
}
// if empty body -- user is asking for default servers
hosts := strings.Fields(string(body))
if len(hosts) == 0 {
for _, u := range newconfig.Upstreams {
if err = validateUpstream(u); err != nil {
httpError(w, http.StatusBadRequest, "%s can not be used as upstream cause: %s", u, err)
return
}
}
config.DNS.UpstreamDNS = defaultDNS
} else {
config.DNS.UpstreamDNS = hosts
if len(newconfig.Upstreams) > 0 {
config.DNS.UpstreamDNS = newconfig.Upstreams
}
err = writeAllConfigs()
if err != nil {
errorText := fmt.Sprintf("Couldn't write config file: %s", err)
log.Error(errorText)
http.Error(w, errorText, http.StatusInternalServerError)
// bootstrap servers are plain DNS only.
for _, host := range newconfig.BootstrapDNS {
if err := checkPlainDNS(host); err != nil {
httpError(w, http.StatusBadRequest, "%s can not be used as bootstrap dns cause: %s", host, err)
return
}
err = reconfigureDNSServer()
if err != nil {
errorText := fmt.Sprintf("Couldn't reconfigure the DNS server: %s", err)
log.Error(errorText)
http.Error(w, errorText, http.StatusInternalServerError)
return
}
_, err = fmt.Fprintf(w, "OK %d servers\n", len(hosts))
if err != nil {
errorText := fmt.Sprintf("Couldn't write body: %s", err)
log.Error(errorText)
http.Error(w, errorText, http.StatusInternalServerError)
config.DNS.BootstrapDNS = defaultBootstrap
if len(newconfig.BootstrapDNS) > 0 {
config.DNS.BootstrapDNS = newconfig.BootstrapDNS
}
config.DNS.AllServers = newconfig.AllServers
httpUpdateConfigReloadDNSReturnOK(w, r)
}
func validateUpstream(upstream string) error {
for _, proto := range protocols {
if strings.HasPrefix(upstream, proto) {
return nil
}
}
if strings.Contains(upstream, "://") {
return fmt.Errorf("wrong protocol")
}
return checkPlainDNS(upstream)
}
// checkPlainDNS checks if host is plain DNS
func checkPlainDNS(upstream string) error {
// Check if host is ip without port
if net.ParseIP(upstream) != nil {
return nil
}
// Check if host is ip with port
ip, port, err := net.SplitHostPort(upstream)
if err != nil {
return err
}
if net.ParseIP(ip) == nil {
return fmt.Errorf("%s is not a valid IP", ip)
}
_, err = strconv.ParseInt(port, 0, 64)
if err != nil {
return fmt.Errorf("%s is not a valid port: %s", port, err)
}
return nil
}
func handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) {
log.Tracef("%s %v", r.Method, r.URL)
body, err := ioutil.ReadAll(r.Body)
upstreamConfig := upstreamConfig{}
err := json.NewDecoder(r.Body).Decode(&upstreamConfig)
if err != nil {
errorText := fmt.Sprintf("Failed to read request body: %s", err)
log.Error(errorText)
http.Error(w, errorText, 400)
httpError(w, http.StatusBadRequest, "Failed to read request body: %s", err)
return
}
hosts := strings.Fields(string(body))
if len(hosts) == 0 {
errorText := fmt.Sprintf("No servers specified")
log.Error(errorText)
http.Error(w, errorText, http.StatusBadRequest)
if len(upstreamConfig.Upstreams) == 0 {
httpError(w, http.StatusBadRequest, "No servers specified")
return
}
result := map[string]string{}
for _, host := range hosts {
err = checkDNS(host)
for _, host := range upstreamConfig.Upstreams {
err = checkDNS(host, upstreamConfig.BootstrapDNS)
if err != nil {
log.Info("%v", err)
result[host] = err.Error()
@ -394,24 +410,28 @@ func handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) {
jsonVal, err := json.Marshal(result)
if err != nil {
errorText := fmt.Sprintf("Unable to marshal status json: %s", err)
log.Error(errorText)
http.Error(w, errorText, http.StatusInternalServerError)
httpError(w, http.StatusInternalServerError, "Unable to marshal status json: %s", err)
return
}
w.Header().Set("Content-Type", "application/json")
_, err = w.Write(jsonVal)
if err != nil {
errorText := fmt.Sprintf("Couldn't write body: %s", err)
log.Error(errorText)
http.Error(w, errorText, http.StatusInternalServerError)
httpError(w, http.StatusInternalServerError, "Couldn't write body: %s", err)
}
}
func checkDNS(input string) error {
func checkDNS(input string, bootstrap []string) error {
if err := validateUpstream(input); err != nil {
return fmt.Errorf("wrong upstream format: %s", err)
}
if len(bootstrap) == 0 {
bootstrap = defaultBootstrap
}
log.Debug("Checking if DNS %s works...", input)
u, err := upstream.AddressToUpstream(input, upstream.Options{Timeout: dnsforward.DefaultTimeout})
u, err := upstream.AddressToUpstream(input, upstream.Options{Bootstrap: bootstrap, Timeout: dnsforward.DefaultTimeout})
if err != nil {
return fmt.Errorf("failed to choose upstream for %s: %s", input, err)
}
@ -451,9 +471,7 @@ func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) {
resp, err := client.Get(versionCheckURL)
if err != nil {
errorText := fmt.Sprintf("Couldn't get version check json from %s: %T %s\n", versionCheckURL, err, err)
log.Error(errorText)
http.Error(w, errorText, http.StatusBadGateway)
httpError(w, http.StatusBadGateway, "Couldn't get version check json from %s: %T %s\n", versionCheckURL, err, err)
return
}
if resp != nil && resp.Body != nil {
@ -463,18 +481,14 @@ func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) {
// read the body entirely
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
errorText := fmt.Sprintf("Couldn't read response body from %s: %s", versionCheckURL, err)
log.Error(errorText)
http.Error(w, errorText, http.StatusBadGateway)
httpError(w, http.StatusBadGateway, "Couldn't read response body from %s: %s", versionCheckURL, err)
return
}
w.Header().Set("Content-Type", "application/json")
_, err = w.Write(body)
if err != nil {
errorText := fmt.Sprintf("Couldn't write body: %s", err)
log.Error(errorText)
http.Error(w, errorText, http.StatusInternalServerError)
httpError(w, http.StatusInternalServerError, "Couldn't write body: %s", err)
}
versionCheckLastTime = now
@ -510,18 +524,14 @@ func handleFilteringStatus(w http.ResponseWriter, r *http.Request) {
config.RUnlock()
if err != nil {
errorText := fmt.Sprintf("Unable to marshal status json: %s", err)
log.Error(errorText)
http.Error(w, errorText, 500)
httpError(w, http.StatusInternalServerError, "Unable to marshal status json: %s", err)
return
}
w.Header().Set("Content-Type", "application/json")
_, err = w.Write(jsonVal)
if err != nil {
errorText := fmt.Sprintf("Unable to write response json: %s", err)
log.Error(errorText)
http.Error(w, errorText, 500)
httpError(w, http.StatusInternalServerError, "Unable to write response json: %s", err)
return
}
}
@ -536,21 +546,19 @@ func handleFilteringAddURL(w http.ResponseWriter, r *http.Request) {
}
if len(f.URL) == 0 {
http.Error(w, "URL parameter was not specified", 400)
http.Error(w, "URL parameter was not specified", http.StatusBadRequest)
return
}
if valid := govalidator.IsRequestURL(f.URL); !valid {
http.Error(w, "URL parameter is not valid request URL", 400)
http.Error(w, "URL parameter is not valid request URL", http.StatusBadRequest)
return
}
// Check for duplicates
for i := range config.Filters {
if config.Filters[i].URL == f.URL {
errorText := fmt.Sprintf("Filter URL already added -- %s", f.URL)
log.Error(errorText)
http.Error(w, errorText, http.StatusBadRequest)
httpError(w, http.StatusBadRequest, "Filter URL already added -- %s", f.URL)
return
}
}
@ -562,30 +570,22 @@ func handleFilteringAddURL(w http.ResponseWriter, r *http.Request) {
// Download the filter contents
ok, err := f.update(true)
if err != nil {
errorText := fmt.Sprintf("Couldn't fetch filter from url %s: %s", f.URL, err)
log.Error(errorText)
http.Error(w, errorText, http.StatusBadRequest)
httpError(w, http.StatusBadRequest, "Couldn't fetch filter from url %s: %s", f.URL, err)
return
}
if f.RulesCount == 0 {
errorText := fmt.Sprintf("Filter at the url %s has no rules (maybe it points to blank page?)", f.URL)
log.Error(errorText)
http.Error(w, errorText, http.StatusBadRequest)
httpError(w, http.StatusBadRequest, "Filter at the url %s has no rules (maybe it points to blank page?)", f.URL)
return
}
if !ok {
errorText := fmt.Sprintf("Filter at the url %s is invalid (maybe it points to blank page?)", f.URL)
log.Error(errorText)
http.Error(w, errorText, http.StatusBadRequest)
httpError(w, http.StatusBadRequest, "Filter at the url %s is invalid (maybe it points to blank page?)", f.URL)
return
}
// Save the filter contents
err = f.save()
if err != nil {
errorText := fmt.Sprintf("Failed to save filter %d due to %s", f.ID, err)
log.Error(errorText)
http.Error(w, errorText, http.StatusBadRequest)
httpError(w, http.StatusBadRequest, "Failed to save filter %d due to %s", f.ID, err)
return
}
@ -594,24 +594,18 @@ func handleFilteringAddURL(w http.ResponseWriter, r *http.Request) {
config.Filters = append(config.Filters, f)
err = writeAllConfigs()
if err != nil {
errorText := fmt.Sprintf("Couldn't write config file: %s", err)
log.Error(errorText)
http.Error(w, errorText, http.StatusInternalServerError)
httpError(w, http.StatusInternalServerError, "Couldn't write config file: %s", err)
return
}
err = reconfigureDNSServer()
if err != nil {
errorText := fmt.Sprintf("Couldn't reconfigure the DNS server: %s", err)
log.Error(errorText)
http.Error(w, errorText, http.StatusInternalServerError)
httpError(w, http.StatusInternalServerError, "Couldn't reconfigure the DNS server: %s", err)
}
_, err = fmt.Fprintf(w, "OK %d rules\n", f.RulesCount)
if err != nil {
errorText := fmt.Sprintf("Couldn't write body: %s", err)
log.Error(errorText)
http.Error(w, errorText, http.StatusInternalServerError)
httpError(w, http.StatusInternalServerError, "Couldn't write body: %s", err)
}
}
@ -619,20 +613,18 @@ func handleFilteringRemoveURL(w http.ResponseWriter, r *http.Request) {
log.Tracef("%s %v", r.Method, r.URL)
parameters, err := parseParametersFromBody(r.Body)
if err != nil {
errorText := fmt.Sprintf("failed to parse parameters from body: %s", err)
log.Error(errorText)
http.Error(w, errorText, 400)
httpError(w, http.StatusBadRequest, "failed to parse parameters from body: %s", err)
return
}
url, ok := parameters["url"]
if !ok {
http.Error(w, "URL parameter was not specified", 400)
http.Error(w, "URL parameter was not specified", http.StatusBadRequest)
return
}
if valid := govalidator.IsRequestURL(url); !valid {
http.Error(w, "URL parameter is not valid request URL", 400)
http.Error(w, "URL parameter is not valid request URL", http.StatusBadRequest)
return
}
@ -645,8 +637,7 @@ func handleFilteringRemoveURL(w http.ResponseWriter, r *http.Request) {
// Remove the filter file
err := os.Remove(filter.Path())
if err != nil && !os.IsNotExist(err) {
errorText := fmt.Sprintf("Couldn't remove the filter file: %s", err)
http.Error(w, errorText, http.StatusInternalServerError)
httpError(w, http.StatusInternalServerError, "Couldn't remove the filter file: %s", err)
return
}
}
@ -660,15 +651,13 @@ func handleFilteringEnableURL(w http.ResponseWriter, r *http.Request) {
log.Tracef("%s %v", r.Method, r.URL)
parameters, err := parseParametersFromBody(r.Body)
if err != nil {
errorText := fmt.Sprintf("failed to parse parameters from body: %s", err)
log.Error(errorText)
http.Error(w, errorText, 400)
httpError(w, http.StatusBadRequest, "failed to parse parameters from body: %s", err)
return
}
url, ok := parameters["url"]
if !ok {
http.Error(w, "URL parameter was not specified", 400)
http.Error(w, "URL parameter was not specified", http.StatusBadRequest)
return
}
@ -700,15 +689,13 @@ func handleFilteringDisableURL(w http.ResponseWriter, r *http.Request) {
log.Tracef("%s %v", r.Method, r.URL)
parameters, err := parseParametersFromBody(r.Body)
if err != nil {
errorText := fmt.Sprintf("failed to parse parameters from body: %s", err)
log.Error(errorText)
http.Error(w, errorText, 400)
httpError(w, http.StatusBadRequest, "failed to parse parameters from body: %s", err)
return
}
url, ok := parameters["url"]
if !ok {
http.Error(w, "URL parameter was not specified", 400)
http.Error(w, "URL parameter was not specified", http.StatusBadRequest)
return
}
@ -738,9 +725,7 @@ func handleFilteringSetRules(w http.ResponseWriter, r *http.Request) {
log.Tracef("%s %v", r.Method, r.URL)
body, err := ioutil.ReadAll(r.Body)
if err != nil {
errorText := fmt.Sprintf("Failed to read request body: %s", err)
log.Error(errorText)
http.Error(w, errorText, 400)
httpError(w, http.StatusBadRequest, "Failed to read request body: %s", err)
return
}
@ -778,17 +763,13 @@ func handleSafeBrowsingStatus(w http.ResponseWriter, r *http.Request) {
}
jsonVal, err := json.Marshal(data)
if err != nil {
errorText := fmt.Sprintf("Unable to marshal status json: %s", err)
log.Error(errorText)
http.Error(w, errorText, 500)
httpError(w, http.StatusInternalServerError, "Unable to marshal status json: %s", err)
}
w.Header().Set("Content-Type", "application/json")
_, err = w.Write(jsonVal)
if err != nil {
errorText := fmt.Sprintf("Unable to write response json: %s", err)
log.Error(errorText)
http.Error(w, errorText, 500)
httpError(w, http.StatusInternalServerError, "Unable to write response json: %s", err)
return
}
}
@ -800,9 +781,7 @@ func handleParentalEnable(w http.ResponseWriter, r *http.Request) {
log.Tracef("%s %v", r.Method, r.URL)
parameters, err := parseParametersFromBody(r.Body)
if err != nil {
errorText := fmt.Sprintf("failed to parse parameters from body: %s", err)
log.Error(errorText)
http.Error(w, errorText, 400)
httpError(w, http.StatusBadRequest, "failed to parse parameters from body: %s", err)
return
}
@ -859,18 +838,14 @@ func handleParentalStatus(w http.ResponseWriter, r *http.Request) {
}
jsonVal, err := json.Marshal(data)
if err != nil {
errorText := fmt.Sprintf("Unable to marshal status json: %s", err)
log.Error(errorText)
http.Error(w, errorText, 500)
httpError(w, http.StatusInternalServerError, "Unable to marshal status json: %s", err)
return
}
w.Header().Set("Content-Type", "application/json")
_, err = w.Write(jsonVal)
if err != nil {
errorText := fmt.Sprintf("Unable to write response json: %s", err)
log.Error(errorText)
http.Error(w, errorText, 500)
httpError(w, http.StatusInternalServerError, "Unable to write response json: %s", err)
return
}
}
@ -898,18 +873,14 @@ func handleSafeSearchStatus(w http.ResponseWriter, r *http.Request) {
}
jsonVal, err := json.Marshal(data)
if err != nil {
errorText := fmt.Sprintf("Unable to marshal status json: %s", err)
log.Error(errorText)
http.Error(w, errorText, 500)
httpError(w, http.StatusInternalServerError, "Unable to marshal status json: %s", err)
return
}
w.Header().Set("Content-Type", "application/json")
_, err = w.Write(jsonVal)
if err != nil {
errorText := fmt.Sprintf("Unable to write response json: %s", err)
log.Error(errorText)
http.Error(w, errorText, 500)
httpError(w, http.StatusInternalServerError, "Unable to write response json: %s", err)
return
}
}
@ -1053,7 +1024,7 @@ func registerControlHandlers() {
http.HandleFunc("/control/querylog", postInstall(optionalAuth(ensureGET(handleQueryLog))))
http.HandleFunc("/control/querylog_enable", postInstall(optionalAuth(ensurePOST(handleQueryLogEnable))))
http.HandleFunc("/control/querylog_disable", postInstall(optionalAuth(ensurePOST(handleQueryLogDisable))))
http.HandleFunc("/control/set_upstream_dns", postInstall(optionalAuth(ensurePOST(handleSetUpstreamDNS))))
http.HandleFunc("/control/set_upstreams_config", postInstall(optionalAuth(ensurePOST(handleSetUpstreamConfig))))
http.HandleFunc("/control/test_upstream_dns", postInstall(optionalAuth(ensurePOST(handleTestUpstreamDNS))))
http.HandleFunc("/control/i18n/change_language", postInstall(optionalAuth(ensurePOST(handleI18nChangeLanguage))))
http.HandleFunc("/control/i18n/current_language", postInstall(optionalAuth(ensureGET(handleI18nCurrentLanguage))))

3
dns.go
View File

@ -61,7 +61,7 @@ func generateServerConfig() dnsforward.ServerConfig {
for _, u := range config.DNS.UpstreamDNS {
opts := upstream.Options{
Timeout: dnsforward.DefaultTimeout,
Bootstrap: []string{config.DNS.BootstrapDNS},
Bootstrap: config.DNS.BootstrapDNS,
}
dnsUpstream, err := upstream.AddressToUpstream(u, opts)
if err != nil {
@ -71,6 +71,7 @@ func generateServerConfig() dnsforward.ServerConfig {
}
newconfig.Upstreams = append(newconfig.Upstreams, dnsUpstream)
}
newconfig.AllServers = config.DNS.AllServers
return newconfig
}

View File

@ -16,8 +16,8 @@ import (
"sync/atomic"
"time"
"github.com/bluele/gcache"
"github.com/AdguardTeam/golibs/log"
"github.com/bluele/gcache"
"golang.org/x/net/publicsuffix"
)

View File

@ -62,11 +62,12 @@ type FilteringConfig struct {
ProtectionEnabled bool `yaml:"protection_enabled"` // whether or not use any of dnsfilter features
FilteringEnabled bool `yaml:"filtering_enabled"` // whether or not use filter lists
BlockedResponseTTL uint32 `yaml:"blocked_response_ttl"` // if 0, then default is used (3600)
QueryLogEnabled bool `yaml:"querylog_enabled"`
Ratelimit int `yaml:"ratelimit"`
RatelimitWhitelist []string `yaml:"ratelimit_whitelist"`
RefuseAny bool `yaml:"refuse_any"`
BootstrapDNS string `yaml:"bootstrap_dns"`
QueryLogEnabled bool `yaml:"querylog_enabled"` // if true, query log is enabled
Ratelimit int `yaml:"ratelimit"` // max number of requests per second from a given IP (0 to disable)
RatelimitWhitelist []string `yaml:"ratelimit_whitelist"` // a list of whitelisted client IP addresses
RefuseAny bool `yaml:"refuse_any"` // if true, refuse ANY requests
BootstrapDNS []string `yaml:"bootstrap_dns"` // a list of bootstrap DNS for DoH and DoT (plain DNS only)
AllServers bool `yaml:"all_servers"` // if true, parallel queries to all configured upstream servers are enabled
dnsfilter.Config `yaml:",inline"`
}
@ -163,6 +164,7 @@ func (s *Server) startInternal(config *ServerConfig) error {
CacheEnabled: true,
Upstreams: s.Upstreams,
Handler: s.handleDNSRequest,
AllServers: s.AllServers,
}
if s.TLSListenAddr != nil && s.CertificateChain != "" && s.PrivateKey != "" {

View File

@ -87,26 +87,21 @@ paths:
200:
description: OK
/set_upstream_dns:
/set_upstreams_config:
post:
tags:
- global
operationId: setUpstreamDNS
summary: 'Set upstream DNS for coredns, empty value will reset it to default values'
operationId: setUpstreamsConfig
summary: "Updates the current upstreams configuration"
consumes:
- text/plain
- application/json
parameters:
- in: body
name: upstream
description: 'Upstream servers, separated by newline or space, port is optional after colon'
- in: "body"
name: "body"
description: "Upstreams configuration JSON"
required: true
schema:
# TODO: use JSON
type: string
example: |
1.1.1.1
1.0.0.1
8.8.8.8 8.8.4.4
192.168.1.104:53535
$ref: "#/definitions/UpstreamsConfig"
responses:
200:
description: OK
@ -116,21 +111,15 @@ paths:
tags:
- global
operationId: testUpstreamDNS
summary: 'Test upstream DNS'
summary: "Test upstream configuration"
consumes:
- text/plain
- application/json
parameters:
- in: body
name: upstream
description: 'Upstream servers, separated by newline or space, port is optional after colon'
- in: "body"
name: "body"
description: "Upstream configuration to be tested"
schema:
# TODO: use JSON
type: string
example: |
1.1.1.1
1.0.0.1
8.8.8.8 8.8.4.4
192.168.1.104:53535
$ref: "#/definitions/UpstreamsConfig"
responses:
200:
description: 'Status of testing each requested server, with "OK" meaning that server works, any other text means an error.'
@ -798,6 +787,33 @@ definitions:
language:
type: "string"
example: "en"
UpstreamsConfig:
type: "object"
description: "Upstreams configuration"
required:
- "bootstrap_dns"
- "upstream_dns"
- "all_servers"
properties:
bootstrap_dns:
type: "array"
description: 'Bootstrap servers, port is optional after colon. Empty value will reset it to default values'
items:
type: "string"
example:
- "8.8.8.8:53"
- "1.1.1.1:53"
upstream_dns:
type: "array"
description: 'Upstream servers, port is optional after colon. Empty value will reset it to default values'
items:
type: "string"
example:
- "tls://1.1.1.1"
- "tls://1.0.0.1"
all_servers:
type: "boolean"
description: "If true, parallel queries to all configured upstream servers are enabled"
Filter:
type: "object"
description: "Filter subscription info"

View File

@ -10,7 +10,7 @@ import (
yaml "gopkg.in/yaml.v2"
)
const currentSchemaVersion = 2 // used for upgrading from old configs to new config
const currentSchemaVersion = 3 // used for upgrading from old configs to new config
// Performs necessary upgrade operations if needed
func upgradeConfig() error {
@ -59,12 +59,17 @@ func upgradeConfig() error {
func upgradeConfigSchema(oldVersion int, diskConfig *map[string]interface{}) error {
switch oldVersion {
case 0:
err := upgradeSchema0to2(diskConfig)
err := upgradeSchema0to3(diskConfig)
if err != nil {
return err
}
case 1:
err := upgradeSchema1to2(diskConfig)
err := upgradeSchema1to3(diskConfig)
if err != nil {
return err
}
case 2:
err := upgradeSchema2to3(diskConfig)
if err != nil {
return err
}
@ -135,12 +140,60 @@ func upgradeSchema1to2(diskConfig *map[string]interface{}) error {
return nil
}
// jump two schemas at once -- this time we just do it sequentially
func upgradeSchema0to2(diskConfig *map[string]interface{}) error {
// Third schema upgrade:
// Bootstrap DNS becomes an array
func upgradeSchema2to3(diskConfig *map[string]interface{}) error {
log.Printf("%s(): called", _Func())
// Let's read dns configuration from diskConfig
dnsConfig, ok := (*diskConfig)["dns"]
if !ok {
return fmt.Errorf("no DNS configuration in config file")
}
// Convert interface{} to map[string]interface{}
newDNSConfig := make(map[string]interface{})
switch v := dnsConfig.(type) {
case map[interface{}]interface{}:
for k, v := range v {
newDNSConfig[fmt.Sprint(k)] = v
}
default:
return fmt.Errorf("DNS configuration is not a map")
}
// Replace bootstrap_dns value filed with new array contains old bootstrap_dns inside
if bootstrapDNS, ok := (newDNSConfig)["bootstrap_dns"]; ok {
newBootstrapConfig := []string{fmt.Sprint(bootstrapDNS)}
(newDNSConfig)["bootstrap_dns"] = newBootstrapConfig
(*diskConfig)["dns"] = newDNSConfig
} else {
return fmt.Errorf("no bootstrap DNS in DNS config")
}
// Bump schema version
(*diskConfig)["schema_version"] = 3
return nil
}
// jump three schemas at once -- this time we just do it sequentially
func upgradeSchema0to3(diskConfig *map[string]interface{}) error {
err := upgradeSchema0to1(diskConfig)
if err != nil {
return err
}
return upgradeSchema1to2(diskConfig)
return upgradeSchema1to3(diskConfig)
}
// jump two schemas at once -- this time we just do it sequentially
func upgradeSchema1to3(diskConfig *map[string]interface{}) error {
err := upgradeSchema1to2(diskConfig)
if err != nil {
return err
}
return upgradeSchema2to3(diskConfig)
}

230
upgrade_test.go Normal file
View File

@ -0,0 +1,230 @@
package main
import (
"fmt"
"testing"
)
func TestUpgrade1to2(t *testing.T) {
// let's create test config for 1 schema version
diskConfig := createTestDiskConfig(1)
// update config
err := upgradeSchema1to2(&diskConfig)
if err != nil {
t.Fatalf("Can't upgrade schema version from 1 to 2")
}
// ensure that schema version was bumped
compareSchemaVersion(t, diskConfig["schema_version"], 2)
// old coredns entry should be removed
_, ok := diskConfig["coredns"]
if ok {
t.Fatalf("Core DNS config was not removed after upgrade schema version from 1 to 2")
}
// pull out new dns config
dnsMap, ok := diskConfig["dns"]
if !ok {
t.Fatalf("No DNS config after upgrade schema version from 1 to 2")
}
// cast dns configurations to maps and compare them
oldDNSConfig := castInterfaceToMap(t, createTestDNSConfig(1))
newDNSConfig := castInterfaceToMap(t, dnsMap)
compareConfigs(t, &oldDNSConfig, &newDNSConfig)
// exclude dns config and schema version from disk config comparison
oldExcludedEntries := []string{"coredns", "schema_version"}
newExcludedEntries := []string{"dns", "schema_version"}
oldDiskConfig := createTestDiskConfig(1)
compareConfigsWithoutEntries(t, &oldDiskConfig, &diskConfig, oldExcludedEntries, newExcludedEntries)
}
func TestUpgrade2to3(t *testing.T) {
// let's create test config
diskConfig := createTestDiskConfig(2)
// upgrade schema from 2 to 3
err := upgradeSchema2to3(&diskConfig)
if err != nil {
t.Fatalf("Can't update schema version from 2 to 3: %s", err)
}
// check new schema version
compareSchemaVersion(t, diskConfig["schema_version"], 3)
// pull out new dns configuration
dnsMap, ok := diskConfig["dns"]
if !ok {
t.Fatalf("No dns config in new configuration")
}
// cast dns configuration to map
newDNSConfig := castInterfaceToMap(t, dnsMap)
// check if bootstrap DNS becomes an array
bootstrapDNS := newDNSConfig["bootstrap_dns"]
switch v := bootstrapDNS.(type) {
case []string:
if len(v) != 1 {
t.Fatalf("Wrong count of bootsrap DNS servers: %d", len(v))
}
if v[0] != "8.8.8.8:53" {
t.Fatalf("Bootsrap DNS server is not 8.8.8.8:53 : %s", v[0])
}
default:
t.Fatalf("Wrong type for bootsrap DNS: %T", v)
}
// exclude bootstrap DNS from DNS configs comparison
excludedEntries := []string{"bootstrap_dns"}
oldDNSConfig := castInterfaceToMap(t, createTestDNSConfig(2))
compareConfigsWithoutEntries(t, &oldDNSConfig, &newDNSConfig, excludedEntries, excludedEntries)
// excluded dns config and schema version from disk config comparison
excludedEntries = []string{"dns", "schema_version"}
oldDiskConfig := createTestDiskConfig(2)
compareConfigsWithoutEntries(t, &oldDiskConfig, &diskConfig, excludedEntries, excludedEntries)
}
func castInterfaceToMap(t *testing.T, oldConfig interface{}) (newConfig map[string]interface{}) {
newConfig = make(map[string]interface{})
switch v := oldConfig.(type) {
case map[interface{}]interface{}:
for key, value := range v {
newConfig[fmt.Sprint(key)] = value
}
case map[string]interface{}:
for key, value := range v {
newConfig[key] = value
}
default:
t.Fatalf("DNS configuration is not a map")
}
return
}
// compareConfigsWithoutEntry removes entries from configs and returns result of compareConfigs
func compareConfigsWithoutEntries(t *testing.T, oldConfig, newConfig *map[string]interface{}, oldKey, newKey []string) {
for _, k := range oldKey {
delete(*oldConfig, k)
}
for _, k := range newKey {
delete(*newConfig, k)
}
compareConfigs(t, oldConfig, newConfig)
}
// compares configs before and after schema upgrade
func compareConfigs(t *testing.T, oldConfig, newConfig *map[string]interface{}) {
if len(*oldConfig) != len(*newConfig) {
t.Fatalf("wrong config entries count! Before upgrade: %d; After upgrade: %d", len(*oldConfig), len(*oldConfig))
}
// Check old and new entries
for k, v := range *newConfig {
switch value := v.(type) {
case string:
if value != (*oldConfig)[k] {
t.Fatalf("wrong value for string %s. Before update: %s; After update: %s", k, (*oldConfig)[k], value)
}
case int:
if value != (*oldConfig)[k] {
t.Fatalf("wrong value for int %s. Before update: %d; After update: %d", k, (*oldConfig)[k], value)
}
case []string:
for i, line := range value {
if len((*oldConfig)[k].([]string)) != len(value) {
t.Fatalf("wrong array length for %s. Before update: %d; After update: %d", k, len((*oldConfig)[k].([]string)), len(value))
}
if (*oldConfig)[k].([]string)[i] != line {
t.Fatalf("wrong data for string array %s. Before update: %s; After update: %s", k, (*oldConfig)[k].([]string)[i], line)
}
}
case bool:
if v != (*oldConfig)[k].(bool) {
t.Fatalf("wrong boolean value for %s", k)
}
case []filter:
if len((*oldConfig)[k].([]filter)) != len(value) {
t.Fatalf("wrong filters count. Before update: %d; After update: %d", len((*oldConfig)[k].([]filter)), len(value))
}
for i, newFilter := range value {
oldFilter := (*oldConfig)[k].([]filter)[i]
if oldFilter.Enabled != newFilter.Enabled || oldFilter.Name != newFilter.Name || oldFilter.RulesCount != newFilter.RulesCount {
t.Fatalf("old filter %s not equals new filter %s", oldFilter.Name, newFilter.Name)
}
}
default:
t.Fatalf("uknown data type for %s: %T", k, value)
}
}
}
// compareSchemaVersion check if newSchemaVersion equals schemaVersion
func compareSchemaVersion(t *testing.T, newSchemaVersion interface{}, schemaVersion int) {
switch v := newSchemaVersion.(type) {
case int:
if v != schemaVersion {
t.Fatalf("Wrong schema version in new config file")
}
default:
t.Fatalf("Schema version is not an integer after update")
}
}
func createTestDiskConfig(schemaVersion int) (diskConfig map[string]interface{}) {
diskConfig = make(map[string]interface{})
diskConfig["language"] = "en"
diskConfig["filters"] = []filter{
{
URL: "https://filters.adtidy.org/android/filters/111_optimized.txt",
Name: "Latvian filter",
RulesCount: 100,
},
{
URL: "https://easylist.to/easylistgermany/easylistgermany.txt",
Name: "Germany filter",
RulesCount: 200,
},
}
diskConfig["user_rules"] = []string{}
diskConfig["schema_version"] = schemaVersion
diskConfig["bind_host"] = "0.0.0.0"
diskConfig["bind_port"] = 80
diskConfig["auth_name"] = "name"
diskConfig["auth_pass"] = "pass"
dnsConfig := createTestDNSConfig(schemaVersion)
if schemaVersion > 1 {
diskConfig["dns"] = dnsConfig
} else {
diskConfig["coredns"] = dnsConfig
}
return diskConfig
}
func createTestDNSConfig(schemaVersion int) map[interface{}]interface{} {
dnsConfig := make(map[interface{}]interface{})
dnsConfig["port"] = 53
dnsConfig["blocked_response_ttl"] = 10
dnsConfig["querylog_enabled"] = true
dnsConfig["ratelimit"] = 20
dnsConfig["bootstrap_dns"] = "8.8.8.8:53"
if schemaVersion > 2 {
dnsConfig["bootstrap_dns"] = []string{"8.8.8.8:53"}
}
dnsConfig["parental_sensitivity"] = 13
dnsConfig["ratelimit_whitelist"] = []string{}
dnsConfig["upstream_dns"] = []string{"tls://1.1.1.1", "tls://1.0.0.1", "8.8.8.8"}
dnsConfig["filtering_enabled"] = true
dnsConfig["refuse_any"] = true
dnsConfig["parental_enabled"] = true
dnsConfig["bind_host"] = "0.0.0.0"
dnsConfig["protection_enabled"] = true
dnsConfig["safesearch_enabled"] = true
dnsConfig["safebrowsing_enabled"] = true
return dnsConfig
}