gemini_genai_rs/generate/
response.rs1use serde::{Deserialize, Serialize};
4
5use crate::protocol::types::{
6 CitationMetadata, Content, FinishReason, GroundingMetadata, SafetyRating, UsageMetadata,
7};
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
11#[serde(rename_all = "camelCase")]
12pub struct GenerateContentResponse {
13 #[serde(default)]
15 pub candidates: Vec<Candidate>,
16
17 #[serde(skip_serializing_if = "Option::is_none")]
19 pub prompt_feedback: Option<PromptFeedback>,
20
21 #[serde(skip_serializing_if = "Option::is_none")]
23 pub usage_metadata: Option<UsageMetadata>,
24
25 #[serde(skip_serializing_if = "Option::is_none")]
27 pub model_version: Option<String>,
28}
29
30impl GenerateContentResponse {
31 pub fn text(&self) -> Option<&str> {
35 self.candidates
36 .first()
37 .and_then(|c| c.content.as_ref())
38 .and_then(|content| {
39 content.parts.iter().find_map(|part| {
40 if let crate::protocol::types::Part::Text { text } = part {
41 Some(text.as_str())
42 } else {
43 None
44 }
45 })
46 })
47 }
48
49 pub fn is_prompt_blocked(&self) -> bool {
51 self.prompt_feedback
52 .as_ref()
53 .and_then(|f| f.block_reason.as_ref())
54 .is_some()
55 }
56
57 pub fn finish_reason(&self) -> Option<FinishReason> {
59 self.candidates.first().and_then(|c| c.finish_reason)
60 }
61
62 pub fn function_calls(&self) -> Vec<&crate::protocol::types::FunctionCall> {
64 self.candidates
65 .first()
66 .and_then(|c| c.content.as_ref())
67 .map(|content| {
68 content
69 .parts
70 .iter()
71 .filter_map(|part| {
72 if let crate::protocol::types::Part::FunctionCall { function_call } = part {
73 Some(function_call)
74 } else {
75 None
76 }
77 })
78 .collect()
79 })
80 .unwrap_or_default()
81 }
82}
83
84#[derive(Debug, Clone, Serialize, Deserialize)]
86#[serde(rename_all = "camelCase")]
87pub struct Candidate {
88 #[serde(skip_serializing_if = "Option::is_none")]
90 pub content: Option<Content>,
91
92 #[serde(skip_serializing_if = "Option::is_none")]
94 pub finish_reason: Option<FinishReason>,
95
96 #[serde(default)]
98 pub safety_ratings: Vec<SafetyRating>,
99
100 #[serde(skip_serializing_if = "Option::is_none")]
102 pub citation_metadata: Option<CitationMetadata>,
103
104 #[serde(skip_serializing_if = "Option::is_none")]
106 pub token_count: Option<u32>,
107
108 #[serde(skip_serializing_if = "Option::is_none")]
110 pub grounding_metadata: Option<GroundingMetadata>,
111
112 #[serde(skip_serializing_if = "Option::is_none")]
114 pub index: Option<u32>,
115}
116
117#[derive(Debug, Clone, Serialize, Deserialize)]
119#[serde(rename_all = "camelCase")]
120pub struct PromptFeedback {
121 #[serde(skip_serializing_if = "Option::is_none")]
123 pub block_reason: Option<BlockReason>,
124
125 #[serde(default)]
127 pub safety_ratings: Vec<SafetyRating>,
128}
129
130#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
132#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
133pub enum BlockReason {
134 BlockReasonUnspecified,
136 Safety,
138 Other,
140 Blocklist,
142 ProhibitedContent,
144}
145
146#[cfg(test)]
147mod tests {
148 use super::*;
149
150 #[test]
151 fn parse_minimal_response() {
152 let json = serde_json::json!({
153 "candidates": [{
154 "content": {
155 "parts": [{"text": "Hi"}],
156 "role": "model"
157 },
158 "finishReason": "STOP"
159 }]
160 });
161 let resp: GenerateContentResponse = serde_json::from_value(json).unwrap();
162 assert_eq!(resp.text().unwrap(), "Hi");
163 assert_eq!(resp.finish_reason(), Some(FinishReason::Stop));
164 assert!(!resp.is_prompt_blocked());
165 }
166
167 #[test]
168 fn parse_blocked_prompt() {
169 let json = serde_json::json!({
170 "candidates": [],
171 "promptFeedback": {
172 "blockReason": "SAFETY",
173 "safetyRatings": []
174 }
175 });
176 let resp: GenerateContentResponse = serde_json::from_value(json).unwrap();
177 assert!(resp.is_prompt_blocked());
178 assert!(resp.text().is_none());
179 }
180
181 #[test]
182 fn parse_with_function_calls() {
183 let json = serde_json::json!({
184 "candidates": [{
185 "content": {
186 "parts": [{
187 "functionCall": {
188 "name": "get_weather",
189 "args": {"city": "London"}
190 }
191 }],
192 "role": "model"
193 },
194 "finishReason": "STOP"
195 }]
196 });
197 let resp: GenerateContentResponse = serde_json::from_value(json).unwrap();
198 let fns = resp.function_calls();
199 assert_eq!(fns.len(), 1);
200 assert_eq!(fns[0].name, "get_weather");
201 }
202
203 #[test]
204 fn parse_with_usage_metadata() {
205 let json = serde_json::json!({
206 "candidates": [{
207 "content": {
208 "parts": [{"text": "Ok"}]
209 },
210 "finishReason": "STOP"
211 }],
212 "usageMetadata": {
213 "promptTokenCount": 10,
214 "candidatesTokenCount": 5,
215 "totalTokenCount": 15
216 }
217 });
218 let resp: GenerateContentResponse = serde_json::from_value(json).unwrap();
219 let usage = resp.usage_metadata.unwrap();
220 assert_eq!(usage.prompt_token_count, Some(10));
221 assert_eq!(usage.total_token_count, Some(15));
222 }
223
224 #[test]
225 fn parse_with_safety_ratings() {
226 let json = serde_json::json!({
227 "candidates": [{
228 "content": {
229 "parts": [{"text": "Hello"}]
230 },
231 "finishReason": "STOP",
232 "safetyRatings": [{
233 "category": "HARM_CATEGORY_HARASSMENT",
234 "probability": "NEGLIGIBLE"
235 }, {
236 "category": "HARM_CATEGORY_HATE_SPEECH",
237 "probability": "LOW"
238 }]
239 }]
240 });
241 let resp: GenerateContentResponse = serde_json::from_value(json).unwrap();
242 assert_eq!(resp.candidates[0].safety_ratings.len(), 2);
243 }
244
245 #[test]
246 fn parse_unknown_finish_reason() {
247 let json = serde_json::json!({
248 "candidates": [{
249 "content": {"parts": [{"text": "x"}]},
250 "finishReason": "SOME_FUTURE_REASON"
251 }]
252 });
253 let resp: GenerateContentResponse = serde_json::from_value(json).unwrap();
254 assert_eq!(
255 resp.finish_reason(),
256 Some(FinishReason::FinishReasonUnspecified)
257 );
258 }
259}