gemini_adk_fluent_rs/
testing.rs1use std::collections::{HashMap, HashSet};
4
5use crate::builder::AgentBuilder;
6
7#[derive(Debug, Clone, PartialEq)]
9pub enum ContractViolation {
10 UnproducedKey {
12 consumer: String,
14 key: String,
16 },
17 DuplicateWrite {
19 agents: Vec<String>,
21 key: String,
23 },
24 OrphanedOutput {
26 producer: String,
28 key: String,
30 },
31}
32
33pub fn check_contracts(agents: &[AgentBuilder]) -> Vec<ContractViolation> {
40 let mut violations = Vec::new();
41
42 let mut all_writes: HashMap<String, Vec<String>> = HashMap::new();
44 let mut all_reads: HashSet<String> = HashSet::new();
45 let mut all_written_keys: HashSet<String> = HashSet::new();
46
47 for agent in agents {
48 for key in agent.get_writes() {
49 all_writes
50 .entry(key.clone())
51 .or_default()
52 .push(agent.name().to_string());
53 all_written_keys.insert(key.clone());
54 }
55 for key in agent.get_reads() {
56 all_reads.insert(key.clone());
57 }
58 }
59
60 for agent in agents {
62 for key in agent.get_reads() {
63 if !all_written_keys.contains(key) {
64 violations.push(ContractViolation::UnproducedKey {
65 consumer: agent.name().to_string(),
66 key: key.clone(),
67 });
68 }
69 }
70 }
71
72 for (key, writers) in &all_writes {
74 if writers.len() > 1 {
75 violations.push(ContractViolation::DuplicateWrite {
76 agents: writers.clone(),
77 key: key.clone(),
78 });
79 }
80 }
81
82 for agent in agents {
84 for key in agent.get_writes() {
85 if !all_reads.contains(key) {
86 violations.push(ContractViolation::OrphanedOutput {
87 producer: agent.name().to_string(),
88 key: key.clone(),
89 });
90 }
91 }
92 }
93
94 violations
95}
96
97pub fn infer_data_flow(agents: &[AgentBuilder]) -> Vec<DataFlowEdge> {
101 let mut edges = Vec::new();
102
103 for producer in agents {
104 for consumer in agents {
105 if producer.name() == consumer.name() {
106 continue;
107 }
108 for write_key in producer.get_writes() {
109 if consumer.get_reads().contains(write_key) {
110 edges.push(DataFlowEdge {
111 producer: producer.name().to_string(),
112 consumer: consumer.name().to_string(),
113 key: write_key.clone(),
114 });
115 }
116 }
117 }
118 }
119
120 edges
121}
122
123#[derive(Debug, Clone, PartialEq)]
125pub struct DataFlowEdge {
126 pub producer: String,
128 pub consumer: String,
130 pub key: String,
132}
133
134pub struct AgentHarness {
136 state: gemini_adk_rs::State,
137}
138
139impl AgentHarness {
140 pub fn new() -> Self {
142 Self {
143 state: gemini_adk_rs::State::new(),
144 }
145 }
146
147 pub fn set<V: serde::Serialize>(self, key: &str, value: V) -> Self {
149 self.state.set(key, value);
150 self
151 }
152
153 pub fn state(&self) -> &gemini_adk_rs::State {
155 &self.state
156 }
157
158 pub async fn run(
160 &self,
161 agent: &dyn gemini_adk_rs::text::TextAgent,
162 ) -> Result<String, gemini_adk_rs::error::AgentError> {
163 agent.run(&self.state).await
164 }
165}
166
167impl Default for AgentHarness {
168 fn default() -> Self {
169 Self::new()
170 }
171}
172
173pub fn diagnose(agent: &AgentBuilder) -> String {
175 let mut lines = Vec::new();
176 lines.push(format!("Agent: {}", agent.name()));
177
178 if let Some(model) = agent.get_model() {
179 lines.push(format!(" Model: {:?}", model));
180 }
181 if let Some(inst) = agent.get_instruction() {
182 let truncated = if inst.len() > 80 {
183 format!("{}...", &inst[..80])
184 } else {
185 inst.to_string()
186 };
187 lines.push(format!(" Instruction: {}", truncated));
188 }
189 if let Some(t) = agent.get_temperature() {
190 lines.push(format!(" Temperature: {}", t));
191 }
192 if agent.tool_count() > 0 {
193 lines.push(format!(" Tools: {}", agent.tool_count()));
194 }
195 if !agent.get_writes().is_empty() {
196 lines.push(format!(" Writes: {:?}", agent.get_writes()));
197 }
198 if !agent.get_reads().is_empty() {
199 lines.push(format!(" Reads: {:?}", agent.get_reads()));
200 }
201 if !agent.get_sub_agents().is_empty() {
202 lines.push(format!(" Sub-agents: {}", agent.get_sub_agents().len()));
203 }
204
205 lines.join("\n")
206}
207
208#[cfg(test)]
209mod tests {
210 use super::*;
211
212 #[test]
213 fn no_violations_for_matching_contracts() {
214 let writer = AgentBuilder::new("writer").writes("output");
215 let reader = AgentBuilder::new("reader").reads("output");
216 let violations = check_contracts(&[writer, reader]);
217 assert!(violations.is_empty());
218 }
219
220 #[test]
221 fn detects_unproduced_key() {
222 let reader = AgentBuilder::new("reader").reads("missing");
223 let violations = check_contracts(&[reader]);
224 assert_eq!(violations.len(), 1);
225 assert!(matches!(
226 &violations[0],
227 ContractViolation::UnproducedKey {
228 consumer,
229 key,
230 } if consumer == "reader" && key == "missing"
231 ));
232 }
233
234 #[test]
235 fn detects_duplicate_write() {
236 let a = AgentBuilder::new("a").writes("shared");
237 let b = AgentBuilder::new("b").writes("shared").reads("shared");
238 let violations = check_contracts(&[a, b]);
239 assert!(violations.iter().any(
240 |v| matches!(v, ContractViolation::DuplicateWrite { key, .. } if key == "shared")
241 ));
242 }
243
244 #[test]
245 fn detects_orphaned_output() {
246 let writer = AgentBuilder::new("writer").writes("unused");
247 let violations = check_contracts(&[writer]);
248 assert_eq!(violations.len(), 1);
249 assert!(matches!(
250 &violations[0],
251 ContractViolation::OrphanedOutput {
252 producer,
253 key,
254 } if producer == "writer" && key == "unused"
255 ));
256 }
257
258 #[test]
259 fn multiple_violations() {
260 let a = AgentBuilder::new("a").writes("orphan");
261 let b = AgentBuilder::new("b").reads("missing");
262 let violations = check_contracts(&[a, b]);
263 assert_eq!(violations.len(), 2);
264 }
265
266 #[test]
267 fn empty_agents_no_violations() {
268 let violations = check_contracts(&[]);
269 assert!(violations.is_empty());
270 }
271
272 #[test]
273 fn infer_data_flow_finds_edges() {
274 let writer = AgentBuilder::new("writer").writes("output");
275 let reader = AgentBuilder::new("reader").reads("output");
276 let edges = infer_data_flow(&[writer, reader]);
277 assert_eq!(edges.len(), 1);
278 assert_eq!(edges[0].producer, "writer");
279 assert_eq!(edges[0].consumer, "reader");
280 assert_eq!(edges[0].key, "output");
281 }
282
283 #[test]
284 fn infer_data_flow_no_self_edges() {
285 let agent = AgentBuilder::new("self").writes("key").reads("key");
286 let edges = infer_data_flow(&[agent]);
287 assert!(edges.is_empty());
288 }
289
290 #[test]
291 fn diagnose_basic() {
292 let agent = AgentBuilder::new("test")
293 .instruction("Be helpful")
294 .temperature(0.5)
295 .writes("output");
296 let diag = diagnose(&agent);
297 assert!(diag.contains("test"));
298 assert!(diag.contains("Be helpful"));
299 assert!(diag.contains("0.5"));
300 }
301
302 #[test]
303 fn harness_sets_state() {
304 let harness = AgentHarness::new().set("key", "value");
305 let val: Option<String> = harness.state().get("key");
306 assert_eq!(val, Some("value".into()));
307 }
308
309 #[test]
310 fn complex_pipeline_contracts() {
311 let researcher = AgentBuilder::new("researcher")
312 .writes("findings")
313 .writes("sources");
314 let writer = AgentBuilder::new("writer")
315 .reads("findings")
316 .writes("draft");
317 let reviewer = AgentBuilder::new("reviewer")
318 .reads("draft")
319 .writes("quality");
320
321 let violations = check_contracts(&[researcher, writer, reviewer]);
322 let orphans: Vec<_> = violations
324 .iter()
325 .filter(|v| matches!(v, ContractViolation::OrphanedOutput { .. }))
326 .collect();
327 assert_eq!(orphans.len(), 2);
328 }
329}