AI Security Best Practices: Beyond Prompt Injection

Last year, our AI application was compromised. Not through prompt injection—through model extraction. An attacker downloaded our fine-tuned model in 48 hours. After securing 20+ AI applications, I’ve learned that prompt injection is just the tip of the iceberg. Here’s the complete guide to AI security beyond prompt injection.

AI Security Threat Landscape
Figure 1: AI Security Threat Landscape

The Expanding Threat Landscape

AI applications face multiple security threats beyond prompt injection:

  • Model extraction: Attackers reconstruct models through API queries
  • Data poisoning: Malicious training data corrupts model behavior
  • Adversarial attacks: Crafted inputs cause misclassification
  • Model inversion: Reconstructing training data from model outputs
  • Membership inference: Determining if data was in training set
  • API abuse: Unauthorized access and rate limit bypass

Model Extraction Defense

Model extraction is a serious threat. Attackers can reconstruct your model through API queries:

# Defense: Rate limiting and query monitoring
from functools import wraps
import time
from collections import defaultdict

class ModelExtractionDefense:
    def __init__(self, max_queries_per_hour=100, max_queries_per_user=1000):
        self.query_counts = defaultdict(int)
        self.query_timestamps = defaultdict(list)
        self.max_queries_per_hour = max_queries_per_hour
        self.max_queries_per_user = max_queries_per_user
    
    def check_rate_limit(self, user_id):
        # Check if user exceeds rate limits
        now = time.time()
        hour_ago = now - 3600
        
        # Clean old timestamps
        self.query_timestamps[user_id] = [
            ts for ts in self.query_timestamps[user_id] if ts > hour_ago
        ]
        
        # Check limits
        if len(self.query_timestamps[user_id]) >= self.max_queries_per_hour:
            raise RateLimitError("Hourly query limit exceeded")
        
        if self.query_counts[user_id] >= self.max_queries_per_user:
            raise RateLimitError("Daily query limit exceeded")
        
        # Record query
        self.query_timestamps[user_id].append(now)
        self.query_counts[user_id] += 1
    
    def detect_extraction_attempt(self, queries):
        # Detect potential model extraction attempts
        # Check for systematic query patterns
        if len(queries) > 1000:
            return True
        
        # Check for diverse query types (extraction attempts)
        unique_patterns = len(set(queries))
        if unique_patterns / len(queries) > 0.9:
            return True
        
        return False

# Usage
def api_endpoint(user_id, query):
    defense = ModelExtractionDefense()
    defense.check_rate_limit(user_id)
    
    # Check for extraction patterns
    if defense.detect_extraction_attempt([query]):
        raise SecurityError("Potential model extraction attempt detected")
    
    # Process query
    return process_query(query)

Additional Model Extraction Defenses

  • Output perturbation: Add noise to outputs to prevent exact reconstruction
  • Query filtering: Block suspicious query patterns
  • Watermarking: Embed watermarks to detect extracted models
  • Access controls: Require authentication and authorization

Data Poisoning Prevention

Data poisoning attacks corrupt training data to manipulate model behavior:

# Defense: Data validation and sanitization
import re
from typing import List, Dict

class DataPoisoningDefense:
    def __init__(self):
        self.suspicious_patterns = [
            r'ignore\s+previous\s+instructions',
            r'forget\s+everything',
            r'new\s+instructions:',
            r'<\|system\|>',
        ]
    
    def validate_training_data(self, data: List[Dict]) -> List[Dict]:
        # Validate and sanitize training data
        cleaned_data = []
        
        for item in data:
            # Check for suspicious patterns
            if self.contains_suspicious_pattern(item):
                continue  # Skip poisoned data
            
            # Validate data format
            if not self.is_valid_format(item):
                continue
            
            # Sanitize content
            item['text'] = self.sanitize_text(item['text'])
            cleaned_data.append(item)
        
        return cleaned_data
    
    def contains_suspicious_pattern(self, item: Dict) -> bool:
        # Check for known poisoning patterns
        text = item.get('text', '').lower()
        for pattern in self.suspicious_patterns:
            if re.search(pattern, text, re.IGNORECASE):
                return True
        return False
    
    def sanitize_text(self, text: str) -> str:
        # Sanitize text content
        # Remove control characters
        text = re.sub(r'[--Ÿ]', '', text)
        
        # Normalize whitespace
        text = re.sub(r'\s+', ' ', text)
        
        # Remove potential injection patterns
        text = re.sub(r'<\|[^|]+\|>', '', text)
        
        return text.strip()
    
    def is_valid_format(self, item: Dict) -> bool:
        # Validate data format
        required_fields = ['text', 'label']  # Adjust based on your format
        return all(field in item for field in required_fields)

# Usage
def prepare_training_data(raw_data):
    defense = DataPoisoningDefense()
    cleaned_data = defense.validate_training_data(raw_data)
    
    # Additional checks
    if len(cleaned_data) < len(raw_data) * 0.9:
        raise SecurityError("Too much data filtered - potential poisoning")
    
    return cleaned_data
Multi-Layer Defense Strategy
Figure 2: Multi-Layer Defense Strategy

Adversarial Attack Mitigation

Adversarial attacks use crafted inputs to cause misclassification:

# Defense: Input validation and adversarial training
import numpy as np
from typing import Tuple

class AdversarialDefense:
    def __init__(self):
        self.max_input_length = 1000
        self.allowed_chars = set('abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789 .,!?;:')
    
    def validate_input(self, text: str) -> Tuple[bool, str]:
        # Validate input for adversarial patterns
        # Length check
        if len(text) > self.max_input_length:
            return False, "Input too long"
        
        # Character validation
        if not all(c in self.allowed_chars for c in text):
            return False, "Invalid characters detected"
        
        # Check for adversarial patterns
        if self.detect_adversarial_pattern(text):
            return False, "Adversarial pattern detected"
        
        return True, "Valid"
    
    def detect_adversarial_pattern(self, text: str) -> bool:
        # Detect potential adversarial inputs
        # Check for unusual character sequences
        if self.has_unusual_sequences(text):
            return True
        
        # Check for high entropy (random-looking text)
        if self.has_high_entropy(text):
            return True
        
        return False
    
    def has_unusual_sequences(self, text: str) -> bool:
        # Check for unusual character sequences
        # Check for repeated characters (potential adversarial)
        if re.search(r'(.){10,}', text):
            return True
        
        # Check for alternating patterns
        if re.search(r'(.)(.){5,}', text):
            return True
        
        return False
    
    def has_high_entropy(self, text: str) -> bool:
        # Check for high entropy (random text)
        if len(text) < 20:
            return False
        
        # Simple entropy check
        char_counts = {}
        for char in text:
            char_counts[char] = char_counts.get(char, 0) + 1
        
        # Calculate entropy
        entropy = -sum((count/len(text)) * np.log2(count/len(text)) 
                      for count in char_counts.values())
        
        # High entropy suggests random/adversarial input
        return entropy > 4.5

# Usage
def secure_inference(text: str):
    defense = AdversarialDefense()
    is_valid, message = defense.validate_input(text)
    
    if not is_valid:
        raise SecurityError(f"Invalid input: {message}")
    
    # Proceed with inference
    return model.predict(text)

API Security Best Practices

Secure your AI API endpoints:

from flask import Flask, request, jsonify
from functools import wraps
import jwt
import hashlib
import time

app = Flask(__name__)

# API Security Middleware
def require_auth(f):
    @wraps(f)
    def decorated(*args, **kwargs):
        token = request.headers.get('Authorization')
        if not token:
            return jsonify({'error': 'No token provided'}), 401
        
        try:
            # Verify JWT token
            payload = jwt.decode(token, SECRET_KEY, algorithms=['HS256'])
            request.user_id = payload['user_id']
        except jwt.InvalidTokenError:
            return jsonify({'error': 'Invalid token'}), 401
        
        return f(*args, **kwargs)
    return decorated

def rate_limit(max_per_minute=60):
    def decorator(f):
        @wraps(f)
        def decorated(*args, **kwargs):
            user_id = getattr(request, 'user_id', request.remote_addr)
            key = f"rate_limit:{user_id}"
            
            # Check rate limit (using Redis or similar)
            current = redis_client.incr(key)
            if current == 1:
                redis_client.expire(key, 60)
            
            if current > max_per_minute:
                return jsonify({'error': 'Rate limit exceeded'}), 429
            
            return f(*args, **kwargs)
        return decorated
    return decorator

def input_sanitization(f):
    @wraps(f)
    def decorated(*args, **kwargs):
        data = request.get_json()
        
        # Sanitize input
        if 'prompt' in data:
            data['prompt'] = sanitize_prompt(data['prompt'])
        
        # Validate input
        if not validate_input(data):
            return jsonify({'error': 'Invalid input'}), 400
        
        return f(*args, **kwargs)
    return decorated

@app.route('/api/generate', methods=['POST'])
@require_auth
@rate_limit(max_per_minute=30)
@input_sanitization
def generate():
    data = request.get_json()
    prompt = data['prompt']
    
    # Secure inference
    result = secure_inference(prompt)
    
    return jsonify({'result': result})

def sanitize_prompt(prompt: str) -> str:
    # Sanitize user prompt
    # Remove control characters
    prompt = re.sub(r'[--Ÿ]', '', prompt)
    
    # Limit length
    prompt = prompt[:1000]
    
    # Remove potential injection patterns
    prompt = re.sub(r'<\|[^|]+\|>', '', prompt)
    
    return prompt.strip()

def validate_input(data: dict) -> bool:
    # Validate input data
    if 'prompt' not in data:
        return False
    
    if not isinstance(data['prompt'], str):
        return False
    
    if len(data['prompt']) == 0 or len(data['prompt']) > 1000:
        return False
    
    return True

Model Inversion and Membership Inference Defense

Protect against privacy attacks:

# Defense: Differential privacy and output perturbation
import numpy as np
from typing import List

class PrivacyDefense:
    def __init__(self, epsilon=1.0):
        self.epsilon = epsilon  # Privacy budget
    
    def add_noise(self, output: np.ndarray) -> np.ndarray:
        # Add calibrated noise for differential privacy
        sensitivity = 1.0  # Adjust based on your model
        scale = sensitivity / self.epsilon
        
        # Add Laplace noise
        noise = np.random.laplace(0, scale, output.shape)
        noisy_output = output + noise
        
        return noisy_output
    
    def clip_output(self, output: np.ndarray, clip_value: float = 1.0) -> np.ndarray:
        # Clip output values to limit information leakage
        return np.clip(output, -clip_value, clip_value)
    
    def secure_inference(self, input_text: str) -> str:
        # Perform secure inference with privacy protection
        # Get model output
        raw_output = model.predict(input_text)
        
        # Convert to numpy if needed
        if isinstance(raw_output, str):
            # For text outputs, add noise to embeddings/logits
            embeddings = model.get_embeddings(input_text)
            noisy_embeddings = self.add_noise(embeddings)
            output = model.generate_from_embeddings(noisy_embeddings)
        else:
            # For numeric outputs
            clipped = self.clip_output(raw_output)
            noisy = self.add_noise(clipped)
            output = noisy
        
        return output

# Usage
def privacy_preserving_inference(text: str):
    defense = PrivacyDefense(epsilon=1.0)
    return defense.secure_inference(text)

Security Monitoring and Logging

Monitor your AI application for security threats:

import logging
from datetime import datetime
from typing import Dict, Any

class SecurityMonitor:
    def __init__(self):
        self.logger = logging.getLogger('security')
        self.suspicious_activities = []
    
    def log_query(self, user_id: str, query: str, response: str, metadata: Dict[str, Any]):
        # Log all queries for security analysis
        log_entry = {
            'timestamp': datetime.utcnow().isoformat(),
            'user_id': user_id,
            'query': query[:100],  # Truncate for privacy
            'response_length': len(response),
            'metadata': metadata
        }
        
        # Check for suspicious patterns
        if self.is_suspicious(query, metadata):
            self.logger.warning(f"Suspicious activity detected: {log_entry}")
            self.suspicious_activities.append(log_entry)
        
        # Store in secure log
        self.store_log(log_entry)
    
    def is_suspicious(self, query: str, metadata: Dict[str, Any]) -> bool:
        # Detect suspicious query patterns
        # High query rate
        if metadata.get('queries_per_hour', 0) > 100:
            return True
        
        # Unusual query patterns
        if self.has_extraction_pattern(query):
            return True
        
        # Unusual response patterns
        if metadata.get('response_time', 0) > 10:  # Very slow
            return True
        
        return False
    
    def has_extraction_pattern(self, query: str) -> bool:
        # Detect model extraction patterns
        # Systematic queries
        extraction_keywords = [
            'test', 'example', 'sample', 'generate',
            'all possible', 'every combination'
        ]
        
        query_lower = query.lower()
        if sum(1 for kw in extraction_keywords if kw in query_lower) > 3:
            return True
        
        return False
    
    def store_log(self, entry: Dict):
        # Store log entry securely
        # Store in secure database or SIEM
        # Encrypt sensitive fields
        pass

# Usage
monitor = SecurityMonitor()

def secure_api_call(user_id: str, query: str):
    # Process query
    response = process_query(query)
    
    # Log for security monitoring
    metadata = {
        'queries_per_hour': get_user_query_count(user_id),
        'response_time': get_response_time(),
    }
    monitor.log_query(user_id, query, response, metadata)
    
    return response
Security Monitoring Architecture
Figure 3: Security Monitoring Architecture

Best Practices Checklist

Comprehensive security checklist for AI applications:

  • Authentication & Authorization: Require valid tokens for all API calls
  • Rate Limiting: Prevent abuse and model extraction attempts
  • Input Validation: Sanitize and validate all user inputs
  • Output Filtering: Filter sensitive information from outputs
  • Logging & Monitoring: Log all queries and monitor for anomalies
  • Data Validation: Validate training data for poisoning
  • Model Watermarking: Embed watermarks to detect extraction
  • Privacy Protection: Use differential privacy for sensitive data
  • Regular Audits: Regularly audit security measures
  • Incident Response: Have a plan for security incidents

🎯 Key Takeaway

AI security goes far beyond prompt injection. Protect against model extraction with rate limiting and monitoring. Prevent data poisoning with validation. Mitigate adversarial attacks with input sanitization. Use differential privacy for sensitive data. Implement comprehensive logging and monitoring. Security is a multi-layer defense—don't rely on a single measure.

Common Mistakes

What I learned the hard way:

  • Focusing only on prompt injection: Many other threats exist
  • No rate limiting: Enabled model extraction attacks
  • Weak input validation: Allowed adversarial inputs
  • No monitoring: Didn't detect attacks until too late
  • Ignoring data poisoning: Corrupted training data
  • No privacy protection: Leaked training data through outputs

Bottom Line

AI security requires a comprehensive, multi-layer approach. Beyond prompt injection, protect against model extraction, data poisoning, adversarial attacks, and privacy violations. Implement authentication, rate limiting, input validation, output filtering, and comprehensive monitoring. Security is not optional—it's essential for production AI applications.


Discover more from C4: Container, Code, Cloud & Context

Subscribe to get the latest posts sent to your email.

Leave a Reply

Your email address will not be published. Required fields are marked *

This site uses Akismet to reduce spam. Learn how your comment data is processed.