Built-in Middleware: CostTracker, LatencyMiddleware, TopologyLogMiddleware¶
Demonstrates the built-in middleware classes for production observability and the error boundary mechanism that prevents middleware failures from crashing the pipeline.
Key concepts:
CostTracker: token usage accumulation via after_model
LatencyMiddleware: per-agent timing via TraceContext
TopologyLogMiddleware: structured logging for topology events
Error boundary: middleware exceptions caught, logged, and reported
on_middleware_error: notification hook for other middleware
Custom middleware with typed MiddlewareSchema
Tip
What you’ll learn How to compose agents into a sequential pipeline.
Source: 65_builtin_middleware.py
from adk_fluent._middleware import M, MComposite
from adk_fluent.middleware import (
CostTracker,
LatencyMiddleware,
Middleware,
TopologyLogMiddleware,
TraceContext,
)
# --- 1. CostTracker: token usage accumulation ---
tracker = CostTracker()
# Initial state
assert tracker.total_input_tokens == 0
assert tracker.total_output_tokens == 0
assert tracker.calls == 0
# Has after_model hook
assert hasattr(tracker, "after_model")
# CostTracker conforms to Middleware protocol
assert isinstance(tracker, Middleware)
# Via M factory
cost_chain = M.cost()
assert isinstance(cost_chain, MComposite)
assert isinstance(cost_chain.to_stack()[0], CostTracker)
# --- 2. LatencyMiddleware: per-agent timing ---
latency = LatencyMiddleware()
# Initial state
assert isinstance(latency.latencies, dict)
assert len(latency.latencies) == 0
# Has before_agent and after_agent hooks
assert hasattr(latency, "before_agent")
assert hasattr(latency, "after_agent")
# Conforms to Middleware protocol
assert isinstance(latency, Middleware)
# Via M factory
latency_chain = M.latency()
assert isinstance(latency_chain.to_stack()[0], LatencyMiddleware)
# --- 3. TopologyLogMiddleware: structured topology logging ---
topo = TopologyLogMiddleware()
# Initial state
assert isinstance(topo.log, list)
assert len(topo.log) == 0
# Has all topology hooks
assert hasattr(topo, "on_loop_iteration")
assert hasattr(topo, "on_fanout_start")
assert hasattr(topo, "on_fanout_complete")
assert hasattr(topo, "on_route_selected")
assert hasattr(topo, "on_fallback_attempt")
assert hasattr(topo, "on_timeout")
# Conforms to Middleware protocol
assert isinstance(topo, Middleware)
# Via M factory
topo_chain = M.topology_log()
assert isinstance(topo_chain.to_stack()[0], TopologyLogMiddleware)
# --- 4. Test TopologyLogMiddleware event capture ---
import asyncio
ctx = TraceContext()
async def _run_topo_logging():
await topo.on_loop_iteration(ctx, "review_loop", 1)
await topo.on_loop_iteration(ctx, "review_loop", 2)
await topo.on_route_selected(ctx, "intent_router", "support")
await topo.on_fanout_start(ctx, "analysis", ["risk", "fraud"])
await topo.on_fanout_complete(ctx, "analysis", ["risk", "fraud"])
asyncio.run(_run_topo_logging())
# Events are captured in the log
assert len(topo.log) == 5
assert topo.log[0]["event"] == "loop_iteration"
assert topo.log[0]["loop"] == "review_loop"
assert topo.log[0]["iteration"] == 1
assert topo.log[2]["event"] == "route_selected"
assert topo.log[2]["selected"] == "support"
assert topo.log[3]["event"] == "fanout_start"
assert topo.log[3]["branches"] == ["risk", "fraud"]
# --- 5. Error boundary: Middleware protocol includes on_middleware_error ---
assert hasattr(Middleware, "on_middleware_error")
class ErrorAwareMiddleware:
"""Middleware that tracks errors from other middleware."""
def __init__(self):
self.errors = []
async def on_middleware_error(self, ctx, hook_name, error, middleware):
self.errors.append(
{
"hook": hook_name,
"error": str(error),
"middleware": type(middleware).__name__,
}
)
error_tracker = ErrorAwareMiddleware()
assert isinstance(error_tracker, Middleware)
# --- 6. Composing built-in middleware ---
# Production observability stack
production_stack = M.retry(3) | M.log() | M.cost() | M.latency() | M.topology_log()
assert len(production_stack) == 5
# Scoped cost tracking: only track LLM costs for the expensive agent
from adk_fluent import Agent, Pipeline
writer = Agent("writer").model("gemini-2.5-flash").instruct("Write content.")
reviewer = Agent("reviewer").model("gemini-2.5-flash").instruct("Review content.")
pipeline = (writer >> reviewer).middleware(
# Global: retry + logging for everything
M.retry(3) | M.log()
)
assert len(pipeline._middlewares) == 2
# Scoped: cost tracking only for writer
pipeline.middleware(M.scope("writer", M.cost()))
# Conditional: latency tracking only in stream mode
pipeline.middleware(M.when("stream", M.latency()))
assert len(pipeline._middlewares) == 4
# --- 7. Custom middleware with MiddlewareSchema ---
from typing import Annotated
from adk_fluent._middleware_schema import MiddlewareSchema
from adk_fluent._schema_base import Reads, Writes
class ComplianceState(MiddlewareSchema):
"""Declares state dependencies for HIPAA compliance middleware."""
patient_id: Annotated[str, Reads()]
audit_log: Annotated[str, Writes(scope="temp")]
class ComplianceMiddleware:
"""HIPAA compliance middleware for healthcare pipelines."""
agents = "patient_lookup"
schema = ComplianceState
async def before_agent(self, ctx, agent_name):
# In production: verify patient consent, log access
pass
async def after_agent(self, ctx, agent_name):
# In production: write audit entry
pass
compliance = ComplianceMiddleware()
assert compliance.schema.reads_keys() == frozenset({"patient_id"})
assert compliance.schema.writes_keys() == frozenset({"temp:audit_log"})
# Works with M.scope() and M.when()
scoped_compliance = M.scope("patient_lookup", compliance)
assert scoped_compliance.to_stack()[0].schema is ComplianceState
# --- 8. Repr for built-in middleware ---
assert "CostTracker" in repr(CostTracker())
assert "LatencyMiddleware" in repr(LatencyMiddleware())
# --- 9. Full production example ---
# E-commerce order processing with comprehensive observability
order_agent = Agent("order_processor").model("gemini-2.5-flash").instruct("Process incoming orders.")
fraud_agent = Agent("fraud_detector").model("gemini-2.5-flash").instruct("Detect fraudulent orders.")
production_pipeline = (order_agent >> fraud_agent).middleware(
# Global retry + structured logging
M.retry(3)
| M.log()
# Topology logging for loop/fanout visibility
| M.topology_log()
)
# Cost tracking only for the LLM-heavy fraud detection
production_pipeline.middleware(M.scope("fraud_detector", M.cost()))
assert len(production_pipeline._middlewares) == 4
# --- 10. Expanded built-in middleware classes ---
from adk_fluent.middleware import (
CircuitBreakerMiddleware,
DedupMiddleware,
FallbackModelMiddleware,
MetricsMiddleware,
ModelCacheMiddleware,
RetryMiddleware,
TimeoutMiddleware,
TraceMiddleware,
_SampledMiddleware,
)
# CircuitBreakerMiddleware
circuit_breaker = CircuitBreakerMiddleware(threshold=5, reset_after=60)
assert circuit_breaker._threshold == 5
assert circuit_breaker._reset_after == 60
assert circuit_breaker._failures == {}
assert hasattr(circuit_breaker, "before_model")
assert hasattr(circuit_breaker, "after_model")
assert isinstance(circuit_breaker, Middleware)
# TimeoutMiddleware
timeout_mw = TimeoutMiddleware(seconds=30)
assert timeout_mw._seconds == 30
assert timeout_mw._deadlines == {}
assert hasattr(timeout_mw, "before_agent")
assert hasattr(timeout_mw, "before_model")
assert isinstance(timeout_mw, Middleware)
# ModelCacheMiddleware
cache_mw = ModelCacheMiddleware(ttl=300, key_fn=None)
assert cache_mw._ttl == 300
assert cache_mw._cache == {}
assert hasattr(cache_mw, "before_model")
assert hasattr(cache_mw, "after_model")
assert isinstance(cache_mw, Middleware)
# FallbackModelMiddleware
fallback_mw = FallbackModelMiddleware(fallback_model="gemini-2.0-flash")
assert fallback_mw._fallback == "gemini-2.0-flash"
assert hasattr(fallback_mw, "on_model_error")
assert isinstance(fallback_mw, Middleware)
# DedupMiddleware
dedup_mw = DedupMiddleware(window=10)
assert dedup_mw._window == 10
assert dedup_mw._recent == []
assert hasattr(dedup_mw, "before_model")
assert isinstance(dedup_mw, Middleware)
# _SampledMiddleware
inner_log = M.log().to_stack()[0]
sampled_mw = _SampledMiddleware(rate=0.5, inner=inner_log)
assert sampled_mw._rate == 0.5
assert sampled_mw._inner is inner_log
# TraceMiddleware
trace_mw = TraceMiddleware(exporter=None)
assert hasattr(trace_mw, "before_agent")
assert hasattr(trace_mw, "after_agent")
assert isinstance(trace_mw, Middleware)
# MetricsMiddleware
metrics_mw = MetricsMiddleware(collector=None)
assert metrics_mw._collector is None
assert metrics_mw._counts == {}
assert hasattr(metrics_mw, "after_agent")
assert hasattr(metrics_mw, "on_model_error")
assert isinstance(metrics_mw, Middleware)
# --- 11. Production resilience stack with new middleware ---
# High-availability API gateway with full observability + error handling
api_agent = Agent("api_gateway").model("gemini-2.5-flash").instruct("Route API requests.")
resilient_pipeline = api_agent.middleware(
# Retry with circuit breaker to prevent cascade failures
M.retry(3)
| M.circuit_breaker(threshold=5, reset_after=60)
# Timeout to prevent hanging requests
| M.timeout(seconds=30)
# Cache to reduce redundant LLM calls
| M.cache(ttl=60)
# Fallback model if primary fails
| M.fallback_model("gemini-2.0-flash")
# Dedup to suppress duplicate requests
| M.dedup(window=10)
# Observability
| M.log()
| M.cost()
| M.latency()
)
assert len(resilient_pipeline._middlewares) == 9
# Verify flattened stack has correct types
flat = [mw for mw in resilient_pipeline._middlewares]
assert isinstance(flat[0], RetryMiddleware)
assert isinstance(flat[1], CircuitBreakerMiddleware)
assert isinstance(flat[2], TimeoutMiddleware)
assert isinstance(flat[3], ModelCacheMiddleware)
assert isinstance(flat[4], FallbackModelMiddleware)
assert isinstance(flat[5], DedupMiddleware)
print("All built-in middleware assertions passed!")
graph TD
n1[["writer_then_reviewer (sequence)"]]
n2["writer"]
n3["reviewer"]
n2 --> n3