Last year, our AI application was compromised. Not through prompt injection—through model extraction. An attacker downloaded our fine-tuned model in 48 hours. After securing 20+ AI applications, I’ve learned that prompt injection is just the tip of the iceberg. Here’s the complete guide to AI security beyond prompt injection.

The Expanding Threat Landscape
AI applications face multiple security threats beyond prompt injection:
- Model extraction: Attackers reconstruct models through API queries
- Data poisoning: Malicious training data corrupts model behavior
- Adversarial attacks: Crafted inputs cause misclassification
- Model inversion: Reconstructing training data from model outputs
- Membership inference: Determining if data was in training set
- API abuse: Unauthorized access and rate limit bypass
Model Extraction Defense
Model extraction is a serious threat. Attackers can reconstruct your model through API queries:
# Defense: Rate limiting and query monitoring
from functools import wraps
import time
from collections import defaultdict
class ModelExtractionDefense:
def __init__(self, max_queries_per_hour=100, max_queries_per_user=1000):
self.query_counts = defaultdict(int)
self.query_timestamps = defaultdict(list)
self.max_queries_per_hour = max_queries_per_hour
self.max_queries_per_user = max_queries_per_user
def check_rate_limit(self, user_id):
# Check if user exceeds rate limits
now = time.time()
hour_ago = now - 3600
# Clean old timestamps
self.query_timestamps[user_id] = [
ts for ts in self.query_timestamps[user_id] if ts > hour_ago
]
# Check limits
if len(self.query_timestamps[user_id]) >= self.max_queries_per_hour:
raise RateLimitError("Hourly query limit exceeded")
if self.query_counts[user_id] >= self.max_queries_per_user:
raise RateLimitError("Daily query limit exceeded")
# Record query
self.query_timestamps[user_id].append(now)
self.query_counts[user_id] += 1
def detect_extraction_attempt(self, queries):
# Detect potential model extraction attempts
# Check for systematic query patterns
if len(queries) > 1000:
return True
# Check for diverse query types (extraction attempts)
unique_patterns = len(set(queries))
if unique_patterns / len(queries) > 0.9:
return True
return False
# Usage
def api_endpoint(user_id, query):
defense = ModelExtractionDefense()
defense.check_rate_limit(user_id)
# Check for extraction patterns
if defense.detect_extraction_attempt([query]):
raise SecurityError("Potential model extraction attempt detected")
# Process query
return process_query(query)
Additional Model Extraction Defenses
- Output perturbation: Add noise to outputs to prevent exact reconstruction
- Query filtering: Block suspicious query patterns
- Watermarking: Embed watermarks to detect extracted models
- Access controls: Require authentication and authorization
Data Poisoning Prevention
Data poisoning attacks corrupt training data to manipulate model behavior:
# Defense: Data validation and sanitization
import re
from typing import List, Dict
class DataPoisoningDefense:
def __init__(self):
self.suspicious_patterns = [
r'ignore\s+previous\s+instructions',
r'forget\s+everything',
r'new\s+instructions:',
r'<\|system\|>',
]
def validate_training_data(self, data: List[Dict]) -> List[Dict]:
# Validate and sanitize training data
cleaned_data = []
for item in data:
# Check for suspicious patterns
if self.contains_suspicious_pattern(item):
continue # Skip poisoned data
# Validate data format
if not self.is_valid_format(item):
continue
# Sanitize content
item['text'] = self.sanitize_text(item['text'])
cleaned_data.append(item)
return cleaned_data
def contains_suspicious_pattern(self, item: Dict) -> bool:
# Check for known poisoning patterns
text = item.get('text', '').lower()
for pattern in self.suspicious_patterns:
if re.search(pattern, text, re.IGNORECASE):
return True
return False
def sanitize_text(self, text: str) -> str:
# Sanitize text content
# Remove control characters
text = re.sub(r'[ --]', '', text)
# Normalize whitespace
text = re.sub(r'\s+', ' ', text)
# Remove potential injection patterns
text = re.sub(r'<\|[^|]+\|>', '', text)
return text.strip()
def is_valid_format(self, item: Dict) -> bool:
# Validate data format
required_fields = ['text', 'label'] # Adjust based on your format
return all(field in item for field in required_fields)
# Usage
def prepare_training_data(raw_data):
defense = DataPoisoningDefense()
cleaned_data = defense.validate_training_data(raw_data)
# Additional checks
if len(cleaned_data) < len(raw_data) * 0.9:
raise SecurityError("Too much data filtered - potential poisoning")
return cleaned_data

Adversarial Attack Mitigation
Adversarial attacks use crafted inputs to cause misclassification:
# Defense: Input validation and adversarial training
import numpy as np
from typing import Tuple
class AdversarialDefense:
def __init__(self):
self.max_input_length = 1000
self.allowed_chars = set('abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789 .,!?;:')
def validate_input(self, text: str) -> Tuple[bool, str]:
# Validate input for adversarial patterns
# Length check
if len(text) > self.max_input_length:
return False, "Input too long"
# Character validation
if not all(c in self.allowed_chars for c in text):
return False, "Invalid characters detected"
# Check for adversarial patterns
if self.detect_adversarial_pattern(text):
return False, "Adversarial pattern detected"
return True, "Valid"
def detect_adversarial_pattern(self, text: str) -> bool:
# Detect potential adversarial inputs
# Check for unusual character sequences
if self.has_unusual_sequences(text):
return True
# Check for high entropy (random-looking text)
if self.has_high_entropy(text):
return True
return False
def has_unusual_sequences(self, text: str) -> bool:
# Check for unusual character sequences
# Check for repeated characters (potential adversarial)
if re.search(r'(.){10,}', text):
return True
# Check for alternating patterns
if re.search(r'(.)(.){5,}', text):
return True
return False
def has_high_entropy(self, text: str) -> bool:
# Check for high entropy (random text)
if len(text) < 20:
return False
# Simple entropy check
char_counts = {}
for char in text:
char_counts[char] = char_counts.get(char, 0) + 1
# Calculate entropy
entropy = -sum((count/len(text)) * np.log2(count/len(text))
for count in char_counts.values())
# High entropy suggests random/adversarial input
return entropy > 4.5
# Usage
def secure_inference(text: str):
defense = AdversarialDefense()
is_valid, message = defense.validate_input(text)
if not is_valid:
raise SecurityError(f"Invalid input: {message}")
# Proceed with inference
return model.predict(text)
API Security Best Practices
Secure your AI API endpoints:
from flask import Flask, request, jsonify
from functools import wraps
import jwt
import hashlib
import time
app = Flask(__name__)
# API Security Middleware
def require_auth(f):
@wraps(f)
def decorated(*args, **kwargs):
token = request.headers.get('Authorization')
if not token:
return jsonify({'error': 'No token provided'}), 401
try:
# Verify JWT token
payload = jwt.decode(token, SECRET_KEY, algorithms=['HS256'])
request.user_id = payload['user_id']
except jwt.InvalidTokenError:
return jsonify({'error': 'Invalid token'}), 401
return f(*args, **kwargs)
return decorated
def rate_limit(max_per_minute=60):
def decorator(f):
@wraps(f)
def decorated(*args, **kwargs):
user_id = getattr(request, 'user_id', request.remote_addr)
key = f"rate_limit:{user_id}"
# Check rate limit (using Redis or similar)
current = redis_client.incr(key)
if current == 1:
redis_client.expire(key, 60)
if current > max_per_minute:
return jsonify({'error': 'Rate limit exceeded'}), 429
return f(*args, **kwargs)
return decorated
return decorator
def input_sanitization(f):
@wraps(f)
def decorated(*args, **kwargs):
data = request.get_json()
# Sanitize input
if 'prompt' in data:
data['prompt'] = sanitize_prompt(data['prompt'])
# Validate input
if not validate_input(data):
return jsonify({'error': 'Invalid input'}), 400
return f(*args, **kwargs)
return decorated
@app.route('/api/generate', methods=['POST'])
@require_auth
@rate_limit(max_per_minute=30)
@input_sanitization
def generate():
data = request.get_json()
prompt = data['prompt']
# Secure inference
result = secure_inference(prompt)
return jsonify({'result': result})
def sanitize_prompt(prompt: str) -> str:
# Sanitize user prompt
# Remove control characters
prompt = re.sub(r'[ --]', '', prompt)
# Limit length
prompt = prompt[:1000]
# Remove potential injection patterns
prompt = re.sub(r'<\|[^|]+\|>', '', prompt)
return prompt.strip()
def validate_input(data: dict) -> bool:
# Validate input data
if 'prompt' not in data:
return False
if not isinstance(data['prompt'], str):
return False
if len(data['prompt']) == 0 or len(data['prompt']) > 1000:
return False
return True
Model Inversion and Membership Inference Defense
Protect against privacy attacks:
# Defense: Differential privacy and output perturbation
import numpy as np
from typing import List
class PrivacyDefense:
def __init__(self, epsilon=1.0):
self.epsilon = epsilon # Privacy budget
def add_noise(self, output: np.ndarray) -> np.ndarray:
# Add calibrated noise for differential privacy
sensitivity = 1.0 # Adjust based on your model
scale = sensitivity / self.epsilon
# Add Laplace noise
noise = np.random.laplace(0, scale, output.shape)
noisy_output = output + noise
return noisy_output
def clip_output(self, output: np.ndarray, clip_value: float = 1.0) -> np.ndarray:
# Clip output values to limit information leakage
return np.clip(output, -clip_value, clip_value)
def secure_inference(self, input_text: str) -> str:
# Perform secure inference with privacy protection
# Get model output
raw_output = model.predict(input_text)
# Convert to numpy if needed
if isinstance(raw_output, str):
# For text outputs, add noise to embeddings/logits
embeddings = model.get_embeddings(input_text)
noisy_embeddings = self.add_noise(embeddings)
output = model.generate_from_embeddings(noisy_embeddings)
else:
# For numeric outputs
clipped = self.clip_output(raw_output)
noisy = self.add_noise(clipped)
output = noisy
return output
# Usage
def privacy_preserving_inference(text: str):
defense = PrivacyDefense(epsilon=1.0)
return defense.secure_inference(text)
Security Monitoring and Logging
Monitor your AI application for security threats:
import logging
from datetime import datetime
from typing import Dict, Any
class SecurityMonitor:
def __init__(self):
self.logger = logging.getLogger('security')
self.suspicious_activities = []
def log_query(self, user_id: str, query: str, response: str, metadata: Dict[str, Any]):
# Log all queries for security analysis
log_entry = {
'timestamp': datetime.utcnow().isoformat(),
'user_id': user_id,
'query': query[:100], # Truncate for privacy
'response_length': len(response),
'metadata': metadata
}
# Check for suspicious patterns
if self.is_suspicious(query, metadata):
self.logger.warning(f"Suspicious activity detected: {log_entry}")
self.suspicious_activities.append(log_entry)
# Store in secure log
self.store_log(log_entry)
def is_suspicious(self, query: str, metadata: Dict[str, Any]) -> bool:
# Detect suspicious query patterns
# High query rate
if metadata.get('queries_per_hour', 0) > 100:
return True
# Unusual query patterns
if self.has_extraction_pattern(query):
return True
# Unusual response patterns
if metadata.get('response_time', 0) > 10: # Very slow
return True
return False
def has_extraction_pattern(self, query: str) -> bool:
# Detect model extraction patterns
# Systematic queries
extraction_keywords = [
'test', 'example', 'sample', 'generate',
'all possible', 'every combination'
]
query_lower = query.lower()
if sum(1 for kw in extraction_keywords if kw in query_lower) > 3:
return True
return False
def store_log(self, entry: Dict):
# Store log entry securely
# Store in secure database or SIEM
# Encrypt sensitive fields
pass
# Usage
monitor = SecurityMonitor()
def secure_api_call(user_id: str, query: str):
# Process query
response = process_query(query)
# Log for security monitoring
metadata = {
'queries_per_hour': get_user_query_count(user_id),
'response_time': get_response_time(),
}
monitor.log_query(user_id, query, response, metadata)
return response

Best Practices Checklist
Comprehensive security checklist for AI applications:
- Authentication & Authorization: Require valid tokens for all API calls
- Rate Limiting: Prevent abuse and model extraction attempts
- Input Validation: Sanitize and validate all user inputs
- Output Filtering: Filter sensitive information from outputs
- Logging & Monitoring: Log all queries and monitor for anomalies
- Data Validation: Validate training data for poisoning
- Model Watermarking: Embed watermarks to detect extraction
- Privacy Protection: Use differential privacy for sensitive data
- Regular Audits: Regularly audit security measures
- Incident Response: Have a plan for security incidents
🎯 Key Takeaway
AI security goes far beyond prompt injection. Protect against model extraction with rate limiting and monitoring. Prevent data poisoning with validation. Mitigate adversarial attacks with input sanitization. Use differential privacy for sensitive data. Implement comprehensive logging and monitoring. Security is a multi-layer defense—don't rely on a single measure.
Common Mistakes
What I learned the hard way:
- Focusing only on prompt injection: Many other threats exist
- No rate limiting: Enabled model extraction attacks
- Weak input validation: Allowed adversarial inputs
- No monitoring: Didn't detect attacks until too late
- Ignoring data poisoning: Corrupted training data
- No privacy protection: Leaked training data through outputs
Bottom Line
AI security requires a comprehensive, multi-layer approach. Beyond prompt injection, protect against model extraction, data poisoning, adversarial attacks, and privacy violations. Implement authentication, rate limiting, input validation, output filtering, and comprehensive monitoring. Security is not optional—it's essential for production AI applications.
Discover more from C4: Container, Code, Cloud & Context
Subscribe to get the latest posts sent to your email.