1use crate::error::Result;
7use crate::llm::chat_session::ChatSession;
8use crate::llm::tools::LlmTool;
9use crate::llm::LlmBroker;
10use tracing::{info, warn};
11
12pub struct IterativeProblemSolver {
46 chat: ChatSession,
47 max_iterations: usize,
48}
49
50impl IterativeProblemSolver {
51 pub fn new(broker: LlmBroker) -> Self {
66 Self::builder(broker).build()
67 }
68
69 pub fn builder(broker: LlmBroker) -> IterativeProblemSolverBuilder {
87 IterativeProblemSolverBuilder::new(broker)
88 }
89
90 pub async fn solve(&mut self, problem: &str) -> Result<String> {
115 let mut iterations_remaining = self.max_iterations;
116
117 loop {
118 let result = self.step(problem).await?;
119
120 if result.to_lowercase().contains("fail") {
122 info!(user_request = problem, result = result.as_str(), "Task failed");
123 break;
124 }
125
126 if result.to_lowercase().contains("done") {
128 info!(user_request = problem, result = result.as_str(), "Task completed");
129 break;
130 }
131
132 iterations_remaining -= 1;
133 if iterations_remaining == 0 {
134 warn!(
135 max_iterations = self.max_iterations,
136 user_request = problem,
137 result = result.as_str(),
138 "Max iterations reached"
139 );
140 break;
141 }
142 }
143
144 let summary = self
146 .chat
147 .send(
148 "Summarize the final result, and only the final result, \
149 without commenting on the process by which you achieved it.",
150 )
151 .await?;
152
153 Ok(summary)
154 }
155
156 async fn step(&mut self, problem: &str) -> Result<String> {
169 let prompt = format!(
170 "Given the user request:\n\
171 {}\n\
172 \n\
173 Use the tools at your disposal to act on their request. \
174 You may wish to create a step-by-step plan for more complicated requests.\n\
175 \n\
176 If you cannot provide an answer, say only \"FAIL\".\n\
177 If you have the answer, say only \"DONE\".",
178 problem
179 );
180
181 self.chat.send(&prompt).await
182 }
183}
184
185pub struct IterativeProblemSolverBuilder {
187 broker: LlmBroker,
188 tools: Option<Vec<Box<dyn LlmTool>>>,
189 max_iterations: usize,
190 system_prompt: Option<String>,
191}
192
193impl IterativeProblemSolverBuilder {
194 fn new(broker: LlmBroker) -> Self {
196 Self {
197 broker,
198 tools: None,
199 max_iterations: 3,
200 system_prompt: None,
201 }
202 }
203
204 pub fn tools(mut self, tools: Vec<Box<dyn LlmTool>>) -> Self {
206 self.tools = Some(tools);
207 self
208 }
209
210 pub fn max_iterations(mut self, max_iterations: usize) -> Self {
212 self.max_iterations = max_iterations;
213 self
214 }
215
216 pub fn system_prompt(mut self, prompt: impl Into<String>) -> Self {
218 self.system_prompt = Some(prompt.into());
219 self
220 }
221
222 pub fn build(self) -> IterativeProblemSolver {
224 let system_prompt = self.system_prompt.unwrap_or_else(|| {
225 "You are a problem-solving assistant that can solve complex problems step by step. \
226 You analyze problems, break them down into smaller parts, and solve them systematically. \
227 If you cannot solve a problem completely in one step, you make progress and identify what to do next."
228 .to_string()
229 });
230
231 let mut chat_builder = ChatSession::builder(self.broker).system_prompt(system_prompt);
232
233 if let Some(tools) = self.tools {
234 chat_builder = chat_builder.tools(tools);
235 }
236
237 IterativeProblemSolver {
238 chat: chat_builder.build(),
239 max_iterations: self.max_iterations,
240 }
241 }
242}
243
244#[cfg(test)]
245mod tests {
246 use super::*;
247 use crate::llm::gateway::{CompletionConfig, LlmGateway, StreamChunk};
248 use crate::llm::models::{LlmGatewayResponse, LlmMessage};
249 use crate::llm::tools::{FunctionDescriptor, ToolDescriptor};
250 use futures::stream::{self, Stream};
251 use serde_json::{json, Value};
252 use std::collections::HashMap;
253 use std::pin::Pin;
254 use std::sync::{Arc, Mutex};
255
256 struct MockGateway {
258 responses: Vec<String>,
259 call_count: Arc<Mutex<usize>>,
260 }
261
262 impl MockGateway {
263 fn new(responses: Vec<String>) -> Self {
264 Self {
265 responses,
266 call_count: Arc::new(Mutex::new(0)),
267 }
268 }
269 }
270
271 #[async_trait::async_trait]
272 impl LlmGateway for MockGateway {
273 async fn complete(
274 &self,
275 _model: &str,
276 _messages: &[LlmMessage],
277 _tools: Option<&[Box<dyn LlmTool>]>,
278 _config: &CompletionConfig,
279 ) -> Result<LlmGatewayResponse> {
280 let mut count = self.call_count.lock().unwrap();
281 let idx = *count;
282 *count += 1;
283
284 let content = if idx < self.responses.len() {
285 self.responses[idx].clone()
286 } else {
287 "default response".to_string()
288 };
289
290 Ok(LlmGatewayResponse {
291 content: Some(content),
292 object: None,
293 tool_calls: vec![],
294 thinking: None,
295 })
296 }
297
298 async fn complete_json(
299 &self,
300 _model: &str,
301 _messages: &[LlmMessage],
302 _schema: Value,
303 _config: &CompletionConfig,
304 ) -> Result<Value> {
305 Ok(json!({}))
306 }
307
308 async fn get_available_models(&self) -> Result<Vec<String>> {
309 Ok(vec!["test-model".to_string()])
310 }
311
312 async fn calculate_embeddings(
313 &self,
314 _text: &str,
315 _model: Option<&str>,
316 ) -> Result<Vec<f32>> {
317 Ok(vec![0.1, 0.2, 0.3])
318 }
319
320 fn complete_stream<'a>(
321 &'a self,
322 _model: &'a str,
323 _messages: &'a [LlmMessage],
324 _tools: Option<&'a [Box<dyn LlmTool>]>,
325 _config: &'a CompletionConfig,
326 ) -> Pin<Box<dyn Stream<Item = Result<StreamChunk>> + Send + 'a>> {
327 Box::pin(stream::iter(vec![Ok(StreamChunk::Content("test".to_string()))]))
328 }
329 }
330
331 #[derive(Clone)]
333 struct MockTool {
334 name: String,
335 }
336
337 impl LlmTool for MockTool {
338 fn run(&self, _args: &HashMap<String, Value>) -> Result<Value> {
339 Ok(json!({"result": "success"}))
340 }
341
342 fn descriptor(&self) -> ToolDescriptor {
343 ToolDescriptor {
344 r#type: "function".to_string(),
345 function: FunctionDescriptor {
346 name: self.name.clone(),
347 description: "A mock tool".to_string(),
348 parameters: json!({}),
349 },
350 }
351 }
352
353 fn clone_box(&self) -> Box<dyn LlmTool> {
354 Box::new(self.clone())
355 }
356 }
357
358 #[tokio::test]
359 async fn test_builder_default_settings() {
360 let gateway = Arc::new(MockGateway::new(vec![]));
361 let broker = LlmBroker::new("test-model", gateway, None);
362 let solver = IterativeProblemSolver::new(broker);
363
364 assert_eq!(solver.max_iterations, 3);
365 }
366
367 #[tokio::test]
368 async fn test_builder_custom_max_iterations() {
369 let gateway = Arc::new(MockGateway::new(vec![]));
370 let broker = LlmBroker::new("test-model", gateway, None);
371 let solver = IterativeProblemSolver::builder(broker).max_iterations(5).build();
372
373 assert_eq!(solver.max_iterations, 5);
374 }
375
376 #[tokio::test]
377 async fn test_builder_with_tools() {
378 let gateway = Arc::new(MockGateway::new(vec![]));
379 let broker = LlmBroker::new("test-model", gateway, None);
380
381 let tools: Vec<Box<dyn LlmTool>> = vec![Box::new(MockTool {
382 name: "test_tool".to_string(),
383 })];
384
385 let _solver = IterativeProblemSolver::builder(broker).tools(tools).build();
386
387 }
389
390 #[tokio::test]
391 async fn test_solve_completes_with_done() {
392 let gateway = Arc::new(MockGateway::new(vec![
393 "Working on it...".to_string(),
394 "DONE".to_string(),
395 "The answer is 42".to_string(),
396 ]));
397 let broker = LlmBroker::new("test-model", gateway, None);
398 let mut solver = IterativeProblemSolver::new(broker);
399
400 let result = solver.solve("Test problem").await.unwrap();
401
402 assert_eq!(result, "The answer is 42");
403 }
404
405 #[tokio::test]
406 async fn test_solve_fails_with_fail() {
407 let gateway = Arc::new(MockGateway::new(vec![
408 "Trying...".to_string(),
409 "FAIL".to_string(),
410 "Could not solve the problem".to_string(),
411 ]));
412 let broker = LlmBroker::new("test-model", gateway, None);
413 let mut solver = IterativeProblemSolver::new(broker);
414
415 let result = solver.solve("Impossible problem").await.unwrap();
416
417 assert_eq!(result, "Could not solve the problem");
418 }
419
420 #[tokio::test]
421 async fn test_solve_stops_at_max_iterations() {
422 let gateway = Arc::new(MockGateway::new(vec![
423 "Step 1".to_string(),
424 "Step 2".to_string(),
425 "Step 3".to_string(),
426 "Final summary".to_string(),
427 ]));
428 let broker = LlmBroker::new("test-model", gateway, None);
429 let mut solver = IterativeProblemSolver::builder(broker).max_iterations(3).build();
430
431 let result = solver.solve("Long problem").await.unwrap();
432
433 assert_eq!(result, "Final summary");
435 }
436
437 #[tokio::test]
438 async fn test_solve_case_insensitive_done() {
439 let gateway = Arc::new(MockGateway::new(vec![
440 "done".to_string(), "The task is complete".to_string(), ]));
443 let broker = LlmBroker::new("test-model", gateway, None);
444 let mut solver = IterativeProblemSolver::new(broker);
445
446 let result = solver.solve("Test problem").await.unwrap();
447
448 assert_eq!(result, "The task is complete");
449 }
450
451 #[tokio::test]
452 async fn test_solve_case_insensitive_fail() {
453 let gateway = Arc::new(MockGateway::new(vec![
454 "fail".to_string(), "Unable to complete task".to_string(), ]));
457 let broker = LlmBroker::new("test-model", gateway, None);
458 let mut solver = IterativeProblemSolver::new(broker);
459
460 let result = solver.solve("Test problem").await.unwrap();
461
462 assert_eq!(result, "Unable to complete task");
463 }
464
465 #[tokio::test]
466 async fn test_custom_system_prompt() {
467 let gateway =
468 Arc::new(MockGateway::new(vec!["DONE".to_string(), "Custom response".to_string()]));
469 let broker = LlmBroker::new("test-model", gateway, None);
470 let mut solver = IterativeProblemSolver::builder(broker)
471 .system_prompt("Custom system prompt for testing")
472 .build();
473
474 let result = solver.solve("Test problem").await.unwrap();
475
476 assert_eq!(result, "Custom response");
477 }
478
479 #[tokio::test]
480 async fn test_step_method() {
481 let gateway = Arc::new(MockGateway::new(vec!["Step response".to_string()]));
482 let broker = LlmBroker::new("test-model", gateway, None);
483 let mut solver = IterativeProblemSolver::new(broker);
484
485 let result = solver.step("Test problem").await.unwrap();
486
487 assert_eq!(result, "Step response");
488 }
489
490 #[tokio::test]
491 async fn test_multiple_iterations_before_done() {
492 let gateway = Arc::new(MockGateway::new(vec![
493 "Working...".to_string(),
494 "Still working...".to_string(),
495 "Almost there...".to_string(),
496 "DONE".to_string(),
497 "Completed successfully".to_string(),
498 ]));
499 let broker = LlmBroker::new("test-model", gateway, None);
500 let mut solver = IterativeProblemSolver::builder(broker).max_iterations(5).build();
501
502 let result = solver.solve("Complex problem").await.unwrap();
503
504 assert_eq!(result, "Completed successfully");
505 }
506
507 #[tokio::test]
508 async fn test_done_substring_detection() {
509 let gateway = Arc::new(MockGateway::new(vec![
510 "I'm DONE with this task".to_string(), "Task completed".to_string(),
512 ]));
513 let broker = LlmBroker::new("test-model", gateway, None);
514 let mut solver = IterativeProblemSolver::new(broker);
515
516 let result = solver.solve("Test problem").await.unwrap();
517
518 assert_eq!(result, "Task completed");
519 }
520
521 #[tokio::test]
522 async fn test_fail_substring_detection() {
523 let gateway = Arc::new(MockGateway::new(vec![
524 "This will FAIL".to_string(), "Failed to complete".to_string(),
526 ]));
527 let broker = LlmBroker::new("test-model", gateway, None);
528 let mut solver = IterativeProblemSolver::new(broker);
529
530 let result = solver.solve("Test problem").await.unwrap();
531
532 assert_eq!(result, "Failed to complete");
533 }
534}