« Back to Index

Go: Custom DNS resolution in Golang

View original Gist on GitHub

Tags: #go #dns

1. README.md

$ go run main.go

Querying root server 198.41.0.4 for TLD com...
Selected TLD server: g.gtld-servers.net.
Resolved TLD server IP: 192.42.93.30
Querying TLD server 192.42.93.30 for domain coca-cola.com...
Selected authoritative server: ns4-09.azure-dns.info.
Resolved authoritative server IP: 208.84.5.9
Querying authoritative server 208.84.5.9 for domain coca-cola.com...
resolved A record: coca-cola.com.	3000	IN	A	52.14.144.171

main.go

package main

import (
	"crypto/rand"
	"errors"
	"fmt"
	"log"
	"math/big"
	"net"
	"strings"

	"github.com/miekg/dns"
)

func main() {
	domain := "coca-cola.com"
	if err := traceDNS(domain); err != nil {
		log.Fatal(err)
	}
}

func traceDNS(domain string) error {
	ips, err := net.LookupIP("a.root-servers.net")
	if err != nil {
		return fmt.Errorf("failed to lookup root DNS server: %w", err)
	}

	// Step 1: Query a root DNS server for the TLD (e.g., com)
	// We use `a.root-servers.net` IP address (`dig +short a.root-servers.net`).
	rootServer := ips[0].String()
	tld := domain[strings.LastIndex(domain, ".")+1:]

	fmt.Printf("Querying root server %s for TLD %s...\n", rootServer, tld)
	resp, err := queryDNS(rootServer, tld, dns.TypeNS)
	if err != nil {
		return fmt.Errorf("failed to query root server: %w", err)
	}

	// // Print root server response
	// for _, ns := range resp.Ns {
	// 	fmt.Printf("Root NS: %v\n", ns)
	// }

	// Step 2: Select a random TLD server from the root server response
	tldServers := []string{}
	for _, ns := range resp.Ns {
		if tldNs, ok := ns.(*dns.NS); ok {
			tldServers = append(tldServers, tldNs.Ns)
		}
	}

	if len(tldServers) == 0 {
		return errors.New("no TLD servers found in root server response")
	}

	// Randomly pick a TLD server
	nBig, err := rand.Int(rand.Reader, big.NewInt(int64(len(tldServers))))
	if err != nil {
		return fmt.Errorf("failed to create random int: %w", err)
	}
	selectedTLDServer := tldServers[nBig.Int64()]
	fmt.Printf("Selected TLD server: %s\n", selectedTLDServer)

	// Step 3: Get the IP address of the selected TLD server (need to query its A record)
	tldServerIP, err := resolveNameToIP(selectedTLDServer)
	if err != nil {
		return fmt.Errorf("failed to resolve TLD server %s: %w", selectedTLDServer, err)
	}

	fmt.Printf("Resolved TLD server IP: %s\n", tldServerIP)

	// Step 4: Query the TLD server for the domain's NS records
	fmt.Printf("Querying TLD server %s for domain %s...\n", tldServerIP, domain)
	resp, err = queryDNS(tldServerIP, domain, dns.TypeNS)
	if err != nil {
		return fmt.Errorf("failed to query TLD server: %w", err)
	}

	// Step 5: Print the authoritative name servers for the domain
	// fmt.Printf("Authoritative name servers for %s:\n", domain)
	// for _, ns := range resp.Ns {
	// 	if authNs, ok := ns.(*dns.NS); ok {
	// 		fmt.Printf("NS: %v\n", authNs.Ns)
	// 	}
	// }

	// Randomly pick an authoritative server
	nBig, err = rand.Int(rand.Reader, big.NewInt(int64(len(resp.Ns))))
	if err != nil {
		return fmt.Errorf("failed to create random int: %w", err)
	}
	selectedAuthoritativeServer := resp.Ns[nBig.Int64()]

	// Step 6: Get the IP address of the selected authoritative server
	if authNs, ok := selectedAuthoritativeServer.(*dns.NS); ok {
		fmt.Printf("Selected authoritative server: %s\n", authNs.Ns)
		authoritativeServerIP, err := resolveNameToIP(authNs.Ns)
		if err != nil {
			return fmt.Errorf("failed to resolve authoritative server %s: %w", selectedAuthoritativeServer, err)
		}
		fmt.Printf("Resolved authoritative server IP: %s\n", authoritativeServerIP)

		// Step 7: Query the authoritative server for the domain's A records
		fmt.Printf("Querying authoritative server %s for domain %s...\n", authoritativeServerIP, domain)
		resp, err := queryDNS(authoritativeServerIP, domain, dns.TypeA)
		if err != nil {
			return fmt.Errorf("failed to query TLD server: %w", err)
		}
		for _, rr := range resp.Answer {
			fmt.Printf("resolved A record: %s\n", rr.String())
		}
	}

	return nil
}

// Resolve a domain name (e.g., TLD server) to its IP address by querying for its A record
func resolveNameToIP(name string) (string, error) {
	// Use a public DNS server to resolve the name
	publicDNSServer := "8.8.8.8" // Google DNS

	resp, err := queryDNS(publicDNSServer, name, dns.TypeA)
	if err != nil {
		return "", err
	}

	for _, ans := range resp.Answer {
		if aRecord, ok := ans.(*dns.A); ok {
			return aRecord.A.String(), nil
		}
	}

	return "", fmt.Errorf("no A record found for %s", name)
}

// Query the specified DNS server for the given domain and record type
func queryDNS(server, domain string, qtype uint16) (*dns.Msg, error) {
	// Create a DNS client
	client := new(dns.Client)

	// Create a DNS message
	msg := new(dns.Msg)
	msg.SetQuestion(dns.Fqdn(domain), qtype)
	msg.RecursionDesired = false

	// Send the query to the specified server
	response, _, err := client.Exchange(msg, server+":53")
	if err != nil {
		return nil, fmt.Errorf("failed to query dns server: %w", err)
	}

	return response, nil
}