gemini_adk_rs/tool/
policy.rs

1//! Per-tool execution policy — timeout, caching, and confirmation.
2//!
3//! A [`ToolPolicy`] describes optional runtime enforcement attached to an
4//! individual tool. [`PolicyTool`] is a [`ToolFunction`] decorator that wraps
5//! an inner tool and enforces its policy on every call:
6//!
7//! - **timeout**: the inner call is raced against [`tokio::time::timeout`];
8//!   on elapse, [`ToolError::Timeout`] is returned and the inner future dropped.
9//! - **cache**: successful results are memoized in a concurrent map keyed by
10//!   `(tool name, canonical-JSON args)`. Repeat calls with identical args return
11//!   the cached value without re-invoking the inner tool. Errors are not cached.
12//! - **confirm**: a declarative flag recorded on the policy and surfaced via
13//!   [`PolicyTool::requires_confirmation`]. The flag is never silently dropped;
14//!   full interactive confirmation wiring is handled by the session runtime.
15
16use std::sync::Arc;
17use std::time::Duration;
18
19use async_trait::async_trait;
20use dashmap::DashMap;
21
22use crate::error::ToolError;
23
24use super::ToolFunction;
25
26/// Optional per-tool execution policy.
27#[derive(Debug, Clone, Default)]
28pub struct ToolPolicy {
29    /// If set, the tool call is bounded by this duration.
30    pub timeout: Option<Duration>,
31    /// If `true`, successful results are memoized by `(name, canonical args)`.
32    pub cache: bool,
33    /// If `true`, the tool requires user confirmation before execution.
34    pub confirm: bool,
35    /// Optional hint shown when confirmation is requested.
36    pub confirm_message: Option<String>,
37}
38
39impl ToolPolicy {
40    /// Create an empty policy (no enforcement).
41    pub fn new() -> Self {
42        Self::default()
43    }
44
45    /// Whether this policy enforces anything at all.
46    ///
47    /// Used to decide whether wrapping a tool in a [`PolicyTool`] is worthwhile.
48    pub fn is_noop(&self) -> bool {
49        self.timeout.is_none() && !self.cache && !self.confirm
50    }
51
52    /// Set a timeout.
53    pub fn with_timeout(mut self, d: Duration) -> Self {
54        self.timeout = Some(d);
55        self
56    }
57
58    /// Enable caching.
59    pub fn with_cache(mut self) -> Self {
60        self.cache = true;
61        self
62    }
63
64    /// Require confirmation with an optional message.
65    pub fn with_confirm(mut self, message: Option<String>) -> Self {
66        self.confirm = true;
67        self.confirm_message = message;
68        self
69    }
70
71    /// Merge another policy into this one (the other takes precedence where set).
72    pub fn merge(mut self, other: &ToolPolicy) -> Self {
73        if other.timeout.is_some() {
74            self.timeout = other.timeout;
75        }
76        self.cache |= other.cache;
77        if other.confirm {
78            self.confirm = true;
79            if other.confirm_message.is_some() {
80                self.confirm_message = other.confirm_message.clone();
81            }
82        }
83        self
84    }
85}
86
87/// A [`ToolFunction`] decorator that enforces a [`ToolPolicy`].
88pub struct PolicyTool {
89    inner: Arc<dyn ToolFunction>,
90    policy: ToolPolicy,
91    cache: Arc<DashMap<String, serde_json::Value>>,
92}
93
94impl PolicyTool {
95    /// Wrap `inner` with the given `policy`.
96    pub fn new(inner: Arc<dyn ToolFunction>, policy: ToolPolicy) -> Self {
97        Self {
98            inner,
99            policy,
100            cache: Arc::new(DashMap::new()),
101        }
102    }
103
104    /// Wrap `inner` only if the policy enforces something; otherwise return `inner`.
105    pub fn wrap(inner: Arc<dyn ToolFunction>, policy: ToolPolicy) -> Arc<dyn ToolFunction> {
106        if policy.is_noop() {
107            inner
108        } else {
109            Arc::new(Self::new(inner, policy))
110        }
111    }
112
113    /// Whether this tool requires user confirmation before execution.
114    pub fn requires_confirmation(&self) -> bool {
115        self.policy.confirm
116    }
117
118    /// The policy attached to this tool.
119    pub fn policy(&self) -> &ToolPolicy {
120        &self.policy
121    }
122
123    /// Build a stable cache key from the tool name and canonical-JSON args.
124    fn cache_key(&self, args: &serde_json::Value) -> String {
125        format!("{}\u{1}{}", self.inner.name(), canonical_json(args))
126    }
127}
128
129/// Render a JSON value canonically so equal values produce equal strings.
130fn canonical_json(value: &serde_json::Value) -> String {
131    match value {
132        serde_json::Value::Object(map) => {
133            let mut keys: Vec<&String> = map.keys().collect();
134            keys.sort();
135            let mut out = String::from("{");
136            for (i, k) in keys.iter().enumerate() {
137                if i > 0 {
138                    out.push(',');
139                }
140                out.push_str(&serde_json::to_string(k).unwrap_or_default());
141                out.push(':');
142                out.push_str(&canonical_json(&map[*k]));
143            }
144            out.push('}');
145            out
146        }
147        serde_json::Value::Array(items) => {
148            let mut out = String::from("[");
149            for (i, item) in items.iter().enumerate() {
150                if i > 0 {
151                    out.push(',');
152                }
153                out.push_str(&canonical_json(item));
154            }
155            out.push(']');
156            out
157        }
158        other => serde_json::to_string(other).unwrap_or_default(),
159    }
160}
161
162#[async_trait]
163impl ToolFunction for PolicyTool {
164    fn name(&self) -> &str {
165        self.inner.name()
166    }
167
168    fn description(&self) -> &str {
169        self.inner.description()
170    }
171
172    fn parameters(&self) -> Option<serde_json::Value> {
173        self.inner.parameters()
174    }
175
176    fn requires_confirmation(&self) -> bool {
177        // Propagate through nested wrappers so modifier order can't bypass a
178        // gate: e.g. `T::cached(T::confirm(..))` wraps a confirm PolicyTool in
179        // a cache PolicyTool whose own policy has `confirm == false`.
180        self.policy.confirm || self.inner.requires_confirmation()
181    }
182
183    fn confirmation_message(&self) -> Option<&str> {
184        self.policy
185            .confirm_message
186            .as_deref()
187            .or_else(|| self.inner.confirmation_message())
188    }
189
190    async fn call(&self, args: serde_json::Value) -> Result<serde_json::Value, ToolError> {
191        // Cache lookup (only for cacheable tools).
192        let key = if self.policy.cache {
193            let key = self.cache_key(&args);
194            if let Some(hit) = self.cache.get(&key) {
195                return Ok(hit.clone());
196            }
197            Some(key)
198        } else {
199            None
200        };
201
202        // Execute with optional timeout enforcement.
203        let result = if let Some(timeout) = self.policy.timeout {
204            match tokio::time::timeout(timeout, self.inner.call(args)).await {
205                Ok(r) => r,
206                Err(_elapsed) => Err(ToolError::Timeout(timeout)),
207            }
208        } else {
209            self.inner.call(args).await
210        };
211
212        // Memoize successful results.
213        if let (Some(key), Ok(value)) = (key, &result) {
214            self.cache.insert(key, value.clone());
215        }
216
217        result
218    }
219}
220
221#[cfg(test)]
222mod tests {
223    use super::*;
224    use crate::tool::SimpleTool;
225    use serde_json::json;
226    use std::sync::atomic::{AtomicU32, Ordering};
227
228    #[tokio::test]
229    async fn timeout_policy_returns_timeout_error() {
230        let slow: Arc<dyn ToolFunction> = Arc::new(SimpleTool::new(
231            "slow",
232            "sleeps too long",
233            None,
234            |_| async move {
235                tokio::time::sleep(Duration::from_secs(3600)).await;
236                Ok(json!({"ok": true}))
237            },
238        ));
239        let tool = PolicyTool::new(
240            slow,
241            ToolPolicy::new().with_timeout(Duration::from_millis(50)),
242        );
243
244        match tool.call(json!({})).await {
245            Err(ToolError::Timeout(d)) => assert_eq!(d, Duration::from_millis(50)),
246            other => panic!("expected Timeout, got {other:?}"),
247        }
248    }
249
250    #[tokio::test]
251    async fn under_timeout_succeeds() {
252        let fast: Arc<dyn ToolFunction> = Arc::new(SimpleTool::new(
253            "fast",
254            "returns quickly",
255            None,
256            |_| async move { Ok(json!({"ok": true})) },
257        ));
258        let tool = PolicyTool::new(fast, ToolPolicy::new().with_timeout(Duration::from_secs(5)));
259        let out = tool.call(json!({})).await.unwrap();
260        assert_eq!(out["ok"], true);
261    }
262
263    #[tokio::test]
264    async fn cache_returns_same_value_and_runs_once() {
265        let counter = Arc::new(AtomicU32::new(0));
266        let c = counter.clone();
267        let counting: Arc<dyn ToolFunction> = Arc::new(SimpleTool::new(
268            "count",
269            "increments a counter",
270            None,
271            move |_| {
272                let c = c.clone();
273                async move {
274                    let n = c.fetch_add(1, Ordering::SeqCst) + 1;
275                    Ok(json!({"n": n}))
276                }
277            },
278        ));
279        let tool = PolicyTool::new(counting, ToolPolicy::new().with_cache());
280
281        let first = tool.call(json!({"x": 1})).await.unwrap();
282        let second = tool.call(json!({"x": 1})).await.unwrap();
283        assert_eq!(first, second);
284        assert_eq!(first["n"], 1);
285        assert_eq!(counter.load(Ordering::SeqCst), 1);
286
287        // Different args -> cache miss, counter advances.
288        let third = tool.call(json!({"x": 2})).await.unwrap();
289        assert_eq!(third["n"], 2);
290        assert_eq!(counter.load(Ordering::SeqCst), 2);
291    }
292
293    #[tokio::test]
294    async fn cache_key_is_order_independent() {
295        let counter = Arc::new(AtomicU32::new(0));
296        let c = counter.clone();
297        let counting: Arc<dyn ToolFunction> = Arc::new(SimpleTool::new(
298            "count2",
299            "increments a counter",
300            None,
301            move |_| {
302                let c = c.clone();
303                async move {
304                    c.fetch_add(1, Ordering::SeqCst);
305                    Ok(json!({"ok": true}))
306                }
307            },
308        ));
309        let tool = PolicyTool::new(counting, ToolPolicy::new().with_cache());
310
311        tool.call(json!({"a": 1, "b": 2})).await.unwrap();
312        // Same logical args, different key order -> should hit cache.
313        tool.call(json!({"b": 2, "a": 1})).await.unwrap();
314        assert_eq!(counter.load(Ordering::SeqCst), 1);
315    }
316
317    #[tokio::test]
318    async fn errors_are_not_cached() {
319        let counter = Arc::new(AtomicU32::new(0));
320        let c = counter.clone();
321        let failing: Arc<dyn ToolFunction> =
322            Arc::new(SimpleTool::new("fail", "always fails", None, move |_| {
323                let c = c.clone();
324                async move {
325                    c.fetch_add(1, Ordering::SeqCst);
326                    Err(ToolError::ExecutionFailed("boom".into()))
327                }
328            }));
329        let tool = PolicyTool::new(failing, ToolPolicy::new().with_cache());
330
331        assert!(tool.call(json!({})).await.is_err());
332        assert!(tool.call(json!({})).await.is_err());
333        assert_eq!(counter.load(Ordering::SeqCst), 2);
334    }
335
336    #[tokio::test]
337    async fn wrap_skips_noop_policy() {
338        let inner: Arc<dyn ToolFunction> = Arc::new(SimpleTool::new(
339            "plain",
340            "plain tool",
341            None,
342            |_| async move { Ok(json!({})) },
343        ));
344        let wrapped = PolicyTool::wrap(inner.clone(), ToolPolicy::new());
345        assert_eq!(wrapped.name(), "plain");
346        // confirm-only policy still wraps so the flag is preserved.
347        let confirmed = PolicyTool::wrap(inner, ToolPolicy::new().with_confirm(None));
348        assert_eq!(confirmed.name(), "plain");
349    }
350
351    #[test]
352    fn confirm_flag_is_recorded() {
353        let inner: Arc<dyn ToolFunction> = Arc::new(SimpleTool::new(
354            "danger",
355            "dangerous",
356            None,
357            |_| async move { Ok(json!({})) },
358        ));
359        let tool = PolicyTool::new(
360            inner,
361            ToolPolicy::new().with_confirm(Some("are you sure?".into())),
362        );
363        assert!(tool.requires_confirmation());
364        assert_eq!(
365            tool.policy().confirm_message.as_deref(),
366            Some("are you sure?")
367        );
368    }
369}