1use std::sync::Arc;
17use std::time::Duration;
18
19use async_trait::async_trait;
20use dashmap::DashMap;
21
22use crate::error::ToolError;
23
24use super::ToolFunction;
25
26#[derive(Debug, Clone, Default)]
28pub struct ToolPolicy {
29 pub timeout: Option<Duration>,
31 pub cache: bool,
33 pub confirm: bool,
35 pub confirm_message: Option<String>,
37}
38
39impl ToolPolicy {
40 pub fn new() -> Self {
42 Self::default()
43 }
44
45 pub fn is_noop(&self) -> bool {
49 self.timeout.is_none() && !self.cache && !self.confirm
50 }
51
52 pub fn with_timeout(mut self, d: Duration) -> Self {
54 self.timeout = Some(d);
55 self
56 }
57
58 pub fn with_cache(mut self) -> Self {
60 self.cache = true;
61 self
62 }
63
64 pub fn with_confirm(mut self, message: Option<String>) -> Self {
66 self.confirm = true;
67 self.confirm_message = message;
68 self
69 }
70
71 pub fn merge(mut self, other: &ToolPolicy) -> Self {
73 if other.timeout.is_some() {
74 self.timeout = other.timeout;
75 }
76 self.cache |= other.cache;
77 if other.confirm {
78 self.confirm = true;
79 if other.confirm_message.is_some() {
80 self.confirm_message = other.confirm_message.clone();
81 }
82 }
83 self
84 }
85}
86
87pub struct PolicyTool {
89 inner: Arc<dyn ToolFunction>,
90 policy: ToolPolicy,
91 cache: Arc<DashMap<String, serde_json::Value>>,
92}
93
94impl PolicyTool {
95 pub fn new(inner: Arc<dyn ToolFunction>, policy: ToolPolicy) -> Self {
97 Self {
98 inner,
99 policy,
100 cache: Arc::new(DashMap::new()),
101 }
102 }
103
104 pub fn wrap(inner: Arc<dyn ToolFunction>, policy: ToolPolicy) -> Arc<dyn ToolFunction> {
106 if policy.is_noop() {
107 inner
108 } else {
109 Arc::new(Self::new(inner, policy))
110 }
111 }
112
113 pub fn requires_confirmation(&self) -> bool {
115 self.policy.confirm
116 }
117
118 pub fn policy(&self) -> &ToolPolicy {
120 &self.policy
121 }
122
123 fn cache_key(&self, args: &serde_json::Value) -> String {
125 format!("{}\u{1}{}", self.inner.name(), canonical_json(args))
126 }
127}
128
129fn canonical_json(value: &serde_json::Value) -> String {
131 match value {
132 serde_json::Value::Object(map) => {
133 let mut keys: Vec<&String> = map.keys().collect();
134 keys.sort();
135 let mut out = String::from("{");
136 for (i, k) in keys.iter().enumerate() {
137 if i > 0 {
138 out.push(',');
139 }
140 out.push_str(&serde_json::to_string(k).unwrap_or_default());
141 out.push(':');
142 out.push_str(&canonical_json(&map[*k]));
143 }
144 out.push('}');
145 out
146 }
147 serde_json::Value::Array(items) => {
148 let mut out = String::from("[");
149 for (i, item) in items.iter().enumerate() {
150 if i > 0 {
151 out.push(',');
152 }
153 out.push_str(&canonical_json(item));
154 }
155 out.push(']');
156 out
157 }
158 other => serde_json::to_string(other).unwrap_or_default(),
159 }
160}
161
162#[async_trait]
163impl ToolFunction for PolicyTool {
164 fn name(&self) -> &str {
165 self.inner.name()
166 }
167
168 fn description(&self) -> &str {
169 self.inner.description()
170 }
171
172 fn parameters(&self) -> Option<serde_json::Value> {
173 self.inner.parameters()
174 }
175
176 fn requires_confirmation(&self) -> bool {
177 self.policy.confirm || self.inner.requires_confirmation()
181 }
182
183 fn confirmation_message(&self) -> Option<&str> {
184 self.policy
185 .confirm_message
186 .as_deref()
187 .or_else(|| self.inner.confirmation_message())
188 }
189
190 async fn call(&self, args: serde_json::Value) -> Result<serde_json::Value, ToolError> {
191 let key = if self.policy.cache {
193 let key = self.cache_key(&args);
194 if let Some(hit) = self.cache.get(&key) {
195 return Ok(hit.clone());
196 }
197 Some(key)
198 } else {
199 None
200 };
201
202 let result = if let Some(timeout) = self.policy.timeout {
204 match tokio::time::timeout(timeout, self.inner.call(args)).await {
205 Ok(r) => r,
206 Err(_elapsed) => Err(ToolError::Timeout(timeout)),
207 }
208 } else {
209 self.inner.call(args).await
210 };
211
212 if let (Some(key), Ok(value)) = (key, &result) {
214 self.cache.insert(key, value.clone());
215 }
216
217 result
218 }
219}
220
221#[cfg(test)]
222mod tests {
223 use super::*;
224 use crate::tool::SimpleTool;
225 use serde_json::json;
226 use std::sync::atomic::{AtomicU32, Ordering};
227
228 #[tokio::test]
229 async fn timeout_policy_returns_timeout_error() {
230 let slow: Arc<dyn ToolFunction> = Arc::new(SimpleTool::new(
231 "slow",
232 "sleeps too long",
233 None,
234 |_| async move {
235 tokio::time::sleep(Duration::from_secs(3600)).await;
236 Ok(json!({"ok": true}))
237 },
238 ));
239 let tool = PolicyTool::new(
240 slow,
241 ToolPolicy::new().with_timeout(Duration::from_millis(50)),
242 );
243
244 match tool.call(json!({})).await {
245 Err(ToolError::Timeout(d)) => assert_eq!(d, Duration::from_millis(50)),
246 other => panic!("expected Timeout, got {other:?}"),
247 }
248 }
249
250 #[tokio::test]
251 async fn under_timeout_succeeds() {
252 let fast: Arc<dyn ToolFunction> = Arc::new(SimpleTool::new(
253 "fast",
254 "returns quickly",
255 None,
256 |_| async move { Ok(json!({"ok": true})) },
257 ));
258 let tool = PolicyTool::new(fast, ToolPolicy::new().with_timeout(Duration::from_secs(5)));
259 let out = tool.call(json!({})).await.unwrap();
260 assert_eq!(out["ok"], true);
261 }
262
263 #[tokio::test]
264 async fn cache_returns_same_value_and_runs_once() {
265 let counter = Arc::new(AtomicU32::new(0));
266 let c = counter.clone();
267 let counting: Arc<dyn ToolFunction> = Arc::new(SimpleTool::new(
268 "count",
269 "increments a counter",
270 None,
271 move |_| {
272 let c = c.clone();
273 async move {
274 let n = c.fetch_add(1, Ordering::SeqCst) + 1;
275 Ok(json!({"n": n}))
276 }
277 },
278 ));
279 let tool = PolicyTool::new(counting, ToolPolicy::new().with_cache());
280
281 let first = tool.call(json!({"x": 1})).await.unwrap();
282 let second = tool.call(json!({"x": 1})).await.unwrap();
283 assert_eq!(first, second);
284 assert_eq!(first["n"], 1);
285 assert_eq!(counter.load(Ordering::SeqCst), 1);
286
287 let third = tool.call(json!({"x": 2})).await.unwrap();
289 assert_eq!(third["n"], 2);
290 assert_eq!(counter.load(Ordering::SeqCst), 2);
291 }
292
293 #[tokio::test]
294 async fn cache_key_is_order_independent() {
295 let counter = Arc::new(AtomicU32::new(0));
296 let c = counter.clone();
297 let counting: Arc<dyn ToolFunction> = Arc::new(SimpleTool::new(
298 "count2",
299 "increments a counter",
300 None,
301 move |_| {
302 let c = c.clone();
303 async move {
304 c.fetch_add(1, Ordering::SeqCst);
305 Ok(json!({"ok": true}))
306 }
307 },
308 ));
309 let tool = PolicyTool::new(counting, ToolPolicy::new().with_cache());
310
311 tool.call(json!({"a": 1, "b": 2})).await.unwrap();
312 tool.call(json!({"b": 2, "a": 1})).await.unwrap();
314 assert_eq!(counter.load(Ordering::SeqCst), 1);
315 }
316
317 #[tokio::test]
318 async fn errors_are_not_cached() {
319 let counter = Arc::new(AtomicU32::new(0));
320 let c = counter.clone();
321 let failing: Arc<dyn ToolFunction> =
322 Arc::new(SimpleTool::new("fail", "always fails", None, move |_| {
323 let c = c.clone();
324 async move {
325 c.fetch_add(1, Ordering::SeqCst);
326 Err(ToolError::ExecutionFailed("boom".into()))
327 }
328 }));
329 let tool = PolicyTool::new(failing, ToolPolicy::new().with_cache());
330
331 assert!(tool.call(json!({})).await.is_err());
332 assert!(tool.call(json!({})).await.is_err());
333 assert_eq!(counter.load(Ordering::SeqCst), 2);
334 }
335
336 #[tokio::test]
337 async fn wrap_skips_noop_policy() {
338 let inner: Arc<dyn ToolFunction> = Arc::new(SimpleTool::new(
339 "plain",
340 "plain tool",
341 None,
342 |_| async move { Ok(json!({})) },
343 ));
344 let wrapped = PolicyTool::wrap(inner.clone(), ToolPolicy::new());
345 assert_eq!(wrapped.name(), "plain");
346 let confirmed = PolicyTool::wrap(inner, ToolPolicy::new().with_confirm(None));
348 assert_eq!(confirmed.name(), "plain");
349 }
350
351 #[test]
352 fn confirm_flag_is_recorded() {
353 let inner: Arc<dyn ToolFunction> = Arc::new(SimpleTool::new(
354 "danger",
355 "dangerous",
356 None,
357 |_| async move { Ok(json!({})) },
358 ));
359 let tool = PolicyTool::new(
360 inner,
361 ToolPolicy::new().with_confirm(Some("are you sure?".into())),
362 );
363 assert!(tool.requires_confirmation());
364 assert_eq!(
365 tool.policy().confirm_message.as_deref(),
366 Some("are you sure?")
367 );
368 }
369}