gemini_genai_rs/generate/
mod.rs1mod config;
19mod response;
20
21pub use config::GenerateContentConfig;
22pub use response::{BlockReason, Candidate, GenerateContentResponse, PromptFeedback};
23
24use crate::client::http::HttpError;
25use crate::client::Client;
26use crate::protocol::types::GeminiModel;
27use crate::transport::auth::ServiceEndpoint;
28
29impl Client {
30 pub async fn generate_content(
32 &self,
33 prompt: impl Into<String>,
34 ) -> Result<GenerateContentResponse, GenerateError> {
35 let config = GenerateContentConfig::from_text(prompt);
36 self.generate_content_with(config, None).await
37 }
38
39 pub async fn generate_content_with(
41 &self,
42 config: GenerateContentConfig,
43 model: Option<&GeminiModel>,
44 ) -> Result<GenerateContentResponse, GenerateError> {
45 let model = model.unwrap_or(self.default_model());
46 let url = self.rest_url_for(ServiceEndpoint::GenerateContent, model);
47 let headers = self
48 .auth_headers()
49 .await
50 .map_err(|e| GenerateError::Auth(e.to_string()))?;
51
52 let body = config.to_request_body();
53 let json = self
54 .http_client()
55 .post_json(&url, headers, &body)
56 .await
57 .map_err(GenerateError::from)?;
58
59 let response: GenerateContentResponse = serde_json::from_value(json)?;
60 Ok(response)
61 }
62}
63
64#[derive(Debug, thiserror::Error)]
66pub enum GenerateError {
67 #[error(transparent)]
69 Http(#[from] HttpError),
70
71 #[error("Failed to parse response: {0}")]
73 Parse(#[from] serde_json::Error),
74
75 #[error("Auth error: {0}")]
77 Auth(String),
78
79 #[error("Content blocked: {reason:?}")]
81 SafetyBlocked {
82 reason: BlockReason,
84 },
85
86 #[error("Prompt blocked: {reason:?}")]
88 PromptBlocked {
89 reason: BlockReason,
91 },
92}
93
94#[cfg(test)]
95mod tests {
96 use super::*;
97
98 #[test]
99 fn generate_error_display() {
100 let err = GenerateError::SafetyBlocked {
101 reason: BlockReason::Safety,
102 };
103 assert!(err.to_string().contains("blocked"));
104 }
105
106 #[test]
107 fn generate_content_config_from_text() {
108 let config = GenerateContentConfig::from_text("Hello");
109 let body = config.to_request_body();
110 let contents = body.get("contents").unwrap();
111 assert!(contents.is_array());
112 let parts = contents[0].get("parts").unwrap();
113 assert!(parts[0].get("text").unwrap().as_str().unwrap() == "Hello");
114 }
115
116 #[test]
117 fn generate_content_config_with_system() {
118 let config = GenerateContentConfig::from_text("Hello")
119 .system_instruction("You are a helpful assistant");
120 let body = config.to_request_body();
121 assert!(body.get("systemInstruction").is_some());
122 }
123
124 #[test]
125 fn parse_generate_response() {
126 let json = serde_json::json!({
127 "candidates": [{
128 "content": {
129 "parts": [{"text": "Hello world!"}],
130 "role": "model"
131 },
132 "finishReason": "STOP",
133 "safetyRatings": [{
134 "category": "HARM_CATEGORY_HARASSMENT",
135 "probability": "NEGLIGIBLE"
136 }]
137 }],
138 "usageMetadata": {
139 "promptTokenCount": 5,
140 "candidatesTokenCount": 10,
141 "totalTokenCount": 15
142 }
143 });
144
145 let resp: GenerateContentResponse = serde_json::from_value(json).unwrap();
146 assert_eq!(resp.candidates.len(), 1);
147 assert_eq!(resp.text().unwrap(), "Hello world!");
148 assert!(resp.usage_metadata.is_some());
149 }
150}