diff --git a/docs/configuration.md b/docs/configuration.md index 01afc8ea..a5404d7e 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -1786,8 +1786,14 @@ Only required when using AdGuard Home. The username used to log into the admin d ##### `password` Only required when using AdGuard Home. The password used to log into the admin dashboard. Can be specified from an environment variable using the syntax `${VARIABLE_NAME}`. -##### `token` -Only required when using Pi-hole. The API token which can be found in `Settings -> API -> Show API token`. Can be specified from an environment variable using the syntax `${VARIABLE_NAME}`. +##### `token` (Deprecated) +Only required when using Pi-hole major version 5 or earlier. The API token which can be found in `Settings -> API -> Show API token`. Can be specified from an environment variable using the syntax `${VARIABLE_NAME}`. + +##### `app-password` +Only required when using Pi-hole. The App Password can be found in `Settings -> Web Interface / API -> Configure app password`. + +##### `pihole-version` +Only required if using an older version of Pi-hole (major version 5 or earlier). ##### `hide-graph` Whether to hide the graph showing the number of queries over time. diff --git a/internal/glance/widget-dns-stats.go b/internal/glance/widget-dns-stats.go index 833a80d3..de256759 100644 --- a/internal/glance/widget-dns-stats.go +++ b/internal/glance/widget-dns-stats.go @@ -1,12 +1,15 @@ package glance import ( + "bytes" "context" "encoding/json" "errors" + "fmt" "html/template" - "log/slog" + "io" "net/http" + "os" "sort" "strings" "time" @@ -27,6 +30,8 @@ type dnsStatsWidget struct { AllowInsecure bool `yaml:"allow-insecure"` URL string `yaml:"url"` Token string `yaml:"token"` + AppPassword string `yaml:"app-password"` + PiHoleVersion string `yaml:"pihole-version"` Username string `yaml:"username"` Password string `yaml:"password"` } @@ -62,7 +67,7 @@ func (widget *dnsStatsWidget) update(ctx context.Context) { if widget.Service == "adguard" { stats, err = fetchAdguardStats(widget.URL, widget.AllowInsecure, widget.Username, widget.Password, widget.HideGraph) } else { - stats, err = fetchPiholeStats(widget.URL, widget.AllowInsecure, widget.Token, widget.HideGraph) + stats, err = fetchPiholeStats(widget.URL, widget.AllowInsecure, widget.Token, widget.HideGraph, widget.PiHoleVersion, widget.AppPassword) } if !widget.canContinueUpdateAfterHandlingErr(err) { @@ -227,7 +232,8 @@ func fetchAdguardStats(instanceURL string, allowInsecure bool, username, passwor return stats, nil } -type piholeStatsResponse struct { +// Legacy Pi-hole stats response (before v6) +type legacyPiholeStatsResponse struct { TotalQueries int `json:"dns_queries_today"` QueriesSeries piholeQueriesSeries `json:"domains_over_time"` BlockedQueries int `json:"ads_blocked_today"` @@ -237,6 +243,34 @@ type piholeStatsResponse struct { DomainsBlocked int `json:"domains_being_blocked"` } +// Pi-hole v6+ response format +type piholeStatsResponse struct { + Queries struct { + Total int `json:"total"` + Blocked int `json:"blocked"` + PercentBlocked float64 `json:"percent_blocked"` + } `json:"queries"` + Gravity struct { + DomainsBlocked int `json:"domains_being_blocked"` + } `json:"gravity"` + //Note we do not need the full structure. We extract the values needed + //Adding dummy fields to allow easier json parsing. + QueriesSeries piholeQueriesSeries `json:"domains_over_time"` // Will always be empty + BlockedSeries map[int64]int `json:"ads_over_time"` // Will always be empty. +} + +type piholeTopDomainsResponse struct { + Domains []Domains `json:"domains"` + TotalQueries int `json:"total_queries"` + BlockedQueries int `json:"blocked_queries"` + Took float64 `json:"took"` +} + +type Domains struct { + Domain string `json:"domain"` + Count int `json:"count"` +} + // If the user has query logging disabled it's possible for domains_over_time to be returned as an // empty array rather than a map which will prevent unmashalling the rest of the data so we use // custom unmarshal behavior to fallback to an empty map. @@ -275,18 +309,65 @@ func (p *piholeTopBlockedDomains) UnmarshalJSON(data []byte) error { return nil } -func fetchPiholeStats(instanceURL string, allowInsecure bool, token string, noGraph bool) (*dnsStats, error) { - if token == "" { - return nil, errors.New("missing API token") +// piholeGetSID retrieves a new SID from Pi-hole using the app password. +func piholeGetSID(instanceURL, appPassword string, allowInsecure bool) (string, error) { + var client requestDoer + if !allowInsecure { + client = defaultHTTPClient + } else { + client = defaultInsecureHTTPClient + } + + requestURL := strings.TrimRight(instanceURL, "/") + "/api/auth" + requestBody := []byte(`{"password":"` + appPassword + `"}`) + + request, err := http.NewRequest("POST", requestURL, bytes.NewBuffer(requestBody)) + if err != nil { + return "", errors.New("failed to create authentication request: " + err.Error()) } + request.Header.Set("Content-Type", "application/json") - requestURL := strings.TrimRight(instanceURL, "/") + - "/admin/api.php?summaryRaw&topItems&overTimeData10mins&auth=" + token + response, err := client.Do(request) + if err != nil { + return "", errors.New("failed to send authentication request: " + err.Error()) + } + defer response.Body.Close() + + if response.StatusCode != http.StatusOK { + return "", errors.New("authentication failed, received status: " + response.Status) + } + + body, err := io.ReadAll(response.Body) + if err != nil { + return "", errors.New("failed to read authentication response: " + err.Error()) + } + + var jsonResponse struct { + Session struct { + SID string `json:"sid"` + } `json:"session"` + } + + if err := json.Unmarshal(body, &jsonResponse); err != nil { + return "", errors.New("failed to parse authentication response: " + err.Error()) + } + + if jsonResponse.Session.SID == "" { + return "", errors.New("authentication response did not contain a valid SID") + } + + return jsonResponse.Session.SID, nil +} + +// checkPiholeSID checks if the SID is valid by checking HTTP response status code from /api/auth. +func checkPiholeSID(instanceURL string, sid string, allowInsecure bool) error { + requestURL := strings.TrimRight(instanceURL, "/") + "/api/auth" request, err := http.NewRequest("GET", requestURL, nil) if err != nil { - return nil, err + return err } + request.Header.Set("x-ftl-sid", sid) var client requestDoer if !allowInsecure { @@ -295,25 +376,163 @@ func fetchPiholeStats(instanceURL string, allowInsecure bool, token string, noGr client = defaultInsecureHTTPClient } - responseJson, err := decodeJsonFromRequest[piholeStatsResponse](client, request) + response, err := client.Do(request) if err != nil { - return nil, err + return err + } + defer response.Body.Close() + + if response.StatusCode != http.StatusOK { + return errors.New("SID is invalid, received status: " + response.Status) + } + + return nil +} + +// fetchPiholeTopDomains fetches the top blocked domains for Pi-hole v6+. +func fetchPiholeTopDomains(instanceURL string, sid string, allowInsecure bool) (piholeTopDomainsResponse, error) { + requestURL := strings.TrimRight(instanceURL, "/") + "/api/stats/top_domains?blocked=true" + + request, err := http.NewRequest("GET", requestURL, nil) + if err != nil { + return piholeTopDomainsResponse{}, err + } + request.Header.Set("x-ftl-sid", sid) + + var client requestDoer + if !allowInsecure { + client = defaultHTTPClient + } else { + client = defaultInsecureHTTPClient + } + + return decodeJsonFromRequest[piholeTopDomainsResponse](client, request) +} + +// fetchPiholeSeries fetches the series data for Pi-hole v6+ (QueriesSeries and BlockedSeries). +func fetchPiholeSeries(instanceURL string, sid string, allowInsecure bool) (piholeQueriesSeries, map[int64]int, error) { + requestURL := strings.TrimRight(instanceURL, "/") + "/api/history" + + request, err := http.NewRequest("GET", requestURL, nil) + if err != nil { + return nil, nil, err } + request.Header.Set("x-ftl-sid", sid) + + var client requestDoer + if !allowInsecure { + client = defaultHTTPClient + } else { + client = defaultInsecureHTTPClient + } + + // Define the correct struct to match the API response + var responseJson struct { + History []struct { + Timestamp int64 `json:"timestamp"` + Total int `json:"total"` + Blocked int `json:"blocked"` + } `json:"history"` + } + + err = decodeJsonInto(client, request, &responseJson) + if err != nil { + return nil, nil, err + } + + queriesSeries := make(piholeQueriesSeries) + blockedSeries := make(map[int64]int) + + // Populate the series data from history array + for _, entry := range responseJson.History { + queriesSeries[entry.Timestamp] = entry.Total + blockedSeries[entry.Timestamp] = entry.Blocked + } + + return queriesSeries, blockedSeries, nil +} + +// Helper functions to process the responses +func parsePiholeStats(r piholeStatsResponse, topDomains piholeTopDomainsResponse, noGraph bool) *dnsStats { stats := &dnsStats{ - TotalQueries: responseJson.TotalQueries, - BlockedQueries: responseJson.BlockedQueries, - BlockedPercent: int(responseJson.BlockedPercentage), - DomainsBlocked: responseJson.DomainsBlocked, + TotalQueries: r.Queries.Total, + BlockedQueries: r.Queries.Blocked, + BlockedPercent: int(r.Queries.PercentBlocked), + DomainsBlocked: r.Gravity.DomainsBlocked, + } + + if len(topDomains.Domains) > 0 { + domains := make([]dnsStatsBlockedDomain, 0, len(topDomains.Domains)) + for _, d := range topDomains.Domains { + domains = append(domains, dnsStatsBlockedDomain{ + Domain: d.Domain, + PercentBlocked: int(float64(d.Count) / float64(r.Queries.Blocked) * 100), + }) + } + + sort.Slice(domains, func(a, b int) bool { + return domains[a].PercentBlocked > domains[b].PercentBlocked + }) + stats.TopBlockedDomains = domains[:min(len(domains), 5)] + } + if noGraph { + return stats } - if len(responseJson.TopBlockedDomains) > 0 { - domains := make([]dnsStatsBlockedDomain, 0, len(responseJson.TopBlockedDomains)) + // Pihole _should_ return data for the last 24 hours + if len(r.QueriesSeries) != 145 || len(r.BlockedSeries) != 145 { + return stats + } + + + var lowestTimestamp int64 = 0 + for timestamp := range r.QueriesSeries { + if lowestTimestamp == 0 || timestamp < lowestTimestamp { + lowestTimestamp = timestamp + } + } + maxQueriesInSeries := 0 - for domain, count := range responseJson.TopBlockedDomains { + for i := 0; i < 8; i++ { + queries := 0 + blocked := 0 + for j := 0; j < 18; j++ { + index := lowestTimestamp + int64(i*10800+j*600) + queries += r.QueriesSeries[index] + blocked += r.BlockedSeries[index] + } + if queries > maxQueriesInSeries { + maxQueriesInSeries = queries + } + stats.Series[i] = dnsStatsSeries{ + Queries: queries, + Blocked: blocked, + } + if queries > 0 { + stats.Series[i].PercentBlocked = int(float64(blocked) / float64(queries) * 100) + } + } + for i := 0; i < 8; i++ { + stats.Series[i].PercentTotal = int(float64(stats.Series[i].Queries) / float64(maxQueriesInSeries) * 100) + } + return stats +} +func parsePiholeStatsLegacy(r legacyPiholeStatsResponse, noGraph bool) *dnsStats { + + stats := &dnsStats{ + TotalQueries: r.TotalQueries, + BlockedQueries: r.BlockedQueries, + BlockedPercent: int(r.BlockedPercentage), + DomainsBlocked: r.DomainsBlocked, + } + if len(r.TopBlockedDomains) > 0 { + domains := make([]dnsStatsBlockedDomain, 0, len(r.TopBlockedDomains)) + + for domain, count := range r.TopBlockedDomains { domains = append(domains, dnsStatsBlockedDomain{ Domain: domain, - PercentBlocked: int(float64(count) / float64(responseJson.BlockedQueries) * 100), + PercentBlocked: int(float64(count) / float64(r.BlockedQueries) * 100), }) } @@ -323,59 +542,138 @@ func fetchPiholeStats(instanceURL string, allowInsecure bool, token string, noGr stats.TopBlockedDomains = domains[:min(len(domains), 5)] } - if noGraph { - return stats, nil + return stats } // Pihole _should_ return data for the last 24 hours in a 10 minute interval, 6*24 = 144 - if len(responseJson.QueriesSeries) != 144 || len(responseJson.BlockedSeries) != 144 { - slog.Warn( - "DNS stats for pihole: did not get expected 144 data points", - "len(queries)", len(responseJson.QueriesSeries), - "len(blocked)", len(responseJson.BlockedSeries), - ) - return stats, nil + if len(r.QueriesSeries) != 144 || len(r.BlockedSeries) != 144 { + return stats } var lowestTimestamp int64 = 0 - - for timestamp := range responseJson.QueriesSeries { + for timestamp := range r.QueriesSeries { if lowestTimestamp == 0 || timestamp < lowestTimestamp { lowestTimestamp = timestamp } } - maxQueriesInSeries := 0 for i := 0; i < 8; i++ { queries := 0 blocked := 0 - for j := 0; j < 18; j++ { index := lowestTimestamp + int64(i*10800+j*600) - - queries += responseJson.QueriesSeries[index] - blocked += responseJson.BlockedSeries[index] + queries += r.QueriesSeries[index] + blocked += r.BlockedSeries[index] } - if queries > maxQueriesInSeries { maxQueriesInSeries = queries } - stats.Series[i] = dnsStatsSeries{ Queries: queries, Blocked: blocked, } - if queries > 0 { stats.Series[i].PercentBlocked = int(float64(blocked) / float64(queries) * 100) } } - for i := 0; i < 8; i++ { stats.Series[i].PercentTotal = int(float64(stats.Series[i].Queries) / float64(maxQueriesInSeries) * 100) } - - return stats, nil + return stats } + +func fetchPiholeStats(instanceURL string, allowInsecure bool, token string, noGraph bool, version, appPassword string) (*dnsStats, error) { + instanceURL = strings.TrimRight(instanceURL, "/") + var requestURL string + var sid string + isV6 := version == "" || version == "6" + + if isV6 { + if appPassword == "" { + return nil, errors.New("missing app password") + } + + sid = os.Getenv("SID") + if sid == "" { + newSid, err := piholeGetSID(instanceURL, appPassword, allowInsecure) + if err != nil { + return nil, fmt.Errorf("failed to get SID: %w", err) + } + sid = newSid + os.Setenv("SID", sid) + } else { + err := checkPiholeSID(instanceURL, sid, allowInsecure) + if err != nil { + newSid, err := piholeGetSID(instanceURL, appPassword, allowInsecure) + if err != nil { + return nil, fmt.Errorf("failed to get SID after invalid check: %w", err) + } + sid = newSid + os.Setenv("SID", sid) + } + } + + requestURL = instanceURL + "/api/stats/summary" + } else { + if token == "" { + return nil, errors.New("missing API token") + } + requestURL = instanceURL + "/admin/api.php?summaryRaw&topItems&overTimeData10mins&auth=" + token + } + + request, err := http.NewRequest("GET", requestURL, nil) + if err != nil { + return nil, fmt.Errorf("failed to create HTTP request: %w", err) + } + + if isV6 { + request.Header.Set("x-ftl-sid", sid) + } + + var client requestDoer + if !allowInsecure { + client = defaultHTTPClient + } else { + client = defaultInsecureHTTPClient + } + + var responseJson interface{} + if isV6 { + responseJson, err = decodeJsonFromRequest[piholeStatsResponse](client, request) + } else { + responseJson, err = decodeJsonFromRequest[legacyPiholeStatsResponse](client, request) + } + + if err != nil { + return nil, fmt.Errorf("failed to decode JSON response: %w", err) + } + + switch r := responseJson.(type) { + case piholeStatsResponse: + // Fetch top domains separately for v6+ + topDomains, err := fetchPiholeTopDomains(instanceURL, sid, allowInsecure) + if err != nil { + return nil, fmt.Errorf("failed to fetch top domains: %w", err) + } + + // Fetch series data separately for v6+ + queriesSeries, blockedSeries, err := fetchPiholeSeries(instanceURL, sid, allowInsecure) + if err != nil { + return nil, fmt.Errorf("failed to fetch queries series: %w", err) + } + + // Merge series data + r.QueriesSeries = queriesSeries + r.BlockedSeries = blockedSeries + + return parsePiholeStats(r, topDomains, noGraph), nil + + case legacyPiholeStatsResponse: + return parsePiholeStatsLegacy(r, noGraph), nil + + default: + return nil, errors.New("unexpected response type") + } +} \ No newline at end of file diff --git a/internal/glance/widget-utils.go b/internal/glance/widget-utils.go index 8fb76ddb..fa141c19 100644 --- a/internal/glance/widget-utils.go +++ b/internal/glance/widget-utils.go @@ -82,6 +82,16 @@ func decodeJsonFromRequest[T any](client requestDoer, request *http.Request) (T, return result, nil } +func decodeJsonInto[T any](client requestDoer, request *http.Request, out *T) error { + result, err := decodeJsonFromRequest[T](client, request) + if err != nil { + return err + } + + *out = result + return nil +} + func decodeJsonFromRequestTask[T any](client requestDoer) func(*http.Request) (T, error) { return func(request *http.Request) (T, error) { return decodeJsonFromRequest[T](client, request)