gemini_genai_rs/generate/
mod.rs

1//! generateContent and streamGenerateContent REST API.
2//!
3//! This module provides typed request/response types and a client for the
4//! Gemini generateContent REST API. Feature-gated behind `generate`.
5//!
6//! # Usage
7//!
8//! ```ignore
9//! use gemini_genai_rs::prelude::*;
10//!
11//! let client = Client::from_api_key("your-key")
12//!     .model(GeminiModel::Custom("gemini-2.5-flash".into()));
13//!
14//! let response = client.generate_content("What is Rust?").await?;
15//! println!("{}", response.text().unwrap_or_default());
16//! ```
17
18mod config;
19mod response;
20
21pub use config::GenerateContentConfig;
22pub use response::{BlockReason, Candidate, GenerateContentResponse, PromptFeedback};
23
24use crate::client::http::HttpError;
25use crate::client::Client;
26use crate::protocol::types::GeminiModel;
27use crate::transport::auth::ServiceEndpoint;
28
29impl Client {
30    /// Generate content from a text prompt using the default model.
31    pub async fn generate_content(
32        &self,
33        prompt: impl Into<String>,
34    ) -> Result<GenerateContentResponse, GenerateError> {
35        let config = GenerateContentConfig::from_text(prompt);
36        self.generate_content_with(config, None).await
37    }
38
39    /// Generate content with full configuration and optional model override.
40    pub async fn generate_content_with(
41        &self,
42        config: GenerateContentConfig,
43        model: Option<&GeminiModel>,
44    ) -> Result<GenerateContentResponse, GenerateError> {
45        let model = model.unwrap_or(self.default_model());
46        let url = self.rest_url_for(ServiceEndpoint::GenerateContent, model);
47        let headers = self
48            .auth_headers()
49            .await
50            .map_err(|e| GenerateError::Auth(e.to_string()))?;
51
52        let body = config.to_request_body();
53        let json = self
54            .http_client()
55            .post_json(&url, headers, &body)
56            .await
57            .map_err(GenerateError::from)?;
58
59        let response: GenerateContentResponse = serde_json::from_value(json)?;
60        Ok(response)
61    }
62}
63
64/// Errors specific to the Generate API.
65#[derive(Debug, thiserror::Error)]
66pub enum GenerateError {
67    /// HTTP transport error.
68    #[error(transparent)]
69    Http(#[from] HttpError),
70
71    /// JSON deserialization error.
72    #[error("Failed to parse response: {0}")]
73    Parse(#[from] serde_json::Error),
74
75    /// Authentication error.
76    #[error("Auth error: {0}")]
77    Auth(String),
78
79    /// Content was blocked by safety filters.
80    #[error("Content blocked: {reason:?}")]
81    SafetyBlocked {
82        /// The reason the content was blocked.
83        reason: BlockReason,
84    },
85
86    /// Prompt was rejected.
87    #[error("Prompt blocked: {reason:?}")]
88    PromptBlocked {
89        /// The reason the prompt was blocked.
90        reason: BlockReason,
91    },
92}
93
94#[cfg(test)]
95mod tests {
96    use super::*;
97
98    #[test]
99    fn generate_error_display() {
100        let err = GenerateError::SafetyBlocked {
101            reason: BlockReason::Safety,
102        };
103        assert!(err.to_string().contains("blocked"));
104    }
105
106    #[test]
107    fn generate_content_config_from_text() {
108        let config = GenerateContentConfig::from_text("Hello");
109        let body = config.to_request_body();
110        let contents = body.get("contents").unwrap();
111        assert!(contents.is_array());
112        let parts = contents[0].get("parts").unwrap();
113        assert!(parts[0].get("text").unwrap().as_str().unwrap() == "Hello");
114    }
115
116    #[test]
117    fn generate_content_config_with_system() {
118        let config = GenerateContentConfig::from_text("Hello")
119            .system_instruction("You are a helpful assistant");
120        let body = config.to_request_body();
121        assert!(body.get("systemInstruction").is_some());
122    }
123
124    #[test]
125    fn parse_generate_response() {
126        let json = serde_json::json!({
127            "candidates": [{
128                "content": {
129                    "parts": [{"text": "Hello world!"}],
130                    "role": "model"
131                },
132                "finishReason": "STOP",
133                "safetyRatings": [{
134                    "category": "HARM_CATEGORY_HARASSMENT",
135                    "probability": "NEGLIGIBLE"
136                }]
137            }],
138            "usageMetadata": {
139                "promptTokenCount": 5,
140                "candidatesTokenCount": 10,
141                "totalTokenCount": 15
142            }
143        });
144
145        let resp: GenerateContentResponse = serde_json::from_value(json).unwrap();
146        assert_eq!(resp.candidates.len(), 1);
147        assert_eq!(resp.text().unwrap(), "Hello world!");
148        assert!(resp.usage_metadata.is_some());
149    }
150}