MiddlewareSchema: Typed Middleware State Declarations

Demonstrates MiddlewareSchema for declaring middleware state dependencies, enabling the contract checker to validate middleware reads/writes at compile time.

Key concepts:

  • MiddlewareSchema: base class for typed middleware declarations

  • Reads(scope=…): field read from state before execution

  • Writes(scope=…): field written to state after execution

  • reads_keys() / writes_keys(): introspect declared dependencies

  • schema attribute: bind a MiddlewareSchema to a middleware class

  • agents attribute: scope middleware to specific pipeline agents

  • Contract checker Pass 14: validates scoped middleware at build time

  • M.when(PredicateSchema, mw): state-aware conditional middleware

Tip

What you’ll learn How to compose agents into a sequential pipeline.

Source: 64_middleware_schema.py

from typing import Annotated

from adk_fluent._middleware_schema import MiddlewareSchema
from adk_fluent._schema_base import Reads, Writes

# --- 1. Declaring a MiddlewareSchema ---


class BudgetState(MiddlewareSchema):
    """Declares that a middleware reads a token budget and writes usage."""

    token_budget: Annotated[int, Reads(scope="app")]
    tokens_used: Annotated[int, Writes(scope="temp")]


# Introspect reads and writes
assert BudgetState.reads_keys() == frozenset({"app:token_budget"})
assert BudgetState.writes_keys() == frozenset({"temp:tokens_used"})

# --- 2. Mixed reads and writes ---


class EnrichmentState(MiddlewareSchema):
    """Reads a config key, writes an enriched result."""

    api_key: Annotated[str, Reads(scope="app")]
    enriched_data: Annotated[str, Writes()]  # default scope is "session"


assert EnrichmentState.reads_keys() == frozenset({"app:api_key"})
assert EnrichmentState.writes_keys() == frozenset({"enriched_data"})

# --- 3. Session-scoped reads (default scope) ---


class AuditState(MiddlewareSchema):
    user_id: Annotated[str, Reads()]  # scope defaults to "session"
    request_context: Annotated[str, Reads()]


assert AuditState.reads_keys() == frozenset({"user_id", "request_context"})
assert AuditState.writes_keys() == frozenset()  # no writes

# --- 4. Empty schema ---


class NoOpState(MiddlewareSchema):
    pass


assert NoOpState.reads_keys() == frozenset()
assert NoOpState.writes_keys() == frozenset()

# --- 5. Binding schema to middleware class ---


class BudgetEnforcer:
    """Middleware that enforces token budgets.

    The `schema` attribute declares state dependencies.
    The `agents` attribute scopes it to specific agents.
    """

    agents = "writer"
    schema = BudgetState

    async def before_agent(self, ctx, agent_name):
        # In production: read token_budget from state, check remaining
        pass

    async def after_model(self, ctx, response):
        # In production: update tokens_used in state
        pass


# Schema is accessible on the middleware instance
enforcer = BudgetEnforcer()
assert enforcer.schema is BudgetState
assert enforcer.agents == "writer"
assert enforcer.schema.reads_keys() == frozenset({"app:token_budget"})

# --- 6. Schema survives M.scope() and M.when() wrapping ---
from adk_fluent._middleware import M

scoped = M.scope("writer", enforcer)
wrapped = scoped.to_stack()[0]
assert wrapped.schema is BudgetState  # forwarded from inner via __getattr__

conditional = M.when("pipeline", enforcer)
cond_wrapped = conditional.to_stack()[0]
assert cond_wrapped.schema is BudgetState  # forwarded by _ConditionalMiddleware

# --- 7. Contract checker Pass 14 validation ---
from adk_fluent.testing.contracts import check_contracts

# Helper to create minimal IR nodes for contract checking
from types import SimpleNamespace

from adk_fluent._ir_generated import SequenceNode


def agent_node(name, output_key=None):
    return SimpleNamespace(
        name=name,
        output_key=output_key,
        tool_schema=None,
        callback_schema=None,
        prompt_schema=None,
        writes_keys=frozenset(),
        reads_keys=frozenset(),
        include_contents="default",
        instruction="",
        context_spec=None,
        produces_type=None,
        consumes_type=None,
        rules=(),
        predicate=None,
    )


# 7a. Satisfied reads: middleware reads "result", producer writes "result"
class NeedsResult(MiddlewareSchema):
    result: Annotated[str, Reads()]


class ReaderMW:
    agents = "reviewer"
    schema = NeedsResult


producer = agent_node("writer", output_key="result")
consumer = agent_node("reviewer")
seq = SequenceNode(
    name="test_pipeline",
    children=(producer, consumer),
    middlewares=(ReaderMW(),),
)

issues = check_contracts(seq)
mw_issues = [i for i in issues if isinstance(i, dict) and "MiddlewareSchema" in i.get("message", "")]
assert len(mw_issues) == 0  # reads satisfied -- no warnings


# 7b. Unsatisfied reads: middleware reads "missing_key", nobody produces it
class NeedsMissing(MiddlewareSchema):
    missing_key: Annotated[str, Reads()]


class MissingMW:
    agents = "reviewer"
    schema = NeedsMissing


seq_missing = SequenceNode(
    name="test_pipeline",
    children=(producer, consumer),
    middlewares=(MissingMW(),),
)

issues_missing = check_contracts(seq_missing)
mw_issues_missing = [i for i in issues_missing if isinstance(i, dict) and "MiddlewareSchema" in i.get("message", "")]
assert len(mw_issues_missing) == 1
assert "missing_key" in mw_issues_missing[0]["message"]


# 7c. Unscoped middleware: skipped by contract checker
class GlobalMW:
    schema = NeedsMissing  # has schema but no agents scope


seq_global = SequenceNode(
    name="test_pipeline",
    children=(producer, consumer),
    middlewares=(GlobalMW(),),
)

issues_global = check_contracts(seq_global)
mw_issues_global = [i for i in issues_global if isinstance(i, dict) and "MiddlewareSchema" in i.get("message", "")]
assert len(mw_issues_global) == 0  # unscoped middleware not validated


# 7d. Middleware writes promoted to downstream
class WriterSchema(MiddlewareSchema):
    enriched: Annotated[str, Writes()]


class EnricherMW:
    agents = "enricher"
    schema = WriterSchema


class ReaderSchema(MiddlewareSchema):
    enriched: Annotated[str, Reads()]


class DownstreamReaderMW:
    agents = "consumer"
    schema = ReaderSchema


enricher = agent_node("enricher")
downstream = agent_node("consumer")
seq_writes = SequenceNode(
    name="test_pipeline",
    children=(enricher, downstream),
    middlewares=(EnricherMW(), DownstreamReaderMW()),
)

issues_writes = check_contracts(seq_writes)
mw_issues_writes = [i for i in issues_writes if isinstance(i, dict) and "MiddlewareSchema" in i.get("message", "")]
assert len(mw_issues_writes) == 0  # writes promoted -- reads satisfied

# --- 8. Repr ---
m = BudgetState()
r = repr(m)
assert "BudgetState" in r
assert "token_budget" in r
assert "tokens_used" in r

# --- 9. PredicateSchema with M.when() ---
from adk_fluent._predicate_schema import PredicateSchema


class PremiumOnly(PredicateSchema):
    """Only fire middleware for premium users."""

    user_tier: Annotated[str, Reads(scope="user")]

    @staticmethod
    def evaluate(user_tier):
        return user_tier == "premium"


# Reads keys
assert PremiumOnly.reads_keys() == frozenset({"user:user_tier"})

# Can be used with M.when() for state-aware conditional middleware
premium_mw = M.when(PremiumOnly, M.scope("writer", M.cost()))
assert len(premium_mw) == 1
# Condition is deferred to invocation time -- wraps in _ConditionalMiddleware
inner = premium_mw.to_stack()[0]
assert callable(getattr(inner, "after_model", None))

print("All MiddlewareSchema and contract checking assertions passed!")