gemini_adk_rs/
confirmation.rs

1//! Tool confirmation — user confirmation for sensitive tool calls.
2
3use std::future::Future;
4use std::sync::Arc;
5
6use async_trait::async_trait;
7use serde::{Deserialize, Serialize};
8
9/// Represents a user's confirmation decision for a tool call.
10#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct ToolConfirmation {
12    /// Optional hint text explaining what needs confirmation.
13    pub hint: Option<String>,
14    /// Whether the user confirmed the action.
15    pub confirmed: bool,
16    /// Optional payload with additional context.
17    pub payload: Option<serde_json::Value>,
18}
19
20impl ToolConfirmation {
21    /// Create a confirmed result.
22    pub fn confirmed() -> Self {
23        Self {
24            hint: None,
25            confirmed: true,
26            payload: None,
27        }
28    }
29
30    /// Create a denied result with a hint explaining why.
31    pub fn denied(hint: impl Into<String>) -> Self {
32        Self {
33            hint: Some(hint.into()),
34            confirmed: false,
35            payload: None,
36        }
37    }
38
39    /// Attach a payload to this confirmation.
40    pub fn with_payload(mut self, payload: serde_json::Value) -> Self {
41        self.payload = Some(payload);
42        self
43    }
44}
45
46/// A request for confirmation of a sensitive tool call, handed to a
47/// [`ConfirmationProvider`] before the tool executes.
48#[derive(Debug, Clone)]
49pub struct ConfirmationRequest {
50    /// The tool about to run.
51    pub tool_name: String,
52    /// The arguments the model supplied.
53    pub args: serde_json::Value,
54    /// Optional hint describing what needs confirming (from the tool's policy).
55    pub message: Option<String>,
56}
57
58/// Decides whether a confirmation-gated tool call may proceed.
59///
60/// Wire one into a [`ToolDispatcher`](crate::tool::ToolDispatcher) via
61/// [`with_confirmation_provider`](crate::tool::ToolDispatcher::with_confirmation_provider).
62/// When a tool reports [`ToolFunction::requires_confirmation`](crate::tool::ToolFunction::requires_confirmation)
63/// (e.g. one built with `T::confirm(..)`), the dispatcher consults the provider
64/// before executing and returns an error if it is denied. Enforcement is
65/// opt-in: with no provider configured, confirmation-gated tools run normally.
66#[async_trait]
67pub trait ConfirmationProvider: Send + Sync {
68    /// Resolve a confirmation decision for the given request.
69    async fn confirm(&self, request: ConfirmationRequest) -> ToolConfirmation;
70}
71
72/// Blanket impl so a plain async closure can act as a [`ConfirmationProvider`]:
73///
74/// ```rust,ignore
75/// dispatcher.set_confirmation_provider(std::sync::Arc::new(
76///     |req: ConfirmationRequest| async move {
77///         if req.tool_name == "delete_account" {
78///             ToolConfirmation::denied("blocked by policy")
79///         } else {
80///             ToolConfirmation::confirmed()
81///         }
82///     },
83/// ));
84/// ```
85#[async_trait]
86impl<F, Fut> ConfirmationProvider for F
87where
88    F: Fn(ConfirmationRequest) -> Fut + Send + Sync,
89    Fut: Future<Output = ToolConfirmation> + Send,
90{
91    async fn confirm(&self, request: ConfirmationRequest) -> ToolConfirmation {
92        self(request).await
93    }
94}
95
96/// A [`ConfirmationProvider`] that approves or denies every request uniformly —
97/// handy for tests and "deny-all" / "allow-all" defaults.
98pub struct StaticConfirmation {
99    confirmed: bool,
100    hint: Option<String>,
101}
102
103impl StaticConfirmation {
104    /// Approve every confirmation request.
105    pub fn allow_all() -> Arc<dyn ConfirmationProvider> {
106        Arc::new(Self {
107            confirmed: true,
108            hint: None,
109        })
110    }
111
112    /// Deny every confirmation request with an optional hint.
113    pub fn deny_all(hint: impl Into<String>) -> Arc<dyn ConfirmationProvider> {
114        Arc::new(Self {
115            confirmed: false,
116            hint: Some(hint.into()),
117        })
118    }
119}
120
121#[async_trait]
122impl ConfirmationProvider for StaticConfirmation {
123    async fn confirm(&self, _request: ConfirmationRequest) -> ToolConfirmation {
124        ToolConfirmation {
125            hint: self.hint.clone(),
126            confirmed: self.confirmed,
127            payload: None,
128        }
129    }
130}
131
132#[cfg(test)]
133mod tests {
134    use super::*;
135
136    #[test]
137    fn confirmed_constructor() {
138        let c = ToolConfirmation::confirmed();
139        assert!(c.confirmed);
140        assert!(c.hint.is_none());
141        assert!(c.payload.is_none());
142    }
143
144    #[test]
145    fn denied_constructor() {
146        let c = ToolConfirmation::denied("Too dangerous");
147        assert!(!c.confirmed);
148        assert_eq!(c.hint.as_deref(), Some("Too dangerous"));
149    }
150
151    #[test]
152    fn with_payload() {
153        let c =
154            ToolConfirmation::confirmed().with_payload(serde_json::json!({"reason": "approved"}));
155        assert!(c.confirmed);
156        assert_eq!(c.payload.unwrap()["reason"], "approved");
157    }
158
159    #[test]
160    fn serde_roundtrip() {
161        let c =
162            ToolConfirmation::denied("risky").with_payload(serde_json::json!({"level": "high"}));
163        let json = serde_json::to_string(&c).unwrap();
164        let parsed: ToolConfirmation = serde_json::from_str(&json).unwrap();
165        assert!(!parsed.confirmed);
166        assert_eq!(parsed.hint.as_deref(), Some("risky"));
167        assert_eq!(parsed.payload.unwrap()["level"], "high");
168    }
169}