gemini_genai_rs/generate/
config.rs1use crate::protocol::types::{Content, GenerationConfig, Part, SafetySetting, Tool, ToolConfig};
4
5#[derive(Debug, Clone)]
10pub struct GenerateContentConfig {
11 pub contents: Vec<Content>,
13 pub generation_config: Option<GenerationConfig>,
15 pub safety_settings: Vec<SafetySetting>,
17 pub tools: Vec<Tool>,
19 pub tool_config: Option<ToolConfig>,
21 pub system_instruction: Option<Content>,
23}
24
25impl GenerateContentConfig {
26 pub fn from_text(text: impl Into<String>) -> Self {
28 Self {
29 contents: vec![Content::user(text)],
30 generation_config: None,
31 safety_settings: vec![],
32 tools: vec![],
33 tool_config: None,
34 system_instruction: None,
35 }
36 }
37
38 pub fn from_parts(parts: Vec<Part>) -> Self {
40 Self {
41 contents: vec![Content {
42 role: Some(crate::protocol::types::Role::User),
43 parts,
44 }],
45 generation_config: None,
46 safety_settings: vec![],
47 tools: vec![],
48 tool_config: None,
49 system_instruction: None,
50 }
51 }
52
53 pub fn from_contents(contents: Vec<Content>) -> Self {
55 Self {
56 contents,
57 generation_config: None,
58 safety_settings: vec![],
59 tools: vec![],
60 tool_config: None,
61 system_instruction: None,
62 }
63 }
64
65 pub fn generation_config(mut self, config: GenerationConfig) -> Self {
67 self.generation_config = Some(config);
68 self
69 }
70
71 pub fn temperature(mut self, temp: f32) -> Self {
73 self.generation_config
74 .get_or_insert_with(GenerationConfig::default)
75 .temperature = Some(temp);
76 self
77 }
78
79 pub fn max_output_tokens(mut self, max: u32) -> Self {
81 self.generation_config
82 .get_or_insert_with(GenerationConfig::default)
83 .max_output_tokens = Some(max);
84 self
85 }
86
87 pub fn top_p(mut self, top_p: f32) -> Self {
89 self.generation_config
90 .get_or_insert_with(GenerationConfig::default)
91 .top_p = Some(top_p);
92 self
93 }
94
95 pub fn top_k(mut self, top_k: u32) -> Self {
97 self.generation_config
98 .get_or_insert_with(GenerationConfig::default)
99 .top_k = Some(top_k);
100 self
101 }
102
103 pub fn safety_setting(mut self, setting: SafetySetting) -> Self {
105 self.safety_settings.push(setting);
106 self
107 }
108
109 pub fn tool(mut self, tool: Tool) -> Self {
111 self.tools.push(tool);
112 self
113 }
114
115 pub fn tool_config(mut self, config: ToolConfig) -> Self {
117 self.tool_config = Some(config);
118 self
119 }
120
121 pub fn json_output(mut self, schema: Option<serde_json::Value>) -> Self {
127 let gc = self
128 .generation_config
129 .get_or_insert_with(GenerationConfig::default);
130 gc.response_mime_type = Some("application/json".to_string());
131 gc.response_json_schema = schema;
132 self
133 }
134
135 pub fn system_instruction(mut self, text: impl Into<String>) -> Self {
137 self.system_instruction = Some(Content {
138 role: None,
139 parts: vec![Part::text(text)],
140 });
141 self
142 }
143
144 pub fn to_request_body(&self) -> serde_json::Value {
146 let mut body = serde_json::json!({
147 "contents": self.contents,
148 });
149
150 if let Some(ref gc) = self.generation_config {
151 body["generationConfig"] = serde_json::to_value(gc).unwrap_or_default();
152 }
153
154 if !self.safety_settings.is_empty() {
155 body["safetySettings"] =
156 serde_json::to_value(&self.safety_settings).unwrap_or_default();
157 }
158
159 if !self.tools.is_empty() {
160 body["tools"] = serde_json::to_value(&self.tools).unwrap_or_default();
161 }
162
163 if let Some(ref tc) = self.tool_config {
164 body["toolConfig"] = serde_json::to_value(tc).unwrap_or_default();
165 }
166
167 if let Some(ref si) = self.system_instruction {
168 body["systemInstruction"] = serde_json::to_value(si).unwrap_or_default();
169 }
170
171 body
172 }
173}
174
175#[cfg(test)]
176mod tests {
177 use super::*;
178 use crate::protocol::types::{HarmBlockThreshold, HarmCategory};
179
180 #[test]
181 fn from_text_basic() {
182 let config = GenerateContentConfig::from_text("Hello");
183 assert_eq!(config.contents.len(), 1);
184 let body = config.to_request_body();
185 let text = body["contents"][0]["parts"][0]["text"].as_str().unwrap();
186 assert_eq!(text, "Hello");
187 }
188
189 #[test]
190 fn with_temperature_and_max_tokens() {
191 let config = GenerateContentConfig::from_text("Hello")
192 .temperature(0.5)
193 .max_output_tokens(1024);
194 let body = config.to_request_body();
195 assert_eq!(body["generationConfig"]["temperature"], 0.5);
196 assert_eq!(body["generationConfig"]["maxOutputTokens"], 1024);
197 }
198
199 #[test]
200 fn with_safety_settings() {
201 let config = GenerateContentConfig::from_text("Hello").safety_setting(SafetySetting {
202 category: HarmCategory::HarmCategoryHarassment,
203 threshold: HarmBlockThreshold::BlockOnlyHigh,
204 });
205 let body = config.to_request_body();
206 assert!(body["safetySettings"].is_array());
207 assert_eq!(
208 body["safetySettings"][0]["category"],
209 "HARM_CATEGORY_HARASSMENT"
210 );
211 }
212
213 #[test]
214 fn with_system_instruction() {
215 let config =
216 GenerateContentConfig::from_text("Hello").system_instruction("You are helpful");
217 let body = config.to_request_body();
218 let si = &body["systemInstruction"];
219 assert!(si["parts"][0]["text"].as_str().unwrap().contains("helpful"));
220 }
221}