Building the DNS cache

Ronit PandaRonit Panda
5 min read

Before jumping into the code for the cache, let's first breakdown what exactly what we need from the cache class

And also how are we going to cater to all those use-cases

Requirements

  1. set answers to all question names and record types individually that are resolved with recursion

  2. Set ttl according to the record on our cache store (redis in our case)

  3. While getting all records make sure, none of the answer records are actually expired

  4. Return answer records in the same order, as retuned from recursive resolver

Now let's code it out keeping the above requirements in mind

hop on to src/dns-cache.ts file

import { Redis } from '@upstash/redis';
import { DNSAnswer, DNSQuestion } from './message/types';
import { decodeRDATA } from './utils';

export class DNSCache {
    constructor(private redis: Redis) {}

    async set(question: DNSQuestion, answers: DNSAnswer[]) {
        try {
            if (answers.length === 0) {
                return;
            }

            const baseKey = `${question.NAME}/${question.TYPE}`;
            await this.redis.del(baseKey);
            await this.redis.rpush(
                baseKey,
                answers.map((a) => decodeRDATA(a.RDATA)),
            );
            const promises: Promise<DNSAnswer | 'OK' | null>[] = [];
            for (const answer of answers) {
                const key = `${baseKey}:${decodeRDATA(answer.RDATA)}`;
                promises.push(
                    this.redis.set(key, answer, {
                        ex: answer.TTL,
                    }),
                );
            }

            await Promise.all(promises);
        } catch (error) {
            console.error('Error setting cache', error);
            throw error;
        }
    }

    async get(question: DNSQuestion): Promise<DNSAnswer[]> {
        const baseKey = `${question.NAME}/${question.TYPE}`;

        const [_cache] = await this.redis.lrange(baseKey, 0, -1); // get all elements in the list [[RDATA1, RDATA2, ...]]
        if (!_cache) {
            console.log('cache is empty');
            return [];
        }
        const cache = _cache as unknown as string[]; // convert to string array

        if (cache.length === 0) {
            console.log('cache is empty');
            return [];
        }

        const keys = await this.redis.keys(`${baseKey}:*`);
        if (keys.length !== cache.length) {
            console.log('keys length does not match cache length', { keys, cache });
            return [];
        }

        const answers: DNSAnswer[] = [];
        for (const key of keys) {
            const answer = await this.redis.get<DNSAnswer>(key);
            if (answer) {
                answers.push({
                    ...answer,
                    RDATA: Buffer.from(answer.RDATA), // convert back to buffer
                });
            }
        }

        // cache which is a string array contains the RDATA of the answers, I need my answers array with the same order
        // because in a DNS response, the order of the answers matters (first show CNAME then show underlying A for example)
        return sortDNSAnswers(answers, cache);
    }

    async deleteAll() {
        await this.redis.flushdb();
    }
}

function sortDNSAnswers(answers: DNSAnswer[], cache: string[]): DNSAnswer[] {
    return answers.sort((a, b) => {
        const aIndex = cache.indexOf(decodeRDATA(a.RDATA));
        const bIndex = cache.indexOf(decodeRDATA(b.RDATA));

        if (aIndex === -1 || bIndex === -1) {
            console.warn('RDATA not found in cache:', {
                a: aIndex,
                b: bIndex,
            });
        }

        return aIndex - bIndex;
    });
}

Now let's decode the class line by line

  1. Importing Necessary Modules:

    • Import the custom modules and types needed for the cache implementation.

        import { Redis } from '@upstash/redis';
        import { DNSAnswer, DNSQuestion } from './message/types';
        import { decodeRDATA } from './utils';
      
  2. Defining the DNSCache Class:

    • Define the DNSCache class that takes a Redis instance as a parameter.

        export class DNSCache {
            constructor(private redis: Redis) {}
      
  3. Implementing the Set Method:

    • Define the set method to store DNS answers in the cache.

    • Use rpush to store RDATA values in a Redis list.

      • This is needed to ensure that all records resolved by recursion still have valid TTL during retrieval (this will become clearer when we implement the get method).

      • This is also necessary to maintain the order of answer records from the resolved response. For example, if we are querying for an A record that goes through a CNAME, we want to check the CNAME first and then the corresponding A record.

    • Set individual answers in Redis with their TTL, using a hash map with key format of questionName/recordType:Answer --> value will be the whole record with expiry being the record TTL.

        async set(question: DNSQuestion, answers: DNSAnswer[]) {
            try {
                if (answers.length === 0) {
                    return;
                }
      
                const baseKey = `${question.NAME}/${question.TYPE}`;
                await this.redis.del(baseKey);
                await this.redis.rpush(
                    baseKey,
                    answers.map((a) => decodeRDATA(a.RDATA)),
                );
                const promises: Promise<DNSAnswer | 'OK' | null>[] = [];
                for (const answer of answers) {
                    const key = `${baseKey}:${decodeRDATA(answer.RDATA)}`;
                    promises.push(
                        this.redis.set(key, answer, {
                            ex: answer.TTL,
                        }),
                    );
                }
      
                await Promise.all(promises);
            } catch (error) {
                console.error('Error setting cache', error);
                throw error;
            }
        }
      
  4. Implementing the Get Method:

    • Define the get method to retrieve DNS answers from the cache.

    • First, construct the base key using the questionName/questionType format. Then, use the lrange method to fetch the entire array stored. This array helps check for expired records and ensures the correct order of answer records in the response.

    • Retrieve individual answers from Redis using their keys.

    • If the length of the retrieved records is not the same as the cache length, meaning some records have expired, return an empty array and call the method again.

    • Finally, sort the answers from Redis in the same order as the cache array retrieved from the base key and return them as the response.

        async get(question: DNSQuestion): Promise<DNSAnswer[]> {
            const baseKey = `${question.NAME}/${question.TYPE}`;
      
            const [_cache] = await this.redis.lrange(baseKey, 0, -1);
            if (!_cache) {
                console.log('cache is empty');
                return [];
            }
            const cache = _cache as unknown as string[];
      
            if (cache.length === 0) {
                console.log('cache is empty');
                return [];
            }
      
            const keys = await this.redis.keys(`${baseKey}:*`);
            if (keys.length !== cache.length) {
                console.log('keys length does not match cache length', { keys, cache });
                return [];
            }
      
            const answers: DNSAnswer[] = [];
            for (const key of keys) {
                const answer = await this.redis.get<DNSAnswer>(key);
                if (answer) {
                    answers.push({
                        ...answer,
                        RDATA: Buffer.from(answer.RDATA),
                    });
                }
            }
      
            return sortDNSAnswers(answers, cache);
        }
      
  5. Implementing the DeleteAll Method:

    • Define the deleteAll method to clear the entire Redis cache. Just a utility method to be used inside dev only

        async deleteAll() {
            await this.redis.flushdb();
        }
      

Now finally we have implemented the cache, so just to recap what we did inside the primary UDP server

  • Fetch cached response

  • if cached response returns valid answers list use that as result

  • Else recursively resolve to find answers to dns query

  • later set it to cache, so next time we avoid recursive resolution

let responseObject: DNSObject;
// try to fetch from cache first
const cachedAnswers = await dnsCache.get(question);
if (cachedAnswers.length > 0) {
    responseObject = {
        header: {
            ...reqHeaderPacket,
            QR: QRIndicator.RESPONSE,
            RA: Bool.TRUE,
            ANCOUNT: cachedAnswers.length,
        },
        questions: [question],
        answers: cachedAnswers,
        additional: [],
        authority: [],
    };
} else {
    responseObject = await recursiveLookup(question, reqHeaderPacket);
    if (responseObject.answers)
        await dnsCache.set(question, responseObject.answers);
}

We are almost at the end, currently it's hard to query our UDP server for dns response, we can use tools like nslookup. So in the next blog, we will setup a simple HTTP server through which we can easily lookup any DNS record.

0
Subscribe to my newsletter

Read articles from Ronit Panda directly inside your inbox. Subscribe to the newsletter, and don't miss out.

Written by

Ronit Panda
Ronit Panda

Founding full stack engineer at dimension.dev