1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
mod request;
mod response;

use self::request::{
    ChatCompletionRequest,
    CompletionRequest,
};
pub use self::response::{
    ChatCompletionResponse,
    ChatCompletionResponseChoice,
    CompletionResponse,
    CompletionResponseChoice,
};
use std::sync::Arc;

/// A chat completion request message
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct ChatMessage {
    /// The role
    pub role: Box<str>,

    /// The content
    pub content: Box<str>,
}

/// The library error type
#[derive(Debug, thiserror::Error)]
pub enum Error {
    /// Reqwest HTTP Error
    #[error(transparent)]
    Reqwest(#[from] reqwest::Error),
}

/// An open ai client
#[derive(Debug, Clone)]
pub struct Client {
    /// The inner http client
    pub client: reqwest::Client,

    /// The api key
    key: Arc<str>,
}

impl Client {
    /// Make a new client
    pub fn new(key: &str) -> Self {
        Self {
            client: reqwest::Client::new(),
            key: key.into(),
        }
    }

    /// Perform a completion.
    pub async fn completion(
        &self,
        model: &str,
        max_tokens: u16,
        prompt: &str,
    ) -> Result<CompletionResponse, Error> {
        Ok(self
            .client
            .post("https://api.openai.com/v1/completions")
            .header(
                reqwest::header::AUTHORIZATION,
                format!("Bearer {}", self.key),
            )
            .json(&CompletionRequest {
                model,
                max_tokens,
                prompt,
            })
            .send()
            .await?
            .error_for_status()?
            .json()
            .await?)
    }

    /// Perform a chat completion.
    pub async fn chat_completion(
        &self,
        model: &str,
        messages: &[ChatMessage],
        max_tokens: Option<u16>,
    ) -> Result<ChatCompletionResponse, Error> {
        Ok(self
            .client
            .post("https://api.openai.com/v1/chat/completions")
            .header(
                reqwest::header::AUTHORIZATION,
                format!("Bearer {}", self.key),
            )
            .json(&ChatCompletionRequest {
                model: model.into(),
                messages: messages.into(),
                max_tokens,
            })
            .send()
            .await?
            .error_for_status()?
            .json()
            .await?)
    }
}

#[cfg(test)]
mod test {
    use super::*;
    use once_cell::sync::Lazy;

    static KEY: Lazy<String> =
        Lazy::new(|| std::fs::read_to_string("key.txt").expect("failed to read api key"));

    #[ignore]
    #[tokio::test]
    async fn it_works() {
        let client = Client::new(&KEY);
        let response = client
            .chat_completion(
                "gpt-3.5-turbo",
                &[ChatMessage {
                    role: "user".into(),
                    content: "Hello! How are you today?".into(),
                }],
                None,
            )
            .await
            .expect("failed to get response");
        dbg!(&response);
    }

    #[test]
    fn parse_completion_response() {
        let text = include_str!("../test_data/completion_response.json");
        let response: CompletionResponse = serde_json::from_str(text).expect("failed to parse");
        dbg!(&response);
    }

    #[test]
    fn parse_chat_completion_response() {
        let text = include_str!("../test_data/chat_completion_response.json");
        let response: ChatCompletionResponse = serde_json::from_str(text).expect("failed to parse");
        dbg!(&response);
    }
}