1use proc_macro::TokenStream;
26use quote::{format_ident, quote};
27use syn::{
28 parse_macro_input, Data, DeriveInput, Expr, ExprLit, Fields, FnArg, ItemFn, Lit, LitInt,
29 LitStr, Meta, Pat, PatType, ReturnType, Type, TypePath,
30};
31
32#[proc_macro_attribute]
84pub fn tool(attr: TokenStream, item: TokenStream) -> TokenStream {
85 let description = parse_macro_input!(attr as LitStr);
86 let func = parse_macro_input!(item as ItemFn);
87
88 match expand(description, func) {
89 Ok(ts) => ts.into(),
90 Err(e) => e.to_compile_error().into(),
91 }
92}
93
94fn expand(description: LitStr, func: ItemFn) -> syn::Result<proc_macro2::TokenStream> {
95 let sig = &func.sig;
96
97 if sig.asyncness.is_none() {
98 return Err(syn::Error::new_spanned(
99 sig.fn_token,
100 "#[tool] requires an `async fn`",
101 ));
102 }
103 if let Some(variadic) = &sig.variadic {
104 return Err(syn::Error::new_spanned(
105 variadic,
106 "#[tool] does not support variadic functions",
107 ));
108 }
109 if !sig.generics.params.is_empty() {
110 return Err(syn::Error::new_spanned(
111 sig.generics.clone(),
112 "#[tool] does not support generic functions",
113 ));
114 }
115
116 let fn_name = &sig.ident;
117 let vis = &func.vis;
118 let body = &func.block;
119 let output = &sig.output;
120
121 let mut field_idents = Vec::new();
123 let mut field_types = Vec::new();
124 for input in &sig.inputs {
125 match input {
126 FnArg::Receiver(r) => {
127 return Err(syn::Error::new_spanned(
128 r,
129 "#[tool] cannot be applied to methods taking `self`",
130 ));
131 }
132 FnArg::Typed(PatType { pat, ty, .. }) => {
133 let ident = match pat.as_ref() {
134 Pat::Ident(pat_ident) => pat_ident.ident.clone(),
135 other => {
136 return Err(syn::Error::new_spanned(
137 other,
138 "#[tool] parameters must be simple identifiers (no patterns)",
139 ));
140 }
141 };
142 field_idents.push(ident);
143 field_types.push((*ty).clone());
144 }
145 }
146 }
147
148 let return_type: proc_macro2::TokenStream = match output {
151 ReturnType::Default => {
152 return Err(syn::Error::new_spanned(
153 sig,
154 "#[tool] requires a return type of `Result<serde_json::Value, ToolError>`",
155 ));
156 }
157 ReturnType::Type(_, ty) => quote! { #ty },
158 };
159
160 let pascal = to_pascal_case(&fn_name.to_string());
162 let args_struct = format_ident!("__{}Args", pascal);
163 let tool_struct = format_ident!("__{}Tool", pascal);
164 let inner_fn = format_ident!("__{}_impl", fn_name);
166
167 let fn_name_str = fn_name.to_string();
168
169 let struct_fields = field_idents
171 .iter()
172 .zip(field_types.iter())
173 .map(|(ident, ty)| {
174 if is_option(ty) {
177 quote! {
178 #[serde(default)]
179 #ident: #ty
180 }
181 } else {
182 quote! { #ident: #ty }
183 }
184 });
185
186 let destructure = &field_idents;
189 let forward_args = &field_idents;
190
191 let serde = quote! { ::gemini_adk_rs::__macros::serde };
194 let schemars = quote! { ::gemini_adk_rs::__macros::schemars };
195 let async_trait = quote! { ::gemini_adk_rs::__macros::async_trait };
196 let serde_json = quote! { ::gemini_adk_rs::__macros::serde_json };
197
198 let expanded = quote! {
199 #[derive(#serde::Deserialize, #schemars::JsonSchema)]
201 #[serde(crate = "gemini_adk_rs::__macros::serde")]
202 #[allow(non_camel_case_types, non_snake_case)]
203 struct #args_struct {
204 #(#struct_fields),*
205 }
206
207 #[allow(non_snake_case)]
209 async fn #inner_fn ( #(#field_idents : #field_types),* ) -> #return_type #body
210
211 #[allow(non_camel_case_types)]
213 #vis struct #tool_struct;
214
215 #[#async_trait::async_trait]
216 impl ::gemini_adk_rs::tool::ToolFunction for #tool_struct {
217 fn name(&self) -> &str {
218 #fn_name_str
219 }
220
221 fn description(&self) -> &str {
222 #description
223 }
224
225 fn parameters(&self) -> ::core::option::Option<#serde_json::Value> {
226 let root = #schemars::schema_for!(#args_struct);
227 ::core::option::Option::Some(
228 #serde_json::to_value(root)
229 .expect("schemars schema should serialize to JSON"),
230 )
231 }
232
233 async fn call(
234 &self,
235 args: #serde_json::Value,
236 ) -> ::core::result::Result<#serde_json::Value, ::gemini_adk_rs::error::ToolError> {
237 let #args_struct { #(#destructure),* } =
238 #serde_json::from_value(args).map_err(|e| {
239 ::gemini_adk_rs::error::ToolError::InvalidArgs(
240 ::std::format!("Failed to deserialize arguments: {e}"),
241 )
242 })?;
243 #inner_fn ( #(#forward_args),* ).await
244 }
245 }
246
247 #[allow(non_snake_case)]
249 #vis fn #fn_name () -> #tool_struct {
250 #tool_struct
251 }
252 };
253
254 Ok(expanded)
255}
256
257#[proc_macro_derive(Extract, attributes(recognize, extract))]
307pub fn derive_extract(item: TokenStream) -> TokenStream {
308 let input = parse_macro_input!(item as DeriveInput);
309 match expand_extract(input) {
310 Ok(ts) => ts.into(),
311 Err(e) => e.to_compile_error().into(),
312 }
313}
314
315fn expand_extract(input: DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
316 let ident = &input.ident;
317
318 let fields = match &input.data {
319 Data::Struct(s) => match &s.fields {
320 Fields::Named(named) => &named.named,
321 _ => {
322 return Err(syn::Error::new_spanned(
323 ident,
324 "#[derive(Extract)] requires a struct with named fields",
325 ));
326 }
327 },
328 _ => {
329 return Err(syn::Error::new_spanned(
330 ident,
331 "#[derive(Extract)] can only be applied to structs",
332 ));
333 }
334 };
335
336 let mut name = to_snake_case(&ident.to_string());
338 let mut window: usize = 3;
339 for attr in &input.attrs {
340 if attr.path().is_ident("extract") {
341 attr.parse_nested_meta(|meta| {
342 if meta.path.is_ident("name") {
343 let v: LitStr = meta.value()?.parse()?;
344 name = v.value();
345 } else if meta.path.is_ident("window") {
346 let v: LitInt = meta.value()?.parse()?;
347 window = v.base10_parse()?;
348 } else {
349 return Err(
350 meta.error("unknown `extract` option (expected `name` or `window`)")
351 );
352 }
353 Ok(())
354 })?;
355 }
356 }
357
358 let all_field_idents: Vec<_> = fields.iter().filter_map(|f| f.ident.clone()).collect();
361
362 let mut field_calls = Vec::new();
364 for field in fields {
365 let Some(recognize) = field.attrs.iter().find(|a| a.path().is_ident("recognize")) else {
366 continue;
367 };
368 let fname = field.ident.as_ref().expect("named field").to_string();
369 let recognizer = recognizer_expr(recognize)?;
370
371 let mut state_key: Option<String> = None;
373 for attr in &field.attrs {
374 if attr.path().is_ident("extract") {
375 attr.parse_nested_meta(|meta| {
376 if meta.path.is_ident("state") {
377 let v: LitStr = meta.value()?.parse()?;
378 state_key = Some(v.value());
379 } else {
380 return Err(meta.error("unknown field `extract` option (expected `state`)"));
381 }
382 Ok(())
383 })?;
384 }
385 }
386
387 field_calls.push(match state_key {
388 Some(sk) => quote! { .field_to(#fname, #sk, #recognizer) },
389 None => quote! { .field(#fname, #recognizer) },
390 });
391 }
392
393 let doc = format!("The `Extract` record derived from `{ident}`'s `#[recognize(..)]` fields.");
394 Ok(quote! {
395 impl #ident {
396 #[doc = #doc]
397 pub fn extract() -> ::gemini_adk_rs::extract::Extract {
398 ::gemini_adk_rs::extract::Extract::record(#name)
399 #(#field_calls)*
400 .window(#window)
401 .build()
402 }
403
404 #[allow(dead_code)]
405 #[doc(hidden)]
406 fn __extract_mark_fields_used(&self) {
407 #( let _ = &self.#all_field_idents; )*
408 }
409 }
410 })
411}
412
413#[proc_macro_derive(Frame, attributes(slot, frame, recognize))]
435pub fn derive_frame(item: TokenStream) -> TokenStream {
436 let input = parse_macro_input!(item as DeriveInput);
437 match expand_frame(input) {
438 Ok(ts) => ts.into(),
439 Err(e) => e.to_compile_error().into(),
440 }
441}
442
443fn expand_frame(input: DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
444 let ident = &input.ident;
445
446 let fields = match &input.data {
447 Data::Struct(s) => match &s.fields {
448 Fields::Named(named) => &named.named,
449 _ => {
450 return Err(syn::Error::new_spanned(
451 ident,
452 "#[derive(Frame)] requires a struct with named fields",
453 ))
454 }
455 },
456 _ => {
457 return Err(syn::Error::new_spanned(
458 ident,
459 "#[derive(Frame)] can only be applied to structs",
460 ))
461 }
462 };
463
464 let mut name = to_snake_case(&ident.to_string());
466 for attr in &input.attrs {
467 if attr.path().is_ident("frame") {
468 attr.parse_nested_meta(|meta| {
469 if meta.path.is_ident("name") {
470 let v: LitStr = meta.value()?.parse()?;
471 name = v.value();
472 Ok(())
473 } else {
474 Err(meta.error("unknown `frame` option (expected `name`)"))
475 }
476 })?;
477 }
478 }
479
480 let all_field_idents: Vec<_> = fields.iter().filter_map(|f| f.ident.clone()).collect();
481
482 let mut slot_exprs = Vec::new();
483 for field in fields {
484 let fname = field.ident.as_ref().expect("named field").to_string();
485 let mut state_key = fname.clone();
486 let mut prompt: Option<String> = None;
487 let mut reprompt: Option<String> = None;
488 let mut confirm = quote! { ::gemini_adk_rs::frame::ConfirmPolicy::Never };
489 let mut pii = false;
490 let mut min: Option<f64> = None;
491 let mut max: Option<f64> = None;
492 let mut non_empty = false;
493
494 let recognizer = match field.attrs.iter().find(|a| a.path().is_ident("recognize")) {
496 Some(attr) => {
497 let r = slot_recognizer_expr(attr)?;
498 quote! { Some(#r) }
499 }
500 None => quote! { None },
501 };
502
503 for attr in &field.attrs {
504 if !attr.path().is_ident("slot") {
505 continue;
506 }
507 attr.parse_nested_meta(|meta| {
508 if meta.path.is_ident("prompt") {
509 let v: LitStr = meta.value()?.parse()?;
510 prompt = Some(v.value());
511 } else if meta.path.is_ident("reprompt") {
512 let v: LitStr = meta.value()?.parse()?;
513 reprompt = Some(v.value());
514 } else if meta.path.is_ident("state") {
515 let v: LitStr = meta.value()?.parse()?;
516 state_key = v.value();
517 } else if meta.path.is_ident("confirm") {
518 let v: LitStr = meta.value()?.parse()?;
519 confirm = match v.value().as_str() {
520 "never" => quote! { ::gemini_adk_rs::frame::ConfirmPolicy::Never },
521 "low_confidence" => {
522 quote! { ::gemini_adk_rs::frame::ConfirmPolicy::LowConfidence }
523 }
524 "always" => quote! { ::gemini_adk_rs::frame::ConfirmPolicy::Always },
525 other => {
526 return Err(meta.error(format!(
527 "unknown confirm policy '{other}' (expected never/low_confidence/always)"
528 )))
529 }
530 };
531 } else if meta.path.is_ident("pii") {
532 pii = true;
533 } else if meta.path.is_ident("min") {
534 min = Some(lit_to_f64(&meta.value()?.parse()?)?);
535 } else if meta.path.is_ident("max") {
536 max = Some(lit_to_f64(&meta.value()?.parse()?)?);
537 } else if meta.path.is_ident("non_empty") {
538 non_empty = true;
539 } else {
540 return Err(meta.error(
541 "unknown `slot` option (expected prompt/reprompt/state/confirm/pii/min/max/non_empty)",
542 ));
543 }
544 Ok(())
545 })?;
546 }
547
548 let validate = if min.is_some() || max.is_some() {
550 let min_tok = match min {
551 Some(v) => quote! { Some(#v) },
552 None => quote! { None },
553 };
554 let max_tok = match max {
555 Some(v) => quote! { Some(#v) },
556 None => quote! { None },
557 };
558 quote! { Some(::gemini_adk_rs::frame::SlotValidator::Range { min: #min_tok, max: #max_tok }) }
559 } else if non_empty {
560 quote! { Some(::gemini_adk_rs::frame::SlotValidator::NonEmpty) }
561 } else {
562 quote! { None }
563 };
564
565 let prompt_tok = match prompt {
566 Some(p) => quote! { Some(#p.to_string()) },
567 None => quote! { None },
568 };
569 let reprompt_tok = match reprompt {
570 Some(p) => quote! { Some(#p.to_string()) },
571 None => quote! { None },
572 };
573 slot_exprs.push(quote! {
574 ::gemini_adk_rs::frame::SlotSpec {
575 name: #fname.to_string(),
576 state_key: #state_key.to_string(),
577 prompt: #prompt_tok,
578 reprompt: #reprompt_tok,
579 confirm: #confirm,
580 pii: #pii,
581 recognizer: #recognizer,
582 validate: #validate,
583 }
584 });
585 }
586
587 let doc = format!("The `FrameSpec` derived from `{ident}`'s `#[slot(..)]` fields.");
588 Ok(quote! {
589 impl ::gemini_adk_rs::frame::Frame for #ident {
590 #[doc = #doc]
591 fn frame() -> ::gemini_adk_rs::frame::FrameSpec {
592 ::gemini_adk_rs::frame::FrameSpec {
593 name: #name.to_string(),
594 slots: ::std::vec![ #(#slot_exprs),* ],
595 }
596 }
597 }
598
599 impl #ident {
600 #[allow(dead_code)]
601 #[doc(hidden)]
602 fn __frame_mark_fields_used(&self) {
603 #( let _ = &self.#all_field_idents; )*
604 }
605 }
606 })
607}
608
609fn recognizer_expr(attr: &syn::Attribute) -> syn::Result<proc_macro2::TokenStream> {
611 let r = quote! { ::gemini_adk_rs::extract::Recognizer };
612 let meta: Meta = attr.parse_args()?;
613 match meta {
614 Meta::Path(p) => {
615 let id = p
616 .get_ident()
617 .ok_or_else(|| syn::Error::new_spanned(&p, "expected a recognizer name"))?;
618 match id.to_string().as_str() {
619 "integer" => Ok(quote! { #r::integer() }),
620 "money" => Ok(quote! { #r::money() }),
621 "yes_no" => Ok(quote! { #r::yes_no() }),
622 "datetime" => Ok(quote! { #r::datetime() }),
623 other => Err(syn::Error::new_spanned(
624 &p,
625 format!("unknown recognizer `{other}`"),
626 )),
627 }
628 }
629 Meta::NameValue(nv) => {
630 let id = nv
631 .path
632 .get_ident()
633 .ok_or_else(|| syn::Error::new_spanned(&nv.path, "expected a recognizer name"))?;
634 match id.to_string().as_str() {
635 "integer_near" => {
636 let a = str_array(&nv.value)?;
637 Ok(quote! { #r::integer_near([ #(#a),* ]) })
638 }
639 "one_of" => {
640 let a = str_array(&nv.value)?;
641 Ok(quote! { #r::one_of([ #(#a),* ]) })
642 }
643 "fuzzy" => {
644 let a = str_array(&nv.value)?;
645 Ok(quote! { #r::fuzzy([ #(#a),* ]) })
646 }
647 "regex" => {
648 let s = str_lit(&nv.value)?;
649 Ok(quote! { #r::regex(#s) })
650 }
651 other => Err(syn::Error::new_spanned(
652 &nv.path,
653 format!("`{other}` does not take a value"),
654 )),
655 }
656 }
657 Meta::List(l) => Err(syn::Error::new_spanned(
658 l,
659 "unexpected nested list in `#[recognize(..)]`",
660 )),
661 }
662}
663
664fn slot_recognizer_expr(attr: &syn::Attribute) -> syn::Result<proc_macro2::TokenStream> {
667 let r = quote! { ::gemini_adk_rs::frame::SlotRecognizer };
668 let meta: Meta = attr.parse_args()?;
669 match meta {
670 Meta::Path(p) => {
671 let id = p
672 .get_ident()
673 .ok_or_else(|| syn::Error::new_spanned(&p, "expected a recognizer name"))?;
674 match id.to_string().as_str() {
675 "integer" => Ok(quote! { #r::Integer }),
676 "money" => Ok(quote! { #r::Money }),
677 "yes_no" => Ok(quote! { #r::YesNo }),
678 "datetime" => Ok(quote! { #r::DateTime }),
679 other => Err(syn::Error::new_spanned(
680 &p,
681 format!("unknown recognizer `{other}`"),
682 )),
683 }
684 }
685 Meta::NameValue(nv) => {
686 let id = nv
687 .path
688 .get_ident()
689 .ok_or_else(|| syn::Error::new_spanned(&nv.path, "expected a recognizer name"))?;
690 match id.to_string().as_str() {
691 "integer_near" => {
692 let a = str_array(&nv.value)?;
693 Ok(quote! { #r::IntegerNear(::std::vec![ #(#a.to_string()),* ]) })
694 }
695 "one_of" => {
696 let a = str_array(&nv.value)?;
697 Ok(quote! { #r::OneOf(::std::vec![ #(#a.to_string()),* ]) })
698 }
699 "fuzzy" => {
700 let a = str_array(&nv.value)?;
701 Ok(quote! { #r::Fuzzy(::std::vec![ #(#a.to_string()),* ]) })
702 }
703 "regex" => {
704 let s = str_lit(&nv.value)?;
705 Ok(quote! { #r::Regex(#s.to_string()) })
706 }
707 other => Err(syn::Error::new_spanned(
708 &nv.path,
709 format!("`{other}` does not take a value"),
710 )),
711 }
712 }
713 Meta::List(l) => Err(syn::Error::new_spanned(
714 l,
715 "unexpected nested list in `#[recognize(..)]`",
716 )),
717 }
718}
719
720fn lit_to_f64(lit: &Lit) -> syn::Result<f64> {
722 match lit {
723 Lit::Int(i) => i.base10_parse::<f64>(),
724 Lit::Float(f) => f.base10_parse::<f64>(),
725 other => Err(syn::Error::new_spanned(
726 other,
727 "expected a numeric literal for `min`/`max`",
728 )),
729 }
730}
731
732fn str_array(expr: &Expr) -> syn::Result<Vec<LitStr>> {
734 match expr {
735 Expr::Array(arr) => arr
736 .elems
737 .iter()
738 .map(|e| match e {
739 Expr::Lit(ExprLit {
740 lit: Lit::Str(s), ..
741 }) => Ok(s.clone()),
742 other => Err(syn::Error::new_spanned(
743 other,
744 "expected a string literal in the array",
745 )),
746 })
747 .collect(),
748 other => Err(syn::Error::new_spanned(
749 other,
750 "expected an array of string literals, e.g. [\"a\", \"b\"]",
751 )),
752 }
753}
754
755fn str_lit(expr: &Expr) -> syn::Result<LitStr> {
757 match expr {
758 Expr::Lit(ExprLit {
759 lit: Lit::Str(s), ..
760 }) => Ok(s.clone()),
761 other => Err(syn::Error::new_spanned(other, "expected a string literal")),
762 }
763}
764
765fn to_snake_case(s: &str) -> String {
767 let mut out = String::with_capacity(s.len() + 4);
768 for (i, ch) in s.chars().enumerate() {
769 if ch.is_uppercase() {
770 if i != 0 {
771 out.push('_');
772 }
773 out.extend(ch.to_lowercase());
774 } else {
775 out.push(ch);
776 }
777 }
778 out
779}
780
781fn is_option(ty: &Type) -> bool {
789 let Type::Path(TypePath { qself: None, path }) = ty else {
790 return false;
791 };
792 if path
794 .segments
795 .iter()
796 .rev()
797 .skip(1)
798 .any(|seg| !seg.arguments.is_none())
799 {
800 return false;
801 }
802 let idents: Vec<&syn::Ident> = path.segments.iter().map(|seg| &seg.ident).collect();
803 match idents.as_slice() {
804 [opt] => path.leading_colon.is_none() && *opt == "Option",
807 [module, opt] => path.leading_colon.is_none() && *module == "option" && *opt == "Option",
808 [root, module, opt] => {
810 (*root == "std" || *root == "core") && *module == "option" && *opt == "Option"
811 }
812 _ => false,
813 }
814}
815
816fn to_pascal_case(s: &str) -> String {
818 let mut out = String::with_capacity(s.len());
819 let mut upper_next = true;
820 for ch in s.chars() {
821 if ch == '_' {
822 upper_next = true;
823 } else if upper_next {
824 out.extend(ch.to_uppercase());
825 upper_next = false;
826 } else {
827 out.push(ch);
828 }
829 }
830 out
831}
832
833#[cfg(test)]
834mod tests {
835 use super::is_option;
836 use syn::parse_quote;
837
838 #[test]
839 fn is_option_accepts_std_core_paths() {
840 assert!(is_option(&parse_quote!(Option<String>)));
841 assert!(is_option(&parse_quote!(option::Option<String>)));
842 assert!(is_option(&parse_quote!(std::option::Option<String>)));
843 assert!(is_option(&parse_quote!(core::option::Option<String>)));
844 assert!(is_option(&parse_quote!(::std::option::Option<String>)));
845 assert!(is_option(&parse_quote!(::core::option::Option<String>)));
846 }
847
848 #[test]
849 fn is_option_rejects_lookalikes() {
850 assert!(!is_option(&parse_quote!(String)));
851 assert!(!is_option(&parse_quote!(Vec<Option<String>>)));
852 assert!(!is_option(&parse_quote!(my::Option<String>)));
853 assert!(!is_option(&parse_quote!(my::option::Option<String>)));
854 assert!(!is_option(&parse_quote!(::option::Option<String>)));
855 assert!(!is_option(&parse_quote!(<T as Trait>::Option)));
856 }
857}