gemini_adk_rs/
confirmation.rs1use std::future::Future;
4use std::sync::Arc;
5
6use async_trait::async_trait;
7use serde::{Deserialize, Serialize};
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct ToolConfirmation {
12 pub hint: Option<String>,
14 pub confirmed: bool,
16 pub payload: Option<serde_json::Value>,
18}
19
20impl ToolConfirmation {
21 pub fn confirmed() -> Self {
23 Self {
24 hint: None,
25 confirmed: true,
26 payload: None,
27 }
28 }
29
30 pub fn denied(hint: impl Into<String>) -> Self {
32 Self {
33 hint: Some(hint.into()),
34 confirmed: false,
35 payload: None,
36 }
37 }
38
39 pub fn with_payload(mut self, payload: serde_json::Value) -> Self {
41 self.payload = Some(payload);
42 self
43 }
44}
45
46#[derive(Debug, Clone)]
49pub struct ConfirmationRequest {
50 pub tool_name: String,
52 pub args: serde_json::Value,
54 pub message: Option<String>,
56}
57
58#[async_trait]
67pub trait ConfirmationProvider: Send + Sync {
68 async fn confirm(&self, request: ConfirmationRequest) -> ToolConfirmation;
70}
71
72#[async_trait]
86impl<F, Fut> ConfirmationProvider for F
87where
88 F: Fn(ConfirmationRequest) -> Fut + Send + Sync,
89 Fut: Future<Output = ToolConfirmation> + Send,
90{
91 async fn confirm(&self, request: ConfirmationRequest) -> ToolConfirmation {
92 self(request).await
93 }
94}
95
96pub struct StaticConfirmation {
99 confirmed: bool,
100 hint: Option<String>,
101}
102
103impl StaticConfirmation {
104 pub fn allow_all() -> Arc<dyn ConfirmationProvider> {
106 Arc::new(Self {
107 confirmed: true,
108 hint: None,
109 })
110 }
111
112 pub fn deny_all(hint: impl Into<String>) -> Arc<dyn ConfirmationProvider> {
114 Arc::new(Self {
115 confirmed: false,
116 hint: Some(hint.into()),
117 })
118 }
119}
120
121#[async_trait]
122impl ConfirmationProvider for StaticConfirmation {
123 async fn confirm(&self, _request: ConfirmationRequest) -> ToolConfirmation {
124 ToolConfirmation {
125 hint: self.hint.clone(),
126 confirmed: self.confirmed,
127 payload: None,
128 }
129 }
130}
131
132#[cfg(test)]
133mod tests {
134 use super::*;
135
136 #[test]
137 fn confirmed_constructor() {
138 let c = ToolConfirmation::confirmed();
139 assert!(c.confirmed);
140 assert!(c.hint.is_none());
141 assert!(c.payload.is_none());
142 }
143
144 #[test]
145 fn denied_constructor() {
146 let c = ToolConfirmation::denied("Too dangerous");
147 assert!(!c.confirmed);
148 assert_eq!(c.hint.as_deref(), Some("Too dangerous"));
149 }
150
151 #[test]
152 fn with_payload() {
153 let c =
154 ToolConfirmation::confirmed().with_payload(serde_json::json!({"reason": "approved"}));
155 assert!(c.confirmed);
156 assert_eq!(c.payload.unwrap()["reason"], "approved");
157 }
158
159 #[test]
160 fn serde_roundtrip() {
161 let c =
162 ToolConfirmation::denied("risky").with_payload(serde_json::json!({"level": "high"}));
163 let json = serde_json::to_string(&c).unwrap();
164 let parsed: ToolConfirmation = serde_json::from_str(&json).unwrap();
165 assert!(!parsed.confirmed);
166 assert_eq!(parsed.hint.as_deref(), Some("risky"));
167 assert_eq!(parsed.payload.unwrap()["level"], "high");
168 }
169}