Building Neurostash - II

Farhan KhojaFarhan Khoja
5 min read

When I started building NeuroStash, my AI-powered knowledge management platform, I knew authentication would be critical. Not just "password and done" authentication, but a better security that could handle API keys, JWT tokens.

Here’s how I architected a robust authentication system using Python, FastAPI, AWS KMS and PostgreSQL

The Authentication Challenge

Most developers face a common dilemma with authentication:

  • Too simple: Basic username/password systems are vulnerable and don't scale

  • Too complex: Over-engineered solutions become maintenance nightmares

I needed something that was both secure and maintainable, supporting:

  • JWT tokens for session-based authentication

  • API keys for programmatic access

  • Symmetric Encryption Keys that signs the JWT Token and API Keys stored in Database

  • Those Symmetric Encryption Keys which signs are needed to be encrypted via AWS KMS that produces double layer of security in case if database breach you can’t reverse engineer the keys.

  • We have Role-Base access control

Architecture Overview

AWS integration with smart resource management

The foundation starts with AWS KMS integration. I used Python's @property decorator to implement lazy loading of AWS clients

class AwsClientManager:
    @property
    def kms(self):
        if self._kms_client is None:
            self._kms_client = self.session.client("kms")
        return self._kms_client

Why this pattern works:

  • KMS clients are only created when needed

  • Reduces startup time and memory usage

  • Makes testing easier (you can mock the property)

  • Handles AWS credential rotation gracefully

The encrypt/decrypt methods wrap KMS operations with proper error handling

def encrypt_key(self, key_blob: bytes) -> Optional[bytes]:
        if not self.kms or not self.kms_key_id:
            logger.error("kms client or kms key id not configured for encryption")
            return None
        try:
            response = self.kms.encrypt(
                KeyId=self.settings.AWS_KMS_KEY_ID, Plaintext=key_blob
            )
            return response.get("CiphertextBlob")
        except ClientError as e:
            logger.error(f"error encrypting key with kms key id: {e}")
            raise

    def decrypt_key(self, ciphertext_blob: bytes) -> Optional[bytes]:
        if not self.kms or not self.kms_key_id:
            logger.error("kms client or kms key id not configured for encryption")
            return None
        try:
            response = self.kms.decrypt(
                CiphertextBlob=ciphertext_blob, KeyId=self.settings.AWS_KMS_KEY_ID
            )
            return response.get("Plaintext")
        except ClientError as e:
            logger.error(f"error decrypting key: {e}")
            raise

Symmetric Key Generation

For cryptographic operations, I used Python’s secrets module

def generate_symmetric_key() -> bytes:
    try:
        key = secrets.token_bytes(32)  # 256-bit key
        return key
    except Exception as e:
        logger.error(f"failed to generate key: {e}")
        raise KeyGenerationError(f"failed to generate key: {e}") from e

Why 256-bit keys: They provide excellent security while maintaining good performance for HMAC operations.

Database Schema Design

class EncryptionKey(Base, TimestampMixin):
    __tablename__ = "encryption_keys"

    id: Mapped[int] = mapped_column(Integer, Identity(), primary_key=True)
    symmetric_key: Mapped[bytes] = mapped_column(LargeBinary, nullable=False)
    is_active: Mapped[bool] = mapped_column(Boolean, nullable=False)
    expired_at: Mapped[Optional[datetime]] = mapped_column(TIMESTAMP(timezone=True))

    __table_args__ = (
        Index("idx_encryption_keys_active", "id", 
              postgresql_where=Column("is_active")),
    )

Performance optimization: The partial index on is_active means queries for active keys are lightning fast, even with thousands of expired keys in the table.

Token Manager: The heart of authentication

The TokenManager class orchestrates all authentication operations. It maintains an in-memory cache of decrypted keys

def _build_active_key_tuple(self, db: Session) -> Tuple[Dict[int, KeyInfo], int]:
    # Fetch encrypted keys from database
    active_encryption_keys = get_active_encryption_key(db=db)
    other_encryption_keys = get_other_encryption_keys(db=db)

    # Decrypt and cache in memory
    for key_id, value in key_info.items():
        decrypted_key_bytes = self._aws_client_manager.decrypt_key(value.key)
        decrypted_key_info[key_id] = KeyInfo(
            key=decrypted_key_bytes, 
            expires_at=value.expires_at
        )

    return (decrypted_key_info, active_id)

This approach gives us:

  • Fast token verification: No KMS calls during verification

  • Graceful key rotation: Multiple keys can be active simultaneously

  • Proper error handling: Failed decryption is caught early

JWT Token Implementation

JWT tokens include kid (key id) header for key identification and then that key which gets identified is used to sign it.

def create_access_token(self, payload_data: TokenData) -> str:
    all_keys, active_key_id = self.get_keys()
    active_key_info = all_keys[active_key_id]

    to_encode = payload_data.model_dump(mode='json', exclude_unset=True)
    to_encode.update({
        "exp": expire,
        "iss": settings.JWT_ISSUER,
        "aud": settings.JWT_AUDIENCE,
        "iat": datetime.now(timezone.utc),
        "jti": os.urandom(16).hex(),  # Unique token ID
    })

    headers = {"kid": active_key_id}
    return jwt.encode(to_encode, active_key_info.key.hex(), 
                     algorithm=ALGORITHMS.HS256, headers=headers)

Verification: extracts the kid and uses corresponding keys

def verify_token(self, token: str) -> Optional[TokenData]:
    unverified_headers = jwt.get_unverified_headers(token=token)
    kid = unverified_headers.get("kid")

    all_keys, _ = self.get_keys()
    key_for_verification = all_keys.get(kid)

    if key_for_verification.is_expired():
        raise RuntimeError(f"key {kid} has expired")

    payload = jwt.decode(token=token, key=key_for_verification.key.hex(),
                        algorithms=[ALGORITHMS.HS256])
    return TokenData(**payload)

API Key System: Stateless and Secure

API keys use HMAC signature for tamper-proof verification

def generate_api_key(self) -> Tuple[str, bytes, bytes, int]:
    # Generate random component
    random_bytes = secrets.token_bytes(24)
    random_bytes_b64 = base64.urlsafe_b64encode(random_bytes).decode("utf-8").rstrip("=")

    # Create HMAC signature
    data_to_hmac = f"{active_key_id}:{random_bytes_b64}".encode("utf-8")
    hmac_obj = hmac.new(active_key_info.key, data_to_hmac, hashlib.sha256)
    signature_bytes = hmac_obj.digest()
    signature_b64 = base64.urlsafe_b64encode(signature_bytes).decode("utf-8").rstrip("=")

    api_key = f"{random_bytes_b64}.{signature_b64}"
    return api_key, api_key_bytes, signature_bytes, active_key_id

def verify_api_key(self, api_key: str, key_hmac: bytes, kid: int) -> bool:
    parts = api_key.split(".")
    random_bytes_b64, signature_b64 = parts

    # Recreate HMAC with our key
    data_to_hmac = f"{kid}:{random_bytes_b64}".encode("utf-8")
    expected_hmac_obj = hmac.new(key_info.key, data_to_hmac, hashlib.sha256)
    expected_signature_bytes = expected_hmac_obj.digest()

    # Use constant-time comparison
    return hmac.compare_digest(expected_signature_bytes, client_signature_bytes)

Security Benefits:

  • Tamper-proof with HMAC-SHA256

  • Constant-time comparison prevents timing attacks

  • URL-safe Base64 encoding

FastAPI Integration

FastAPI's dependency injection makes authentication clean

async def get_token_payload(
    token: Annotated[Optional[str], Depends(oauth2_scheme)],
    token_manager: TokenDep,
) -> TokenData:
    if token is None:
        raise HTTPException(status_code=401, detail="missing bearer token")

    try:
        payload = token_manager.verify_token(token=token.credentials)
        return payload
    except ExpiredSignatureError:
        raise HTTPException(status_code=401, detail="token has expired")
    except KeyNotFoundError as e:
        raise HTTPException(status_code=401, detail=f"key error: {e}")

Routes can simply depend on TokenPayloadDep or ApiPayloadDep

@router.post("/protected-endpoint")
def protected_route(payload: TokenPayloadDep):
    # payload.user_id, payload.email, payload.role are available
    return {"message": f"Hello {payload.email}"}

Lessons Learned

  • The @property decorator is underrated: Perfect for lazy resource initialization

  • HMAC-based API keys scale better: Stateless verification beats database lookups

  • FastAPI dependencies simplify auth: Clean separation of concerns

  • Error handling matters: Fail gracefully without information leakage

Wrapping Up

Building authentication is challenging, but breaking it into focused components makes it manageable. The combination of AWS KMS, careful database design, and FastAPI's dependency system creates a robust foundation.

The key is balancing security with performance and maintainability. This architecture handles enterprise requirements while keeping the codebase clean and testable.

Want to see the full implementation? Check out the NeuroStash repository: https://github.com/DEVunderdog/NeuroStash

0
Subscribe to my newsletter

Read articles from Farhan Khoja directly inside your inbox. Subscribe to the newsletter, and don't miss out.

Written by

Farhan Khoja
Farhan Khoja