« Back to Index

[Python TTL Cache Decorator]

View original Gist on GitHub

Tags: #python #decorator #async #ttl #cache

Python TTL Cache Decorator.py

from datetime import datetime, timedelta
from functools import wraps
from tornado.httpclient import AsyncHTTPClient

AsyncHTTPClient.configure(None, defaults=dict(user_agent="YourService"))
http_client = AsyncHTTPClient()

def network_cache(fn):
    """Cache asynchronous network requests based on cache-control.

    Expects decorated function to have received an `endpoint` keyword parameter
    and to return a dict containing the keys `cache_control` and
    `response_body`.
    """

    _cache = {}

    @wraps(fn)
    async def wrapper(*args, **kwargs):
        key = kwargs.get('endpoint')

        if key not in _cache:
            # call function and cache off the response
            result = await fn(*args, **kwargs)
            cache_control = result.get('cache_control', 0)
            body = result.get('response_body')
            _cache[key] = {'ttl': cache_control,
                           'value': body,
                           'timestamp': datetime.now().timestamp()}
        else:
            # check the cached content ttl before returning
            ts = _cache[key]['timestamp']
            ts_datetime = datetime.fromtimestamp(ts)
            delta = datetime.now() - ts_datetime

            # if ttl has expired, make request and cache response
            ttl = _cache[key]['ttl']
            if delta.days > timedelta(seconds=ttl).days:
                del _cache[key]
                
                result = await fn(*args, **kwargs)
                cache_control = result.get('cache_control', 0)
                body = result.get('response_body')
                _cache[key] = {'ttl': cache_control,
                               'value': body,
                               'timestamp': datetime.now().timestamp()}

        return _cache[key]['value']
    return wrapper
  
@network_cache
async def get_keys(endpoint=user_pool_jwk):
    """Retrieve JWK (for verifying tokens).

    If successful we return a dict consisting of the cache-control response
    header value and the actual list of JWKs.

    If unsuccessful we return the standard dictionary error format.
    """

    response = await http_client.fetch(endpoint)
    cache_control = response.headers.get('Cache-Control')

    match = re.search(r'max-age=(\d+)', cache_control)
    if not match:
        msg = 'JWK_RESPONSE_INVALID'
        gen_exc = exceptions.AsyncFetchException(msg, code=response.code)
        instr_exc(gen_exc, msg, cache_control=cache_control)
        raise gen_exc

    if response.code != 200:
        raise exceptions.AsyncFetchException('JWK_RESPONSE_INVALID', code=response.get('code'))

    try:
        response_data = json.loads(response.body)
    except Exception as exc:
        msg = 'JSON_PARSING_FAILED'
        instr_exc(exc, msg, endpoint=endpoint, body=response.body)
        return {'state': 'error',
                'code': 500,
                'message': msg}

    return {'state': 'success',
            'cache_control': match,
            'response_body': response_data.get('keys', [])}

test.py

# standard library modules

import unittest.mock as mock
import sys

from collections import namedtuple
from datetime import datetime, timedelta

# external modules

import tornado.testing

# configuration

sys.path.insert(0, '/app')

# application modules

import app.aws
import app.network


# helpers

def make_coroutine(response):
    async def coroutine(*args, **kwargs):
        return response
    return coroutine


# asynchronous tests

class TestPassword(tornado.testing.AsyncTestCase):
    @mock.patch('app.network.datetime', wraps=datetime)
    @mock.patch('app.network.instr')
    @mock.patch('app.network.http_client')
    @tornado.testing.gen_test
    def test_network_cache_decorator(self, mock_http_client, mock_instr, mock_datetime):
        """Verify decorated function caches its network request."""

        fetch_body = '{"keys":["foo"]}'
        fetch_code = 200
        fetch_headers = {'Cache-Control': 'public, max-age=86400'}
        fetch_response = namedtuple('_', ['body', 'code', 'headers'])(fetch_body, fetch_code, fetch_headers)

        mock_http_client.fetch = make_coroutine(fetch_response)

        endpoint = 'https://example.com/foo'
        response = yield app.aws.get_keys(endpoint=endpoint)

        assert response == {'state': 'success', 'cache_control': '86400', 'response_body': ['foo']}
        mock_instr.assert_called_with('JWK_CACHE_MISS', metric_name='jwk.cache', state='miss', key=endpoint)

        response = yield app.aws.get_keys(endpoint=endpoint)
        assert response == {'state': 'success', 'cache_control': '86400', 'response_body': ['foo']}
        mock_instr.assert_called_with('JWK_CACHE_HIT', metric_name='jwk.cache', state='hit', key=endpoint)

        # mock datetime.now to return a current date that's actually two months ahead of now
        mock_datetime.now.return_value = datetime.now() + timedelta(days=60)

        # this request should result with the cache being invalidated (i.e. hit/expiry/population)
        response = yield app.aws.get_keys(endpoint=endpoint)
        assert response == {'state': 'success', 'cache_control': '86400', 'response_body': ['foo']}

        # reset the mock datetime.now behaviour to its default behaviour
        mock_datetime.now.side_effect = datetime.now

        # this request should once again get a cache hit
        response = yield app.aws.get_keys(endpoint=endpoint)
        assert response == {'state': 'success', 'cache_control': '86400', 'response_body': ['foo']}

        fields = {'key': 'https://example.com/foo', 'metric_name': 'jwk.cache'}
        assert mock_instr.call_args_list == [mock.call('JWK_CACHE_MISS', state='miss', **fields),
                                             mock.call('JWK_CACHE_HIT', state='hit', **fields),
                                             mock.call('JWK_CACHE_HIT', state='hit', **fields),
                                             mock.call('JWK_CACHE_EXPIRY', state='expired', **fields),
                                             mock.call('JWK_CACHE_POPULATION', state='populated', **fields),
                                             mock.call('JWK_CACHE_HIT', state='hit', **fields)]