gemini_genai_rs/generate/
response.rs

1//! Response types for generateContent.
2
3use serde::{Deserialize, Serialize};
4
5use crate::protocol::types::{
6    CitationMetadata, Content, FinishReason, GroundingMetadata, SafetyRating, UsageMetadata,
7};
8
9/// Top-level response from generateContent.
10#[derive(Debug, Clone, Serialize, Deserialize)]
11#[serde(rename_all = "camelCase")]
12pub struct GenerateContentResponse {
13    /// Response candidates (usually 1).
14    #[serde(default)]
15    pub candidates: Vec<Candidate>,
16
17    /// Feedback about the prompt (may indicate blocking).
18    #[serde(skip_serializing_if = "Option::is_none")]
19    pub prompt_feedback: Option<PromptFeedback>,
20
21    /// Token usage statistics.
22    #[serde(skip_serializing_if = "Option::is_none")]
23    pub usage_metadata: Option<UsageMetadata>,
24
25    /// Model version that generated the response.
26    #[serde(skip_serializing_if = "Option::is_none")]
27    pub model_version: Option<String>,
28}
29
30impl GenerateContentResponse {
31    /// Extract the text from the first candidate's first text part.
32    ///
33    /// Returns `None` if there are no candidates or no text parts.
34    pub fn text(&self) -> Option<&str> {
35        self.candidates
36            .first()
37            .and_then(|c| c.content.as_ref())
38            .and_then(|content| {
39                content.parts.iter().find_map(|part| {
40                    if let crate::protocol::types::Part::Text { text } = part {
41                        Some(text.as_str())
42                    } else {
43                        None
44                    }
45                })
46            })
47    }
48
49    /// Check if the prompt was blocked.
50    pub fn is_prompt_blocked(&self) -> bool {
51        self.prompt_feedback
52            .as_ref()
53            .and_then(|f| f.block_reason.as_ref())
54            .is_some()
55    }
56
57    /// Get the finish reason of the first candidate.
58    pub fn finish_reason(&self) -> Option<FinishReason> {
59        self.candidates.first().and_then(|c| c.finish_reason)
60    }
61
62    /// Get all function calls from the first candidate.
63    pub fn function_calls(&self) -> Vec<&crate::protocol::types::FunctionCall> {
64        self.candidates
65            .first()
66            .and_then(|c| c.content.as_ref())
67            .map(|content| {
68                content
69                    .parts
70                    .iter()
71                    .filter_map(|part| {
72                        if let crate::protocol::types::Part::FunctionCall { function_call } = part {
73                            Some(function_call)
74                        } else {
75                            None
76                        }
77                    })
78                    .collect()
79            })
80            .unwrap_or_default()
81    }
82}
83
84/// A single response candidate.
85#[derive(Debug, Clone, Serialize, Deserialize)]
86#[serde(rename_all = "camelCase")]
87pub struct Candidate {
88    /// The generated content.
89    #[serde(skip_serializing_if = "Option::is_none")]
90    pub content: Option<Content>,
91
92    /// Why the model stopped generating.
93    #[serde(skip_serializing_if = "Option::is_none")]
94    pub finish_reason: Option<FinishReason>,
95
96    /// Safety ratings for this candidate.
97    #[serde(default)]
98    pub safety_ratings: Vec<SafetyRating>,
99
100    /// Citation information.
101    #[serde(skip_serializing_if = "Option::is_none")]
102    pub citation_metadata: Option<CitationMetadata>,
103
104    /// Token count for this candidate.
105    #[serde(skip_serializing_if = "Option::is_none")]
106    pub token_count: Option<u32>,
107
108    /// Grounding metadata (when search grounding is used).
109    #[serde(skip_serializing_if = "Option::is_none")]
110    pub grounding_metadata: Option<GroundingMetadata>,
111
112    /// Candidate index.
113    #[serde(skip_serializing_if = "Option::is_none")]
114    pub index: Option<u32>,
115}
116
117/// Feedback about the prompt.
118#[derive(Debug, Clone, Serialize, Deserialize)]
119#[serde(rename_all = "camelCase")]
120pub struct PromptFeedback {
121    /// If set, the prompt was blocked for this reason.
122    #[serde(skip_serializing_if = "Option::is_none")]
123    pub block_reason: Option<BlockReason>,
124
125    /// Safety ratings for the prompt.
126    #[serde(default)]
127    pub safety_ratings: Vec<SafetyRating>,
128}
129
130/// Reason a prompt was blocked.
131#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
132#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
133pub enum BlockReason {
134    /// Block reason not specified.
135    BlockReasonUnspecified,
136    /// Blocked due to safety filters.
137    Safety,
138    /// Blocked for other reasons.
139    Other,
140    /// Blocked due to blocklist match.
141    Blocklist,
142    /// Blocked due to prohibited content.
143    ProhibitedContent,
144}
145
146#[cfg(test)]
147mod tests {
148    use super::*;
149
150    #[test]
151    fn parse_minimal_response() {
152        let json = serde_json::json!({
153            "candidates": [{
154                "content": {
155                    "parts": [{"text": "Hi"}],
156                    "role": "model"
157                },
158                "finishReason": "STOP"
159            }]
160        });
161        let resp: GenerateContentResponse = serde_json::from_value(json).unwrap();
162        assert_eq!(resp.text().unwrap(), "Hi");
163        assert_eq!(resp.finish_reason(), Some(FinishReason::Stop));
164        assert!(!resp.is_prompt_blocked());
165    }
166
167    #[test]
168    fn parse_blocked_prompt() {
169        let json = serde_json::json!({
170            "candidates": [],
171            "promptFeedback": {
172                "blockReason": "SAFETY",
173                "safetyRatings": []
174            }
175        });
176        let resp: GenerateContentResponse = serde_json::from_value(json).unwrap();
177        assert!(resp.is_prompt_blocked());
178        assert!(resp.text().is_none());
179    }
180
181    #[test]
182    fn parse_with_function_calls() {
183        let json = serde_json::json!({
184            "candidates": [{
185                "content": {
186                    "parts": [{
187                        "functionCall": {
188                            "name": "get_weather",
189                            "args": {"city": "London"}
190                        }
191                    }],
192                    "role": "model"
193                },
194                "finishReason": "STOP"
195            }]
196        });
197        let resp: GenerateContentResponse = serde_json::from_value(json).unwrap();
198        let fns = resp.function_calls();
199        assert_eq!(fns.len(), 1);
200        assert_eq!(fns[0].name, "get_weather");
201    }
202
203    #[test]
204    fn parse_with_usage_metadata() {
205        let json = serde_json::json!({
206            "candidates": [{
207                "content": {
208                    "parts": [{"text": "Ok"}]
209                },
210                "finishReason": "STOP"
211            }],
212            "usageMetadata": {
213                "promptTokenCount": 10,
214                "candidatesTokenCount": 5,
215                "totalTokenCount": 15
216            }
217        });
218        let resp: GenerateContentResponse = serde_json::from_value(json).unwrap();
219        let usage = resp.usage_metadata.unwrap();
220        assert_eq!(usage.prompt_token_count, Some(10));
221        assert_eq!(usage.total_token_count, Some(15));
222    }
223
224    #[test]
225    fn parse_with_safety_ratings() {
226        let json = serde_json::json!({
227            "candidates": [{
228                "content": {
229                    "parts": [{"text": "Hello"}]
230                },
231                "finishReason": "STOP",
232                "safetyRatings": [{
233                    "category": "HARM_CATEGORY_HARASSMENT",
234                    "probability": "NEGLIGIBLE"
235                }, {
236                    "category": "HARM_CATEGORY_HATE_SPEECH",
237                    "probability": "LOW"
238                }]
239            }]
240        });
241        let resp: GenerateContentResponse = serde_json::from_value(json).unwrap();
242        assert_eq!(resp.candidates[0].safety_ratings.len(), 2);
243    }
244
245    #[test]
246    fn parse_unknown_finish_reason() {
247        let json = serde_json::json!({
248            "candidates": [{
249                "content": {"parts": [{"text": "x"}]},
250                "finishReason": "SOME_FUTURE_REASON"
251            }]
252        });
253        let resp: GenerateContentResponse = serde_json::from_value(json).unwrap();
254        assert_eq!(
255            resp.finish_reason(),
256            Some(FinishReason::FinishReasonUnspecified)
257        );
258    }
259}