gemini_genai_rs/generate/
config.rs

1//! Request configuration for generateContent.
2
3use crate::protocol::types::{Content, GenerationConfig, Part, SafetySetting, Tool, ToolConfig};
4
5/// Configuration for a generateContent request.
6///
7/// Wraps the existing `GenerationConfig` plus safety settings, tools,
8/// system instruction, and content turns.
9#[derive(Debug, Clone)]
10pub struct GenerateContentConfig {
11    /// The conversation turns to send.
12    pub contents: Vec<Content>,
13    /// Generation parameters (temperature, top_p, max_output_tokens, etc.).
14    pub generation_config: Option<GenerationConfig>,
15    /// Per-category safety thresholds.
16    pub safety_settings: Vec<SafetySetting>,
17    /// Tools available to the model.
18    pub tools: Vec<Tool>,
19    /// Tool invocation configuration.
20    pub tool_config: Option<ToolConfig>,
21    /// System instruction (prepended to the conversation).
22    pub system_instruction: Option<Content>,
23}
24
25impl GenerateContentConfig {
26    /// Create a config from a simple text prompt.
27    pub fn from_text(text: impl Into<String>) -> Self {
28        Self {
29            contents: vec![Content::user(text)],
30            generation_config: None,
31            safety_settings: vec![],
32            tools: vec![],
33            tool_config: None,
34            system_instruction: None,
35        }
36    }
37
38    /// Create a config from a list of content parts (e.g., text + image).
39    pub fn from_parts(parts: Vec<Part>) -> Self {
40        Self {
41            contents: vec![Content {
42                role: Some(crate::protocol::types::Role::User),
43                parts,
44            }],
45            generation_config: None,
46            safety_settings: vec![],
47            tools: vec![],
48            tool_config: None,
49            system_instruction: None,
50        }
51    }
52
53    /// Create a config from existing conversation contents.
54    pub fn from_contents(contents: Vec<Content>) -> Self {
55        Self {
56            contents,
57            generation_config: None,
58            safety_settings: vec![],
59            tools: vec![],
60            tool_config: None,
61            system_instruction: None,
62        }
63    }
64
65    /// Set generation config.
66    pub fn generation_config(mut self, config: GenerationConfig) -> Self {
67        self.generation_config = Some(config);
68        self
69    }
70
71    /// Set temperature.
72    pub fn temperature(mut self, temp: f32) -> Self {
73        self.generation_config
74            .get_or_insert_with(GenerationConfig::default)
75            .temperature = Some(temp);
76        self
77    }
78
79    /// Set max output tokens.
80    pub fn max_output_tokens(mut self, max: u32) -> Self {
81        self.generation_config
82            .get_or_insert_with(GenerationConfig::default)
83            .max_output_tokens = Some(max);
84        self
85    }
86
87    /// Set top_p.
88    pub fn top_p(mut self, top_p: f32) -> Self {
89        self.generation_config
90            .get_or_insert_with(GenerationConfig::default)
91            .top_p = Some(top_p);
92        self
93    }
94
95    /// Set top_k.
96    pub fn top_k(mut self, top_k: u32) -> Self {
97        self.generation_config
98            .get_or_insert_with(GenerationConfig::default)
99            .top_k = Some(top_k);
100        self
101    }
102
103    /// Add a safety setting.
104    pub fn safety_setting(mut self, setting: SafetySetting) -> Self {
105        self.safety_settings.push(setting);
106        self
107    }
108
109    /// Add a tool.
110    pub fn tool(mut self, tool: Tool) -> Self {
111        self.tools.push(tool);
112        self
113    }
114
115    /// Set tool config.
116    pub fn tool_config(mut self, config: ToolConfig) -> Self {
117        self.tool_config = Some(config);
118        self
119    }
120
121    /// Set JSON output mode with an optional JSON Schema.
122    ///
123    /// Sets `responseMimeType` to `"application/json"` and, if a schema is
124    /// provided, sets `responseJsonSchema` so the model is constrained to
125    /// produce valid JSON matching the schema.
126    pub fn json_output(mut self, schema: Option<serde_json::Value>) -> Self {
127        let gc = self
128            .generation_config
129            .get_or_insert_with(GenerationConfig::default);
130        gc.response_mime_type = Some("application/json".to_string());
131        gc.response_json_schema = schema;
132        self
133    }
134
135    /// Set system instruction from text.
136    pub fn system_instruction(mut self, text: impl Into<String>) -> Self {
137        self.system_instruction = Some(Content {
138            role: None,
139            parts: vec![Part::text(text)],
140        });
141        self
142    }
143
144    /// Serialize to the JSON request body expected by the REST API.
145    pub fn to_request_body(&self) -> serde_json::Value {
146        let mut body = serde_json::json!({
147            "contents": self.contents,
148        });
149
150        if let Some(ref gc) = self.generation_config {
151            body["generationConfig"] = serde_json::to_value(gc).unwrap_or_default();
152        }
153
154        if !self.safety_settings.is_empty() {
155            body["safetySettings"] =
156                serde_json::to_value(&self.safety_settings).unwrap_or_default();
157        }
158
159        if !self.tools.is_empty() {
160            body["tools"] = serde_json::to_value(&self.tools).unwrap_or_default();
161        }
162
163        if let Some(ref tc) = self.tool_config {
164            body["toolConfig"] = serde_json::to_value(tc).unwrap_or_default();
165        }
166
167        if let Some(ref si) = self.system_instruction {
168            body["systemInstruction"] = serde_json::to_value(si).unwrap_or_default();
169        }
170
171        body
172    }
173}
174
175#[cfg(test)]
176mod tests {
177    use super::*;
178    use crate::protocol::types::{HarmBlockThreshold, HarmCategory};
179
180    #[test]
181    fn from_text_basic() {
182        let config = GenerateContentConfig::from_text("Hello");
183        assert_eq!(config.contents.len(), 1);
184        let body = config.to_request_body();
185        let text = body["contents"][0]["parts"][0]["text"].as_str().unwrap();
186        assert_eq!(text, "Hello");
187    }
188
189    #[test]
190    fn with_temperature_and_max_tokens() {
191        let config = GenerateContentConfig::from_text("Hello")
192            .temperature(0.5)
193            .max_output_tokens(1024);
194        let body = config.to_request_body();
195        assert_eq!(body["generationConfig"]["temperature"], 0.5);
196        assert_eq!(body["generationConfig"]["maxOutputTokens"], 1024);
197    }
198
199    #[test]
200    fn with_safety_settings() {
201        let config = GenerateContentConfig::from_text("Hello").safety_setting(SafetySetting {
202            category: HarmCategory::HarmCategoryHarassment,
203            threshold: HarmBlockThreshold::BlockOnlyHigh,
204        });
205        let body = config.to_request_body();
206        assert!(body["safetySettings"].is_array());
207        assert_eq!(
208            body["safetySettings"][0]["category"],
209            "HARM_CATEGORY_HARASSMENT"
210        );
211    }
212
213    #[test]
214    fn with_system_instruction() {
215        let config =
216            GenerateContentConfig::from_text("Hello").system_instruction("You are helpful");
217        let body = config.to_request_body();
218        let si = &body["systemInstruction"];
219        assert!(si["parts"][0]["text"].as_str().unwrap().contains("helpful"));
220    }
221}