diff --git a/AGHTechDoc.md b/AGHTechDoc.md index b91b8586..ed7f48ba 100644 --- a/AGHTechDoc.md +++ b/AGHTechDoc.md @@ -1846,7 +1846,7 @@ Response: } There are also deprecated properties `filter_id` and `rule` on the top level of -the response object. Their usaga should be replaced with +the response object. Their usage should be replaced with `rules[*].filter_list_id` and `rules[*].text` correspondingly. See the _OpenAPI_ documentation and the `./openapi/CHANGELOG.md` file. diff --git a/CHANGELOG.md b/CHANGELOG.md index e63c913c..46c421ba 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -56,6 +56,13 @@ and this project adheres to [#2391]: https://github.com/AdguardTeam/AdGuardHome/issues/2391 [#2394]: https://github.com/AdguardTeam/AdGuardHome/issues/2394 +### Deprecated + +- _Go_ 1.14 support. v0.106.0 will require at least _Go_ 1.15 to build. +- The `darwin/386` port. It will be removed in v0.106.0. +- The `"rule"` and `"filter_id"` fields in `GET /filtering/check_host` and + `GET /querylog` responses. They will be removed in v0.106.0 ([#2102]). + ### Fixed - Inability to set DNS cache TTL limits ([#2459]). diff --git a/HACKING.md b/HACKING.md index a16d6496..237e554f 100644 --- a/HACKING.md +++ b/HACKING.md @@ -70,6 +70,14 @@ The rules are mostly sorted in the alphabetical order. func TestType_Method_suffix(t *testing.T) { /* … */ } ``` + * Name parameters in interface definitions: + + ```go + type Frobulator interface { + Frobulate(f Foo, b Bar) (r Result, err error) + } + ``` + * Name the deferred errors (e.g. when closing something) `cerr`. * No shadowing, since it can often lead to subtle bugs, especially with @@ -172,10 +180,15 @@ The rules are mostly sorted in the alphabetical order. * Put utility flags in the ASCII order and **don't** group them together. For example, `ls -1 -A -q`. - * `snake_case`, not `camelCase`. + * `snake_case`, not `camelCase` for variables. `kebab-case` for filenames. + + * UPPERCASE names for external exported variables, lowercase for local, + unexported ones. * Use `set -e -f -u` and also `set -x` in verbose mode. + * Use `readonly` liberally. + * Use the `"$var"` form instead of the `$var` form, unless word splitting is required. diff --git a/Makefile b/Makefile index 2bdee355..27247fe6 100644 --- a/Makefile +++ b/Makefile @@ -67,10 +67,11 @@ endif # Version properties COMMIT=$(shell git rev-parse --short HEAD) +TAG_NAME=$(shell git describe --abbrev=0) +PRERELEASE_VERSION=$(shell git describe --abbrev=0) # TODO(a.garipov): The cut call is a temporary solution to trim # prerelease versions. See the comment in .goreleaser.yml. -TAG_NAME=$(shell git describe --abbrev=0 | cut -c 1-8) -RELEASE_VERSION=$(TAG_NAME) +RELEASE_VERSION=$(shell git describe --abbrev=0 | cut -c 1-8) SNAPSHOT_VERSION=$(RELEASE_VERSION)-SNAPSHOT-$(COMMIT) # Set proper version @@ -78,6 +79,8 @@ VERSION= ifeq ($(TAG_NAME),$(shell git describe --abbrev=4)) ifeq ($(CHANNEL),edge) VERSION=$(SNAPSHOT_VERSION) + else ifeq ($(CHANNEL),beta) + VERSION=$(PRERELEASE_VERSION) else VERSION=$(RELEASE_VERSION) endif diff --git a/README.md b/README.md index 5af39b8e..dceda824 100644 --- a/README.md +++ b/README.md @@ -123,20 +123,22 @@ AdGuard Home provides a lot of features out-of-the-box with no need to install a > Disclaimer: some of the listed features can be added to Pi-Hole by installing additional software or by manually using SSH terminal and reconfiguring one of the utilities Pi-Hole consists of. However, in our opinion, this cannot be legitimately counted as a Pi-Hole's feature. -| Feature | AdGuard Home | Pi-Hole | -|-------------------------------------------------------------------------|--------------|--------------------------------------------------------| -| Blocking ads and trackers | ✅ | ✅ | -| Customizing blocklists | ✅ | ✅ | -| Built-in DHCP server | ✅ | ✅ | -| HTTPS for the Admin interface | ✅ | Kind of, but you'll need to manually configure lighthttpd | -| Encrypted DNS upstream servers (DNS-over-HTTPS, DNS-over-TLS, DNSCrypt) | ✅ | ❌ (requires additional software) | -| Cross-platform | ✅ | ❌ (not natively, only via Docker) | -| Running as a DNS-over-HTTPS or DNS-over-TLS server | ✅ | ❌ (requires additional software) | -| Blocking phishing and malware domains | ✅ | ❌ (requires non-default blocklists) | -| Parental control (blocking adult domains) | ✅ | ❌ | -| Force Safe search on search engines | ✅ | ❌ | -| Per-client (device) configuration | ✅ | ✅ | -| Access settings (choose who can use AGH DNS) | ✅ | ❌ | +| Feature | AdGuard Home | Pi-Hole | +|-------------------------------------------------------------------------|-------------------|-----------------------------------------------------------| +| Blocking ads and trackers | ✅ | ✅ | +| Customizing blocklists | ✅ | ✅ | +| Built-in DHCP server | ✅ | ✅ | +| HTTPS for the Admin interface | ✅ | Kind of, but you'll need to manually configure lighthttpd | +| Encrypted DNS upstream servers (DNS-over-HTTPS, DNS-over-TLS, DNSCrypt) | ✅ | ❌ (requires additional software) | +| Cross-platform | ✅ | ❌ (not natively, only via Docker) | +| Running as a DNS-over-HTTPS or DNS-over-TLS server | ✅ | ❌ (requires additional software) | +| Blocking phishing and malware domains | ✅ | ❌ (requires non-default blocklists) | +| Parental control (blocking adult domains) | ✅ | ❌ | +| Force Safe search on search engines | ✅ | ❌ | +| Per-client (device) configuration | ✅ | ✅ | +| Access settings (choose who can use AGH DNS) | ✅ | ❌ | +| Written in a memory-safe language | ✅ | ❌ | +| Running without root privileges | ✅ | ❌ | ### How does AdGuard Home compare to traditional ad blockers diff --git a/client/src/__locales/en.json b/client/src/__locales/en.json index 388691b0..afc2f105 100644 --- a/client/src/__locales/en.json +++ b/client/src/__locales/en.json @@ -270,7 +270,7 @@ "source_label": "Source", "found_in_known_domain_db": "Found in the known domains database.", "category_label": "Category", - "rule_label": "Rule", + "rule_label": "Rule(s)", "list_label": "List", "unknown_filter": "Unknown filter {{filterId}}", "known_tracker": "Known tracker", @@ -530,7 +530,6 @@ "check_ip": "IP addresses: {{ip}}", "check_cname": "CNAME: {{cname}}", "check_reason": "Reason: {{reason}}", - "check_rule": "Rule: {{rule}}", "check_service": "Service name: {{service}}", "service_name": "Service name", "check_not_found": "Not found in your filter lists", diff --git a/client/src/actions/index.js b/client/src/actions/index.js index 04054091..58e2d018 100644 --- a/client/src/actions/index.js +++ b/client/src/actions/index.js @@ -287,7 +287,7 @@ export const getDnsStatus = () => async (dispatch) => { try { checkStatus(handleRequestSuccess, handleRequestError); } catch (error) { - handleRequestError(error); + handleRequestError(); } }; diff --git a/client/src/components/App/index.css b/client/src/components/App/index.css index 4f18ee18..ff530784 100644 --- a/client/src/components/App/index.css +++ b/client/src/components/App/index.css @@ -1,9 +1,9 @@ :root { --yellow-pale: rgba(247, 181, 0, 0.1); - --green79: #67B279; + --green79: #67b279; --gray-a5: #a5a5a5; --gray-d8: #d8d8d8; - --gray-f3: #F3F3F3; + --gray-f3: #f3f3f3; --font-family-monospace: Monaco, Menlo, "Ubuntu Mono", Consolas, source-code-pro, monospace; } diff --git a/client/src/components/Dashboard/Dashboard.css b/client/src/components/Dashboard/Dashboard.css index c299464b..9e733044 100644 --- a/client/src/components/Dashboard/Dashboard.css +++ b/client/src/components/Dashboard/Dashboard.css @@ -34,7 +34,7 @@ align-items: center; } -.dashboard-title__button{ +.dashboard-title__button { margin: 0 0.5rem; } @@ -44,7 +44,7 @@ align-items: flex-start; } - .dashboard-title__button{ + .dashboard-title__button { margin: 0.5rem 0; display: block; } diff --git a/client/src/components/Dashboard/index.js b/client/src/components/Dashboard/index.js index 4de5b8dd..d7e8874e 100644 --- a/client/src/components/Dashboard/index.js +++ b/client/src/components/Dashboard/index.js @@ -44,6 +44,7 @@ const Dashboard = ({ const refreshButton = ; }; -const getTitle = (reason) => { +const getTitle = () => { const { t } = useTranslation(); const filters = useSelector((state) => state.filtering.filters, shallowEqual); const whitelistFilters = useSelector((state) => state.filtering.whitelistFilters, shallowEqual); - const filter_id = useSelector((state) => state.filtering.check.filter_id); - - const filterName = getFilterName( - filters, - whitelistFilters, - filter_id, - 'filtered_custom_rules', - (filter) => (filter?.name ? t('query_log_filtered', { filter: filter.name }) : ''), - ); + const rules = useSelector((state) => state.filtering.check.rules, shallowEqual); + const reason = useSelector((state) => state.filtering.check.reason); const getReasonFiltered = (reason) => { const filterKey = reason.replace(FILTERED, ''); return i18next.t('query_log_filtered', { filter: filterKey }); }; + const ruleAndFilterNames = getRulesToFilterList(rules, filters, whitelistFilters); + const REASON_TO_TITLE_MAP = { [FILTERED_STATUS.NOT_FILTERED_NOT_FOUND]: t('check_not_found'), [FILTERED_STATUS.REWRITE]: t('rewrite_applied'), [FILTERED_STATUS.REWRITE_HOSTS]: t('rewrite_hosts_applied'), - [FILTERED_STATUS.FILTERED_BLACK_LIST]: filterName, - [FILTERED_STATUS.NOT_FILTERED_WHITE_LIST]: filterName, + [FILTERED_STATUS.FILTERED_BLACK_LIST]: ruleAndFilterNames, + [FILTERED_STATUS.NOT_FILTERED_WHITE_LIST]: ruleAndFilterNames, [FILTERED_STATUS.FILTERED_SAFE_SEARCH]: getReasonFiltered(reason), [FILTERED_STATUS.FILTERED_SAFE_BROWSING]: getReasonFiltered(reason), [FILTERED_STATUS.FILTERED_PARENTAL]: getReasonFiltered(reason), @@ -78,7 +73,11 @@ const getTitle = (reason) => { return <>
{t('check_reason', { reason })}
-
{filterName}
+
+ {t('rule_label')}: +   + {ruleAndFilterNames} +
; }; @@ -86,14 +85,13 @@ const Info = () => { const { hostname, reason, - rule, service_name, cname, ip_addrs, } = useSelector((state) => state.filtering.check, shallowEqual); const { t } = useTranslation(); - const title = getTitle(reason); + const title = getTitle(); const className = classNames('card mb-0 p-3', { 'logs__row--red': checkFiltered(reason), @@ -112,7 +110,6 @@ const Info = () => {
{title}
{!onlyFiltered && <> - {rule &&
{t('check_rule', { rule })}
} {service_name &&
{t('check_service', { service: service_name })}
} {cname &&
{t('check_cname', { cname })}
} {ip_addrs &&
{t('check_ip', { ip: ip_addrs.join(', ') })}
} diff --git a/client/src/components/Header/index.js b/client/src/components/Header/index.js index 071f950f..554ec603 100644 --- a/client/src/components/Header/index.js +++ b/client/src/components/Header/index.js @@ -46,7 +46,7 @@ const Header = () => {
- + AdGuard Home logo {!processing && isCoreRunning && { const { t } = useTranslation(); @@ -36,7 +36,6 @@ const ResponseCell = ({ const statusLabel = t(isBlockedByResponse ? 'blocked_by_cname_or_ip' : FILTERED_STATUS_TO_META_MAP[reason]?.LABEL || reason); const boldStatusLabel = {statusLabel}; - const filter = getFilterName(filters, whitelistFilters, filterId); const renderResponses = (responseArr) => { if (!responseArr || responseArr.length === 0) { @@ -52,18 +51,23 @@ const ResponseCell = ({ })}
; }; + const rulesList = getRulesToFilterList(rules, filters, whitelistFilters); + const COMMON_CONTENT = { encryption_status: boldStatusLabel, install_settings_dns: upstream, elapsed: formattedElapsedMs, response_code: status, - ...(service_name ? { service_name: getServiceName(service_name) } : { filter }), - rule_label: rule, + ...(service_name + ? { service_name: getServiceName(service_name) } + : { } + ), + rule_label: rulesList, response_table_header: renderResponses(response), original_response: renderResponses(originalResponse), }; - const content = rule + const content = rules.length > 0 ? Object.entries(COMMON_CONTENT) : Object.entries({ ...COMMON_CONTENT, @@ -78,7 +82,8 @@ const ResponseCell = ({ } return getServiceName(service_name); case FILTERED_STATUS.FILTERED_BLACK_LIST: - return filter; + case FILTERED_STATUS.NOT_FILTERED_WHITE_LIST: + return getFilterNames(rules, filters, whitelistFilters).join(', '); default: return formattedElapsedMs; } @@ -113,8 +118,10 @@ ResponseCell.propTypes = { response: propTypes.array.isRequired, status: propTypes.string.isRequired, upstream: propTypes.string.isRequired, - rule: propTypes.string, - filterId: propTypes.number, + rules: propTypes.arrayOf(propTypes.shape({ + text: propTypes.string.isRequired, + filter_list_id: propTypes.number.isRequired, + })), service_name: propTypes.string, }; diff --git a/client/src/components/Logs/Cells/index.js b/client/src/components/Logs/Cells/index.js index 2e2635d9..8435a617 100644 --- a/client/src/components/Logs/Cells/index.js +++ b/client/src/components/Logs/Cells/index.js @@ -6,11 +6,11 @@ import propTypes from 'prop-types'; import { captitalizeWords, checkFiltered, + getRulesToFilterList, formatDateTime, formatElapsedMs, formatTime, getBlockingClientName, - getFilterName, getServiceName, processContent, } from '../../../helpers/helpers'; @@ -70,8 +70,7 @@ const Row = memo(({ upstream, type, client_proto, - filterId, - rule, + rules, originalResponse, status, service_name, @@ -107,8 +106,6 @@ const Row = memo(({ const sourceData = getSourceData(tracker); - const filter = getFilterName(filters, whitelistFilters, filterId); - const { confirmMessage, buttonKey: blockingClientKey, @@ -172,8 +169,8 @@ const Row = memo(({ response_details: 'title', install_settings_dns: upstream, elapsed: formattedElapsedMs, - filter: rule ? filter : null, - rule_label: rule, + rule_label: rules.length > 0 + && getRulesToFilterList(rules, filters, whitelistFilters), response_table_header: response?.join('\n'), response_code: status, client_details: 'title', @@ -235,8 +232,10 @@ Row.propTypes = { upstream: propTypes.string.isRequired, type: propTypes.string.isRequired, client_proto: propTypes.string.isRequired, - filterId: propTypes.number, - rule: propTypes.string, + rules: propTypes.arrayOf(propTypes.shape({ + text: propTypes.string.isRequired, + filter_list_id: propTypes.number.isRequired, + })), originalResponse: propTypes.array, status: propTypes.string.isRequired, service_name: propTypes.string, diff --git a/client/src/components/Logs/Logs.css b/client/src/components/Logs/Logs.css index c4f70827..a48a8d15 100644 --- a/client/src/components/Logs/Logs.css +++ b/client/src/components/Logs/Logs.css @@ -9,21 +9,18 @@ --size-response: 150; --size-client: 123; --gray-216: rgba(216, 216, 216, 0.23); - --gray-4d: #4D4D4D; - --gray-f3: #F3F3F3; + --gray-4d: #4d4d4d; + --gray-f3: #f3f3f3; --gray-8: #888; --gray-3: #333; - --danger: #DF3812; + --danger: #df3812; --white80: rgba(255, 255, 255, 0.8); - - --btn-block: #C23814; - --btn-block-disabled: #E3B3A6; - --btn-block-active: #A62200; - + --btn-block: #c23814; + --btn-block-disabled: #e3b3a6; + --btn-block-active: #a62200; --btn-unblock: #888888; - --btn-unblock-disabled: #D8D8D8; - --btn-unblock-active: #4D4D4D; - + --btn-unblock-disabled: #d8d8d8; + --btn-unblock-active: #4d4d4d; --option-border-radius: 4px; } @@ -87,7 +84,7 @@ } .custom-select__arrow--left { - background: var(--white) url('../ui/svg/chevron-down.svg') no-repeat; + background: var(--white) url("../ui/svg/chevron-down.svg") no-repeat; background-position: 5px 9px; background-size: 22px; } @@ -431,3 +428,13 @@ margin-right: 1px; opacity: 0.5; } + +.filteringRules__rule { + margin-bottom: 0; +} + +.filteringRules__filter { + font-style: italic; + font-weight: normal; + margin-bottom: 1rem; +} diff --git a/client/src/components/ui/Icons.css b/client/src/components/ui/Icons.css index 24d71538..73f4c864 100644 --- a/client/src/components/ui/Icons.css +++ b/client/src/components/ui/Icons.css @@ -6,18 +6,21 @@ .icon--24 { --size: 1.5rem; + width: var(--size); height: var(--size); } .icon--20 { --size: 1.25rem; + width: var(--size); height: var(--size); } .icon--18 { --size: 1.125rem; + width: var(--size); height: var(--size); } diff --git a/client/src/helpers/helpers.js b/client/src/helpers/helpers.js index 0cefcfdf..82f30245 100644 --- a/client/src/helpers/helpers.js +++ b/client/src/helpers/helpers.js @@ -7,6 +7,7 @@ import i18n from 'i18next'; import uniqBy from 'lodash/uniqBy'; import ipaddr from 'ipaddr.js'; import queryString from 'query-string'; +import React from 'react'; import { getTrackerData } from './trackers/trackers'; import { @@ -68,6 +69,7 @@ export const normalizeLogs = (logs) => logs.map((log) => { time, filterId, rule, + rules, service_name, original_answer, upstream, @@ -80,6 +82,15 @@ export const normalizeLogs = (logs) => logs.map((log) => { return `${type}: ${value} (ttl=${ttl})`; }) : []); + let newRules = rules; + /* TODO 'filterId' and 'rule' are deprecated, will be removed in 0.106 */ + if (rule !== undefined && filterId !== undefined && rules !== undefined && rules.length === 0) { + newRules = { + filter_list_id: filterId, + text: rule, + }; + } + return { time, domain, @@ -88,8 +99,10 @@ export const normalizeLogs = (logs) => logs.map((log) => { reason, client, client_proto, + /* TODO 'filterId' and 'rule' are deprecated, will be removed in 0.106 */ filterId, rule, + rules: newRules, status, service_name, originalAnswer: original_answer, @@ -190,7 +203,12 @@ export const getIpList = (interfaces) => Object.values(interfaces) .reduce((acc, curr) => acc.concat(curr.ip_addresses), []) .sort(); -export const getDnsAddress = (ip, port = '') => { +/** + * @param {string} ip + * @param {number} [port] + * @returns {string} + */ +export const getDnsAddress = (ip, port = 0) => { const isStandardDnsPort = port === STANDARD_DNS_PORT; let address = ip; @@ -205,7 +223,12 @@ export const getDnsAddress = (ip, port = '') => { return address; }; -export const getWebAddress = (ip, port = '') => { +/** + * @param {string} ip + * @param {number} [port] + * @returns {string} + */ +export const getWebAddress = (ip, port = 0) => { const isStandardWebPort = port === STANDARD_WEB_PORT; let address = `http://${ip}`; @@ -716,6 +739,75 @@ export const getFilterName = ( return resolveFilterName(filter); }; +/** + * @param {array} rules + * @param {array} filters + * @param {array} whitelistFilters + * @returns {string[]} + */ +export const getFilterNames = (rules, filters, whitelistFilters) => rules.map( + ({ filter_list_id }) => getFilterName(filters, whitelistFilters, filter_list_id), +); + +/** + * @param {array} rules + * @returns {string[]} + */ +export const getRuleNames = (rules) => rules.map(({ text }) => text); + +/** + * @param {array} rules + * @param {array} filters + * @param {array} whitelistFilters + * @returns {object} + */ +export const getFilterNameToRulesMap = (rules, filters, whitelistFilters) => rules.reduce( + (acc, { text, filter_list_id }) => { + const filterName = getFilterName(filters, whitelistFilters, filter_list_id); + + acc[filterName] = (acc[filterName] || []).concat(text); + return acc; + }, {}, +); + +/** + * @param {array} rules + * @param {array} filters + * @param {array} whitelistFilters + * @param {object} classes + * @returns {JSXElement} + */ +export const getRulesToFilterList = (rules, filters, whitelistFilters, classes = { + list: 'filteringRules', + rule: 'filteringRules__rule font-monospace', + filter: 'filteringRules__filter', +}) => { + const filterNameToRulesMap = getFilterNameToRulesMap(rules, filters, whitelistFilters); + + return
+ {Object.entries(filterNameToRulesMap).reduce( + (acc, [filterName, rulesArr]) => acc + .concat(rulesArr.map((rule, i) =>
{rule}
)) + .concat(
{filterName}
), + [], + )} +
; +}; + +/** +* @param {array} rules +* @param {array} filters +* @param {array} whitelistFilters +* @returns {string} +*/ +export const getRulesAndFilterNames = (rules, filters, whitelistFilters) => { + const filterNameToRulesMap = getFilterNameToRulesMap(rules, filters, whitelistFilters); + + return Object.entries(filterNameToRulesMap).map( + ([filterName, filterRules]) => filterRules.concat(filterName).join('\n'), + ).join('\n\n'); +}; + /** * @param ip {string} * @param gateway_ip {string} diff --git a/client/src/helpers/renderFormattedClientCell.js b/client/src/helpers/renderFormattedClientCell.js index d677c4ca..f7e59a84 100644 --- a/client/src/helpers/renderFormattedClientCell.js +++ b/client/src/helpers/renderFormattedClientCell.js @@ -31,7 +31,7 @@ const getFormattedWhois = (whois) => { * @param {object} info.whois_info * @param {boolean} [isDetailed] * @param {boolean} [isLogs] - * @returns {JSX.Element} + * @returns {JSXElement} */ export const renderFormattedClientCell = (value, info, isDetailed = false, isLogs = false) => { let whoisContainer = null; diff --git a/client/src/helpers/validators.js b/client/src/helpers/validators.js index f2fee026..2632df3f 100644 --- a/client/src/helpers/validators.js +++ b/client/src/helpers/validators.js @@ -16,7 +16,7 @@ import { getLastIpv4Octet, isValidAbsolutePath } from './form'; // https://redux-form.com/8.3.0/examples/fieldlevelvalidation/ // If the value is valid, the validation function should return undefined. /** - * @param value {string} + * @param value {string|number} * @returns {undefined|string} */ export const validateRequiredValue = (value) => { diff --git a/client/src/install/Setup/AddressList.js b/client/src/install/Setup/AddressList.js index 15cf7113..ed58127c 100644 --- a/client/src/install/Setup/AddressList.js +++ b/client/src/install/Setup/AddressList.js @@ -41,16 +41,13 @@ const AddressList = ({ AddressList.propTypes = { interfaces: PropTypes.object.isRequired, address: PropTypes.string.isRequired, - port: PropTypes.oneOfType([ - PropTypes.string, - PropTypes.number, - ]), + port: PropTypes.number.isRequired, isDns: PropTypes.bool, }; renderItem.propTypes = { ip: PropTypes.string.isRequired, - port: PropTypes.string.isRequired, + port: PropTypes.number.isRequired, isDns: PropTypes.bool.isRequired, }; diff --git a/client/src/reducers/access.js b/client/src/reducers/access.js index e90bb314..69bf580d 100644 --- a/client/src/reducers/access.js +++ b/client/src/reducers/access.js @@ -24,13 +24,7 @@ const access = handleActions( [actions.setAccessListRequest]: (state) => ({ ...state, processingSet: true }), [actions.setAccessListFailure]: (state) => ({ ...state, processingSet: false }), - [actions.setAccessListSuccess]: (state) => { - const newState = { - ...state, - processingSet: false, - }; - return newState; - }, + [actions.setAccessListSuccess]: (state) => ({ ...state, processingSet: false }), [actions.toggleClientBlockRequest]: (state) => ({ ...state, processingSet: true }), [actions.toggleClientBlockFailure]: (state) => ({ ...state, processingSet: false }), diff --git a/go.mod b/go.mod index 10866d88..61c563e5 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,7 @@ go 1.14 require ( github.com/AdguardTeam/dnsproxy v0.33.7 github.com/AdguardTeam/golibs v0.4.4 - github.com/AdguardTeam/urlfilter v0.14.0 + github.com/AdguardTeam/urlfilter v0.14.1 github.com/NYTimes/gziphandler v1.1.1 github.com/ameshkov/dnscrypt/v2 v2.0.1 github.com/fsnotify/fsnotify v1.4.9 diff --git a/go.sum b/go.sum index 764a2fa0..d8d6ef4d 100644 --- a/go.sum +++ b/go.sum @@ -26,8 +26,8 @@ github.com/AdguardTeam/golibs v0.4.2/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKU github.com/AdguardTeam/golibs v0.4.4 h1:cM9UySQiYFW79zo5XRwnaIWVzfW4eNXmZktMrWbthpw= github.com/AdguardTeam/golibs v0.4.4/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4= github.com/AdguardTeam/gomitmproxy v0.2.0/go.mod h1:Qdv0Mktnzer5zpdpi5rAwixNJzW2FN91LjKJCkVbYGU= -github.com/AdguardTeam/urlfilter v0.14.0 h1:+aAhOvZDVGzl5gTERB4pOJCL1zxMyw7vLecJJ6TQTCw= -github.com/AdguardTeam/urlfilter v0.14.0/go.mod h1:klx4JbOfc4EaNb5lWLqOwfg+pVcyRukmoJRvO55lL5U= +github.com/AdguardTeam/urlfilter v0.14.1 h1:imYls0fit9ojA6pP1hWFUEIjyoXbDF85ZM+G67bI48c= +github.com/AdguardTeam/urlfilter v0.14.1/go.mod h1:klx4JbOfc4EaNb5lWLqOwfg+pVcyRukmoJRvO55lL5U= github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= diff --git a/internal/dnsfilter/dnsrewrite.go b/internal/dnsfilter/dnsrewrite.go index 1239fbad..66cb5828 100644 --- a/internal/dnsfilter/dnsrewrite.go +++ b/internal/dnsfilter/dnsrewrite.go @@ -7,8 +7,8 @@ import ( // DNSRewriteResult is the result of application of $dnsrewrite rules. type DNSRewriteResult struct { - RCode rules.RCode `json:",omitempty"` Response DNSRewriteResultResponse `json:",omitempty"` + RCode rules.RCode `json:",omitempty"` } // DNSRewriteResultResponse is the collection of DNS response records diff --git a/internal/dnsforward/dns.go b/internal/dnsforward/dns.go index d7208d1c..0e5eb51f 100644 --- a/internal/dnsforward/dns.go +++ b/internal/dnsforward/dns.go @@ -379,7 +379,7 @@ func processFilteringAfterResponse(ctx *dnsContext) int { if len(d.Res.Answer) != 0 { answer := []dns.RR{} - answer = append(answer, s.genCNAMEAnswer(d.Req, res.CanonName)) + answer = append(answer, s.genAnswerCNAME(d.Req, res.CanonName)) answer = append(answer, d.Res.Answer...) d.Res.Answer = answer } diff --git a/internal/dnsforward/dnsrewrite.go b/internal/dnsforward/dnsrewrite.go index 01895323..3741e850 100644 --- a/internal/dnsforward/dnsrewrite.go +++ b/internal/dnsforward/dnsrewrite.go @@ -13,27 +13,55 @@ import ( ) // filterDNSRewriteResponse handles a single DNS rewrite response entry. -// It returns the constructed answer resource record. +// It returns the properly constructed answer resource record. func (s *Server) filterDNSRewriteResponse(req *dns.Msg, rr rules.RRType, v rules.RRValue) (ans dns.RR, err error) { + // TODO(a.garipov): As more types are added, we will probably want to + // use a handler-oriented approach here. So, think of a way to decouple + // the answer generation logic from the Server. + switch rr { case dns.TypeA, dns.TypeAAAA: ip, ok := v.(net.IP) if !ok { - return nil, fmt.Errorf("value has type %T, not net.IP", v) + return nil, fmt.Errorf("value for rr type %d has type %T, not net.IP", rr, v) } if rr == dns.TypeA { - return s.genAAnswer(req, ip.To4()), nil + return s.genAnswerA(req, ip.To4()), nil } - return s.genAAAAAnswer(req, ip), nil - case dns.TypeTXT: + return s.genAnswerAAAA(req, ip), nil + case dns.TypePTR, + dns.TypeTXT: str, ok := v.(string) if !ok { - return nil, fmt.Errorf("value has type %T, not string", v) + return nil, fmt.Errorf("value for rr type %d has type %T, not string", rr, v) } - return s.genTXTAnswer(req, []string{str}), nil + if rr == dns.TypeTXT { + return s.genAnswerTXT(req, []string{str}), nil + } + + return s.genAnswerPTR(req, str), nil + case dns.TypeMX: + mx, ok := v.(*rules.DNSMX) + if !ok { + return nil, fmt.Errorf("value for rr type %d has type %T, not *rules.DNSMX", rr, v) + } + + return s.genAnswerMX(req, mx), nil + case dns.TypeHTTPS, + dns.TypeSVCB: + svcb, ok := v.(*rules.DNSSVCB) + if !ok { + return nil, fmt.Errorf("value for rr type %d has type %T, not *rules.DNSSVCB", rr, v) + } + + if rr == dns.TypeHTTPS { + return s.genAnswerHTTPS(req, svcb), nil + } + + return s.genAnswerSVCB(req, svcb), nil default: log.Debug("don't know how to handle dns rr type %d, skipping", rr) diff --git a/internal/dnsforward/filter.go b/internal/dnsforward/filter.go index 5cd0090a..c6bfb160 100644 --- a/internal/dnsforward/filter.go +++ b/internal/dnsforward/filter.go @@ -87,17 +87,17 @@ func (s *Server) filterDNSRequest(ctx *dnsContext) (*dnsfilter.Result, error) { name := host if len(res.CanonName) != 0 { - resp.Answer = append(resp.Answer, s.genCNAMEAnswer(req, res.CanonName)) + resp.Answer = append(resp.Answer, s.genAnswerCNAME(req, res.CanonName)) name = res.CanonName } for _, ip := range res.IPList { if req.Question[0].Qtype == dns.TypeA { - a := s.genAAnswer(req, ip.To4()) + a := s.genAnswerA(req, ip.To4()) a.Hdr.Name = dns.Fqdn(name) resp.Answer = append(resp.Answer, a) } else if req.Question[0].Qtype == dns.TypeAAAA { - a := s.genAAAAAnswer(req, ip) + a := s.genAnswerAAAA(req, ip) a.Hdr.Name = dns.Fqdn(name) resp.Answer = append(resp.Answer, a) } diff --git a/internal/dnsforward/msg.go b/internal/dnsforward/msg.go index f8200056..28a8ac3a 100644 --- a/internal/dnsforward/msg.go +++ b/internal/dnsforward/msg.go @@ -7,6 +7,7 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/dnsfilter" "github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/golibs/log" + "github.com/AdguardTeam/urlfilter/rules" "github.com/miekg/dns" ) @@ -92,48 +93,64 @@ func (s *Server) genServerFailure(request *dns.Msg) *dns.Msg { func (s *Server) genARecord(request *dns.Msg, ip net.IP) *dns.Msg { resp := s.makeResponse(request) - resp.Answer = append(resp.Answer, s.genAAnswer(request, ip)) + resp.Answer = append(resp.Answer, s.genAnswerA(request, ip)) return resp } func (s *Server) genAAAARecord(request *dns.Msg, ip net.IP) *dns.Msg { resp := s.makeResponse(request) - resp.Answer = append(resp.Answer, s.genAAAAAnswer(request, ip)) + resp.Answer = append(resp.Answer, s.genAnswerAAAA(request, ip)) return resp } -func (s *Server) genAAnswer(req *dns.Msg, ip net.IP) *dns.A { - answer := new(dns.A) - answer.Hdr = dns.RR_Header{ +func (s *Server) hdr(req *dns.Msg, rrType rules.RRType) (h dns.RR_Header) { + return dns.RR_Header{ Name: req.Question[0].Name, - Rrtype: dns.TypeA, + Rrtype: rrType, Ttl: s.conf.BlockedResponseTTL, Class: dns.ClassINET, } - answer.A = ip - return answer } -func (s *Server) genAAAAAnswer(req *dns.Msg, ip net.IP) *dns.AAAA { - answer := new(dns.AAAA) - answer.Hdr = dns.RR_Header{ - Name: req.Question[0].Name, - Rrtype: dns.TypeAAAA, - Ttl: s.conf.BlockedResponseTTL, - Class: dns.ClassINET, +func (s *Server) genAnswerA(req *dns.Msg, ip net.IP) (ans *dns.A) { + return &dns.A{ + Hdr: s.hdr(req, dns.TypeA), + A: ip, } - answer.AAAA = ip - return answer } -func (s *Server) genTXTAnswer(req *dns.Msg, strs []string) (answer *dns.TXT) { +func (s *Server) genAnswerAAAA(req *dns.Msg, ip net.IP) (ans *dns.AAAA) { + return &dns.AAAA{ + Hdr: s.hdr(req, dns.TypeAAAA), + AAAA: ip, + } +} + +func (s *Server) genAnswerCNAME(req *dns.Msg, cname string) (ans *dns.CNAME) { + return &dns.CNAME{ + Hdr: s.hdr(req, dns.TypeCNAME), + Target: dns.Fqdn(cname), + } +} + +func (s *Server) genAnswerMX(req *dns.Msg, mx *rules.DNSMX) (ans *dns.MX) { + return &dns.MX{ + Hdr: s.hdr(req, dns.TypePTR), + Preference: mx.Preference, + Mx: mx.Exchange, + } +} + +func (s *Server) genAnswerPTR(req *dns.Msg, ptr string) (ans *dns.PTR) { + return &dns.PTR{ + Hdr: s.hdr(req, dns.TypePTR), + Ptr: ptr, + } +} + +func (s *Server) genAnswerTXT(req *dns.Msg, strs []string) (ans *dns.TXT) { return &dns.TXT{ - Hdr: dns.RR_Header{ - Name: req.Question[0].Name, - Rrtype: dns.TypeTXT, - Ttl: s.conf.BlockedResponseTTL, - Class: dns.ClassINET, - }, + Hdr: s.hdr(req, dns.TypeTXT), Txt: strs, } } @@ -198,19 +215,6 @@ func (s *Server) genBlockedHost(request *dns.Msg, newAddr string, d *proxy.DNSCo return resp } -// Make a CNAME response -func (s *Server) genCNAMEAnswer(req *dns.Msg, cname string) *dns.CNAME { - answer := new(dns.CNAME) - answer.Hdr = dns.RR_Header{ - Name: req.Question[0].Name, - Rrtype: dns.TypeCNAME, - Ttl: s.conf.BlockedResponseTTL, - Class: dns.ClassINET, - } - answer.Target = dns.Fqdn(cname) - return answer -} - // Create REFUSED DNS response func (s *Server) makeResponseREFUSED(request *dns.Msg) *dns.Msg { resp := dns.Msg{} diff --git a/internal/dnsforward/svcbmsg.go b/internal/dnsforward/svcbmsg.go new file mode 100644 index 00000000..2a8c27b4 --- /dev/null +++ b/internal/dnsforward/svcbmsg.go @@ -0,0 +1,168 @@ +package dnsforward + +import ( + "encoding/base64" + "net" + "strconv" + + "github.com/AdguardTeam/golibs/log" + "github.com/AdguardTeam/urlfilter/rules" + "github.com/miekg/dns" +) + +// genAnswerHTTPS returns a properly initialized HTTPS resource record. +// +// See the comment on genAnswerSVCB for a list of current restrictions on +// parameter values. +func (s *Server) genAnswerHTTPS(req *dns.Msg, svcb *rules.DNSSVCB) (ans *dns.HTTPS) { + ans = &dns.HTTPS{ + SVCB: *s.genAnswerSVCB(req, svcb), + } + + ans.Hdr.Rrtype = dns.TypeHTTPS + + return ans +} + +// strToSVCBKey is the string-to-svcb-key mapping. +// +// See https://github.com/miekg/dns/blob/23c4faca9d32b0abbb6e179aa1aadc45ac53a916/svcb.go#L27. +// +// TODO(a.garipov): Propose exporting this API or something similar in the +// github.com/miekg/dns module. +var strToSVCBKey = map[string]dns.SVCBKey{ + "alpn": dns.SVCB_ALPN, + "echconfig": dns.SVCB_ECHCONFIG, + "ipv4hint": dns.SVCB_IPV4HINT, + "ipv6hint": dns.SVCB_IPV6HINT, + "mandatory": dns.SVCB_MANDATORY, + "no-default-alpn": dns.SVCB_NO_DEFAULT_ALPN, + "port": dns.SVCB_PORT, +} + +// svcbKeyHandler is a handler for one SVCB parameter key. +type svcbKeyHandler func(valStr string) (val dns.SVCBKeyValue) + +// svcbKeyHandlers are the supported SVCB parameters handlers. +var svcbKeyHandlers = map[string]svcbKeyHandler{ + "alpn": func(valStr string) (val dns.SVCBKeyValue) { + return &dns.SVCBAlpn{ + Alpn: []string{valStr}, + } + }, + + "echconfig": func(valStr string) (val dns.SVCBKeyValue) { + ech, err := base64.StdEncoding.DecodeString(valStr) + if err != nil { + log.Debug("can't parse svcb/https echconfig: %s; ignoring", err) + + return nil + } + + return &dns.SVCBECHConfig{ + ECH: ech, + } + }, + + "ipv4hint": func(valStr string) (val dns.SVCBKeyValue) { + ip := net.ParseIP(valStr) + if ip4 := ip.To4(); ip == nil || ip4 == nil { + log.Debug("can't parse svcb/https ipv4 hint %q; ignoring", valStr) + + return nil + } + + return &dns.SVCBIPv4Hint{ + Hint: []net.IP{ip}, + } + }, + + "ipv6hint": func(valStr string) (val dns.SVCBKeyValue) { + ip := net.ParseIP(valStr) + if ip == nil { + log.Debug("can't parse svcb/https ipv6 hint %q; ignoring", valStr) + + return nil + } + + return &dns.SVCBIPv6Hint{ + Hint: []net.IP{ip}, + } + }, + + "mandatory": func(valStr string) (val dns.SVCBKeyValue) { + code, ok := strToSVCBKey[valStr] + if !ok { + log.Debug("unknown svcb/https mandatory key %q, ignoring", valStr) + + return nil + } + + return &dns.SVCBMandatory{ + Code: []dns.SVCBKey{code}, + } + }, + + "no-default-alpn": func(_ string) (val dns.SVCBKeyValue) { + return &dns.SVCBNoDefaultAlpn{} + }, + + "port": func(valStr string) (val dns.SVCBKeyValue) { + port64, err := strconv.ParseUint(valStr, 10, 16) + if err != nil { + log.Debug("can't parse svcb/https port: %s; ignoring", err) + + return nil + } + + return &dns.SVCBPort{ + Port: uint16(port64), + } + }, +} + +// genAnswerSVCB returns a properly initialized SVCB resource record. +// +// Currently, there are several restrictions on how the parameters are parsed. +// Firstly, the parsing of non-contiguous values isn't supported. Secondly, the +// parsing of value-lists is not supported either. +// +// ipv4hint=127.0.0.1 // Supported. +// ipv4hint="127.0.0.1" // Unsupported. +// ipv4hint=127.0.0.1,127.0.0.2 // Unsupported. +// ipv4hint="127.0.0.1,127.0.0.2" // Unsupported. +// +// TODO(a.garipov): Support all of these. +func (s *Server) genAnswerSVCB(req *dns.Msg, svcb *rules.DNSSVCB) (ans *dns.SVCB) { + ans = &dns.SVCB{ + Hdr: s.hdr(req, dns.TypeSVCB), + Priority: svcb.Priority, + Target: svcb.Target, + } + if len(svcb.Params) == 0 { + return ans + } + + values := make([]dns.SVCBKeyValue, 0, len(svcb.Params)) + for k, valStr := range svcb.Params { + handler, ok := svcbKeyHandlers[k] + if !ok { + log.Debug("unknown svcb/https key %q, ignoring", k) + + continue + } + + val := handler(valStr) + if val == nil { + continue + } + + values = append(values, val) + } + + if len(values) > 0 { + ans.Value = values + } + + return ans +} diff --git a/internal/dnsforward/svcbmsg_test.go b/internal/dnsforward/svcbmsg_test.go new file mode 100644 index 00000000..392e92ac --- /dev/null +++ b/internal/dnsforward/svcbmsg_test.go @@ -0,0 +1,154 @@ +package dnsforward + +import ( + "net" + "testing" + + "github.com/AdguardTeam/urlfilter/rules" + "github.com/miekg/dns" + "github.com/stretchr/testify/assert" +) + +func TestGenAnswerHTTPS_andSVCB(t *testing.T) { + // Preconditions. + + s := &Server{ + conf: ServerConfig{ + FilteringConfig: FilteringConfig{ + BlockedResponseTTL: 3600, + }, + }, + } + + req := &dns.Msg{ + Question: []dns.Question{{ + Name: "abcd", + }}, + } + + // Constants and helper values. + + const host = "example.com" + const prio = 32 + + ip4 := net.IPv4(127, 0, 0, 1) + ip6 := net.IP{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1} + + // Helper functions. + + dnssvcb := func(key, value string) (svcb *rules.DNSSVCB) { + svcb = &rules.DNSSVCB{ + Target: host, + Priority: prio, + } + + if key == "" { + return svcb + } + + svcb.Params = map[string]string{ + key: value, + } + + return svcb + } + + wantsvcb := func(kv dns.SVCBKeyValue) (want *dns.SVCB) { + want = &dns.SVCB{ + Hdr: s.hdr(req, dns.TypeSVCB), + Priority: prio, + Target: host, + } + + if kv == nil { + return want + } + + want.Value = []dns.SVCBKeyValue{kv} + + return want + } + + // Tests. + + testCases := []struct { + svcb *rules.DNSSVCB + want *dns.SVCB + name string + }{{ + svcb: dnssvcb("", ""), + want: wantsvcb(nil), + name: "no_params", + }, { + svcb: dnssvcb("foo", "bar"), + want: wantsvcb(nil), + name: "invalid", + }, { + svcb: dnssvcb("alpn", "h3"), + want: wantsvcb(&dns.SVCBAlpn{Alpn: []string{"h3"}}), + name: "alpn", + }, { + svcb: dnssvcb("echconfig", "AAAA"), + want: wantsvcb(&dns.SVCBECHConfig{ECH: []byte{0, 0, 0}}), + name: "echconfig", + }, { + svcb: dnssvcb("echconfig", "%BAD%"), + want: wantsvcb(nil), + name: "echconfig_invalid", + }, { + svcb: dnssvcb("ipv4hint", "127.0.0.1"), + want: wantsvcb(&dns.SVCBIPv4Hint{Hint: []net.IP{ip4}}), + name: "ipv4hint", + }, { + svcb: dnssvcb("ipv4hint", "127.0.01"), + want: wantsvcb(nil), + name: "ipv4hint_invalid", + }, { + svcb: dnssvcb("ipv6hint", "::1"), + want: wantsvcb(&dns.SVCBIPv6Hint{Hint: []net.IP{ip6}}), + name: "ipv6hint", + }, { + svcb: dnssvcb("ipv6hint", ":::1"), + want: wantsvcb(nil), + name: "ipv6hint_invalid", + }, { + svcb: dnssvcb("mandatory", "alpn"), + want: wantsvcb(&dns.SVCBMandatory{Code: []dns.SVCBKey{dns.SVCB_ALPN}}), + name: "mandatory", + }, { + svcb: dnssvcb("mandatory", "alpnn"), + want: wantsvcb(nil), + name: "mandatory_invalid", + }, { + svcb: dnssvcb("no-default-alpn", ""), + want: wantsvcb(&dns.SVCBNoDefaultAlpn{}), + name: "no-default-alpn", + }, { + svcb: dnssvcb("port", "8080"), + want: wantsvcb(&dns.SVCBPort{Port: 8080}), + name: "port", + }, { + svcb: dnssvcb("port", "1005008080"), + want: wantsvcb(nil), + name: "port", + }} + + for _, tc := range testCases { + t.Run("https", func(t *testing.T) { + t.Run(tc.name, func(t *testing.T) { + want := &dns.HTTPS{SVCB: *tc.want} + want.Hdr.Rrtype = dns.TypeHTTPS + + got := s.genAnswerHTTPS(req, tc.svcb) + assert.Equal(t, want, got) + }) + }) + + t.Run("svcb", func(t *testing.T) { + t.Run(tc.name, func(t *testing.T) { + got := s.genAnswerSVCB(req, tc.svcb) + assert.Equal(t, tc.want, got) + }) + }) + } +} diff --git a/internal/home/auth.go b/internal/home/auth.go index 00407fa0..dce17241 100644 --- a/internal/home/auth.go +++ b/internal/home/auth.go @@ -59,10 +59,10 @@ func (s *session) deserialize(data []byte) bool { // Auth - global object type Auth struct { db *bbolt.DB - sessions map[string]*session // session name -> session data - lock sync.Mutex + sessions map[string]*session users []User - sessionTTL uint32 // in seconds + lock sync.Mutex + sessionTTL uint32 } // User object @@ -223,24 +223,35 @@ func (a *Auth) removeSession(sess []byte) { log.Debug("Auth: removed session from DB") } -// CheckSession - check if session is valid -// Return 0 if OK; -1 if session doesn't exist; 1 if session has expired -func (a *Auth) CheckSession(sess string) int { +// checkSessionResult is the result of checking a session. +type checkSessionResult int + +// checkSessionResult constants. +const ( + checkSessionOK checkSessionResult = 0 + checkSessionNotFound checkSessionResult = -1 + checkSessionExpired checkSessionResult = 1 +) + +// checkSession checks if the session is valid. +func (a *Auth) checkSession(sess string) (res checkSessionResult) { now := uint32(time.Now().UTC().Unix()) update := false a.lock.Lock() + defer a.lock.Unlock() + s, ok := a.sessions[sess] if !ok { - a.lock.Unlock() - return -1 + return checkSessionNotFound } + if s.expire <= now { delete(a.sessions, sess) key, _ := hex.DecodeString(sess) a.removeSession(key) - a.lock.Unlock() - return 1 + + return checkSessionExpired } newExpire := now + a.sessionTTL @@ -250,8 +261,6 @@ func (a *Auth) CheckSession(sess string) int { s.expire = newExpire } - a.lock.Unlock() - if update { key, _ := hex.DecodeString(sess) if a.storeSession(key, s) { @@ -259,7 +268,7 @@ func (a *Auth) CheckSession(sess string) int { } } - return 0 + return checkSessionOK } // RemoveSession - remove session @@ -392,8 +401,8 @@ func optionalAuthThird(w http.ResponseWriter, r *http.Request) (authFirst bool) ok = true } else if err == nil { - r := Context.auth.CheckSession(cookie.Value) - if r == 0 { + r := Context.auth.checkSession(cookie.Value) + if r == checkSessionOK { ok = true } else if r < 0 { log.Debug("Auth: invalid cookie value: %s", cookie) @@ -434,12 +443,13 @@ func optionalAuth(handler func(http.ResponseWriter, *http.Request)) func(http.Re authRequired := Context.auth != nil && Context.auth.AuthRequired() cookie, err := r.Cookie(sessionCookieName) if authRequired && err == nil { - r := Context.auth.CheckSession(cookie.Value) - if r == 0 { + r := Context.auth.checkSession(cookie.Value) + if r == checkSessionOK { w.Header().Set("Location", "/") w.WriteHeader(http.StatusFound) + return - } else if r < 0 { + } else if r == checkSessionNotFound { log.Debug("Auth: invalid cookie value: %s", cookie) } } @@ -503,32 +513,34 @@ func (a *Auth) UserFind(login, password string) User { return User{} } -// GetCurrentUser - get the current user -func (a *Auth) GetCurrentUser(r *http.Request) User { +// getCurrentUser returns the current user. It returns an empty User if the +// user is not found. +func (a *Auth) getCurrentUser(r *http.Request) User { cookie, err := r.Cookie(sessionCookieName) if err != nil { - // there's no Cookie, check Basic authentication + // There's no Cookie, check Basic authentication. user, pass, ok := r.BasicAuth() if ok { - u := Context.auth.UserFind(user, pass) - return u + return Context.auth.UserFind(user, pass) } + return User{} } a.lock.Lock() + defer a.lock.Unlock() + s, ok := a.sessions[cookie.Value] if !ok { - a.lock.Unlock() return User{} } + for _, u := range a.users { if u.Name == s.userName { - a.lock.Unlock() return u } } - a.lock.Unlock() + return User{} } diff --git a/internal/home/auth_test.go b/internal/home/auth_test.go index 25db2dd6..0998a2a6 100644 --- a/internal/home/auth_test.go +++ b/internal/home/auth_test.go @@ -38,7 +38,7 @@ func TestAuth(t *testing.T) { user := User{Name: "name"} a.UserAdd(&user, "password") - assert.True(t, a.CheckSession("notfound") == -1) + assert.Equal(t, checkSessionNotFound, a.checkSession("notfound")) a.RemoveSession("notfound") sess, err := getSession(&users[0]) @@ -49,13 +49,13 @@ func TestAuth(t *testing.T) { // check expiration s.expire = uint32(now) a.addSession(sess, &s) - assert.True(t, a.CheckSession(sessStr) == 1) + assert.Equal(t, checkSessionExpired, a.checkSession(sessStr)) // add session with TTL = 2 sec s = session{} s.expire = uint32(time.Now().UTC().Unix() + 2) a.addSession(sess, &s) - assert.True(t, a.CheckSession(sessStr) == 0) + assert.Equal(t, checkSessionOK, a.checkSession(sessStr)) a.Close() @@ -63,8 +63,8 @@ func TestAuth(t *testing.T) { a = InitAuth(fn, users, 60) // the session is still alive - assert.True(t, a.CheckSession(sessStr) == 0) - // reset our expiration time because CheckSession() has just updated it + assert.Equal(t, checkSessionOK, a.checkSession(sessStr)) + // reset our expiration time because checkSession() has just updated it s.expire = uint32(time.Now().UTC().Unix() + 2) a.storeSession(sess, &s) a.Close() @@ -76,7 +76,7 @@ func TestAuth(t *testing.T) { // load and remove expired sessions a = InitAuth(fn, users, 60) - assert.True(t, a.CheckSession(sessStr) == -1) + assert.Equal(t, checkSessionNotFound, a.checkSession(sessStr)) a.Close() os.Remove(fn) @@ -111,7 +111,7 @@ func TestAuthHTTP(t *testing.T) { Context.auth = InitAuth(fn, users, 60) handlerCalled := false - handler := func(w http.ResponseWriter, r *http.Request) { + handler := func(_ http.ResponseWriter, _ *http.Request) { handlerCalled = true } handler2 := optionalAuth(handler) diff --git a/internal/home/control.go b/internal/home/control.go index 3443515a..616557a8 100644 --- a/internal/home/control.go +++ b/internal/home/control.go @@ -89,7 +89,7 @@ type profileJSON struct { func handleGetProfile(w http.ResponseWriter, r *http.Request) { pj := profileJSON{} - u := Context.auth.GetCurrentUser(r) + u := Context.auth.getCurrentUser(r) pj.Name = u.Name data, err := json.Marshal(pj) diff --git a/internal/home/controlupdate.go b/internal/home/controlupdate.go index dcf428b9..ff9dc4ab 100644 --- a/internal/home/controlupdate.go +++ b/internal/home/controlupdate.go @@ -2,13 +2,14 @@ package home import ( "encoding/json" + "errors" "net/http" "os" "os/exec" "path/filepath" "runtime" - "strings" "syscall" + "time" "github.com/AdguardTeam/AdGuardHome/internal/sysutil" "github.com/AdguardTeam/AdGuardHome/internal/update" @@ -19,6 +20,13 @@ type getVersionJSONRequest struct { RecheckNow bool `json:"recheck_now"` } +// temporaryError is the interface for temporary errors from the Go standard +// library. +type temporaryError interface { + error + Temporary() (ok bool) +} + // Get the latest available version from the Internet func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) { if Context.disableUpdate { @@ -41,14 +49,29 @@ func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) { var info update.VersionInfo for i := 0; i != 3; i++ { - Context.controlLock.Lock() - info, err = Context.updater.GetVersionResponse(req.RecheckNow) - Context.controlLock.Unlock() - if err != nil && strings.HasSuffix(err.Error(), "i/o timeout") { - // This case may happen while we're restarting DNS server - // https://github.com/AdguardTeam/AdGuardHome/internal/issues/934 - continue + func() { + Context.controlLock.Lock() + defer Context.controlLock.Unlock() + + info, err = Context.updater.GetVersionResponse(req.RecheckNow) + }() + + if err != nil { + var terr temporaryError + if errors.As(err, &terr) && terr.Temporary() { + // Temporary network error. This case may happen while + // we're restarting our DNS server. Log and sleep for + // some time. + // + // See https://github.com/AdguardTeam/AdGuardHome/issues/934. + d := time.Duration(i) * time.Second + log.Info("temp net error: %q; sleeping for %s and retrying", err, d) + time.Sleep(d) + + continue + } } + break } if err != nil { diff --git a/staticcheck.conf b/staticcheck.conf index 43639bf6..146f83cb 100644 --- a/staticcheck.conf +++ b/staticcheck.conf @@ -7,8 +7,11 @@ initialisms = [ , "DOQ" , "DOT" , "EDNS" +, "MX" +, "PTR" , "QUIC" , "SDNS" +, "SVCB" ] dot_import_whitelist = [] http_status_code_whitelist = []