binaryninja/download/
instance.rs

1use crate::download::DownloadProvider;
2use crate::headless::is_shutdown_requested;
3use crate::rc::{Ref, RefCountable};
4use crate::string::{strings_to_string_list, BnString, IntoCStr};
5use binaryninjacore_sys::*;
6use std::cell::RefCell;
7use std::collections::HashMap;
8use std::ffi::{c_void, CStr};
9use std::mem::{ManuallyDrop, MaybeUninit};
10use std::os::raw::c_char;
11use std::ptr::null_mut;
12use std::rc::Rc;
13use std::slice;
14
15pub trait CustomDownloadInstance: Sized {
16    fn new_with_provider(provider: DownloadProvider) -> Result<Ref<DownloadInstance>, ()> {
17        let instance_uninit = MaybeUninit::uninit();
18        // SAFETY: Download instance is freed by cb_destroy_instance
19        let leaked_instance = Box::leak(Box::new(instance_uninit));
20        let mut callbacks = BNDownloadInstanceCallbacks {
21            context: leaked_instance as *mut _ as *mut c_void,
22            destroyInstance: Some(cb_destroy_instance::<Self>),
23            performRequest: Some(cb_perform_request::<Self>),
24            performCustomRequest: Some(cb_perform_custom_request::<Self>),
25            freeResponse: Some(cb_free_response),
26        };
27        let instance_ptr = unsafe { BNInitDownloadInstance(provider.handle, &mut callbacks) };
28        // TODO: If possible pass a sensible error back...
29        let instance_ref = unsafe { DownloadInstance::ref_from_raw(instance_ptr) };
30        // We now have the core instance, so we can actually construct the object.
31        leaked_instance.write(Self::from_core(instance_ref.clone()));
32        Ok(instance_ref)
33    }
34
35    /// Construct the object now that the core object has been created.
36    fn from_core(core: Ref<DownloadInstance>) -> Self;
37
38    /// Get the core object, typically the handle is stored directly on the object.
39    fn handle(&self) -> Ref<DownloadInstance>;
40
41    /// Send an HTTP request on behalf of the caller.
42    ///
43    /// The caller will expect you to inform them of progress via the following:
44    ///
45    /// - [DownloadInstance::read_callback]
46    /// - [DownloadInstance::write_callback]
47    /// - [DownloadInstance::progress_callback]
48    fn perform_request(&self, url: &str) -> Result<(), String> {
49        self.perform_custom_request("GET", url, vec![])?;
50        Ok(())
51    }
52
53    /// Send an HTTP request on behalf of the caller.
54    ///
55    /// The caller will expect you to inform them of progress via the following:
56    ///
57    /// - [DownloadInstance::read_callback]
58    /// - [DownloadInstance::write_callback]
59    /// - [DownloadInstance::progress_callback]
60    fn perform_custom_request<I>(
61        &self,
62        method: &str,
63        url: &str,
64        headers: I,
65    ) -> Result<DownloadResponse, String>
66    where
67        I: IntoIterator<Item = (String, String)>;
68}
69
70// TODO: Change this to a trait?
71pub struct DownloadInstanceOutputCallbacks {
72    pub write: Option<Box<dyn FnMut(&[u8]) -> usize>>,
73    pub progress: Option<Box<dyn FnMut(usize, usize) -> bool>>,
74}
75
76// TODO: Change this to a trait?
77pub struct DownloadInstanceInputOutputCallbacks {
78    pub read: Option<Box<dyn FnMut(&mut [u8]) -> Option<usize>>>,
79    pub write: Option<Box<dyn FnMut(&[u8]) -> usize>>,
80    pub progress: Option<Box<dyn FnMut(usize, usize) -> bool>>,
81}
82
83pub struct DownloadResponse {
84    pub status_code: u16,
85    pub headers: HashMap<String, String>,
86}
87
88pub struct OwnedDownloadResponse {
89    pub data: Vec<u8>,
90    pub status_code: u16,
91    pub headers: HashMap<String, String>,
92}
93
94impl OwnedDownloadResponse {
95    /// Attempt to parse the response body as UTF-8.
96    pub fn text(&self) -> Result<String, std::string::FromUtf8Error> {
97        String::from_utf8(self.data.clone())
98    }
99
100    /// Attempt to deserialize the response body as JSON into T.
101    pub fn json<T: serde::de::DeserializeOwned>(&self) -> Result<T, serde_json::Error> {
102        serde_json::from_slice(&self.data)
103    }
104
105    /// Convenience to get a header value by case-insensitive name.
106    pub fn header(&self, name: &str) -> Option<&str> {
107        self.headers
108            .get(&name.to_ascii_lowercase())
109            .map(|s| s.as_str())
110    }
111
112    /// True if the status code is in the 2xx range.
113    pub fn is_success(&self) -> bool {
114        (200..300).contains(&self.status_code)
115    }
116}
117
118/// A reader for a [`DownloadInstance`].
119pub struct DownloadInstanceReader {
120    pub instance: Ref<DownloadInstance>,
121}
122
123impl DownloadInstanceReader {
124    pub fn new(instance: Ref<DownloadInstance>) -> Self {
125        Self { instance }
126    }
127}
128
129impl std::io::Read for DownloadInstanceReader {
130    fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
131        let length = self.instance.read_callback(buf);
132        if length < 0 {
133            Err(std::io::Error::new(
134                std::io::ErrorKind::Interrupted,
135                "Connection interrupted",
136            ))
137        } else {
138            Ok(length as usize)
139        }
140    }
141}
142
143/// A writer for a [`DownloadInstance`].
144pub struct DownloadInstanceWriter {
145    pub instance: Ref<DownloadInstance>,
146    /// The expected length of the download.
147    pub total_length: Option<u64>,
148    /// The current progress of the download.
149    pub progress: u64,
150}
151
152impl DownloadInstanceWriter {
153    pub fn new(instance: Ref<DownloadInstance>, total_length: Option<u64>) -> Self {
154        Self {
155            instance,
156            total_length,
157            progress: 0,
158        }
159    }
160}
161
162impl std::io::Write for DownloadInstanceWriter {
163    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
164        let length = self.instance.write_callback(buf);
165        if is_shutdown_requested() || length == 0 {
166            Err(std::io::Error::from(std::io::ErrorKind::ConnectionAborted))
167        } else {
168            self.progress += buf.len() as u64;
169            if self
170                .instance
171                .progress_callback(self.progress, self.total_length.unwrap_or(u64::MAX))
172            {
173                Ok(length as usize)
174            } else {
175                Err(std::io::Error::from(std::io::ErrorKind::ConnectionAborted))
176            }
177        }
178    }
179
180    fn flush(&mut self) -> std::io::Result<()> {
181        Ok(())
182    }
183}
184
185pub struct DownloadInstance {
186    pub(crate) handle: *mut BNDownloadInstance,
187}
188
189impl DownloadInstance {
190    pub(crate) unsafe fn from_raw(handle: *mut BNDownloadInstance) -> Self {
191        debug_assert!(!handle.is_null());
192        Self { handle }
193    }
194
195    pub(crate) unsafe fn ref_from_raw(handle: *mut BNDownloadInstance) -> Ref<Self> {
196        Ref::new(Self::from_raw(handle))
197    }
198
199    fn get_error(&self) -> String {
200        let err: *mut c_char = unsafe { BNGetErrorForDownloadInstance(self.handle) };
201        unsafe { BnString::into_string(err) }
202    }
203
204    /// Sets the error for the instance, any later call to [`DownloadInstance::get_error`] will
205    /// return the string passed here.
206    fn set_error(&self, err: &str) {
207        let err = err.to_cstr();
208        unsafe { BNSetErrorForDownloadInstance(self.handle, err.as_ptr()) };
209    }
210
211    /// Use inside [`CustomDownloadInstance::perform_custom_request`] to pass data back to the caller.
212    pub fn write_callback(&self, data: &[u8]) -> u64 {
213        unsafe {
214            BNWriteDataForDownloadInstance(self.handle, data.as_ptr() as *mut _, data.len() as u64)
215        }
216    }
217
218    /// Use inside [`CustomDownloadInstance::perform_custom_request`] to read data from the caller.
219    pub fn read_callback(&self, data: &mut [u8]) -> i64 {
220        unsafe {
221            BNReadDataForDownloadInstance(
222                self.handle,
223                data.as_mut_ptr() as *mut _,
224                data.len() as u64,
225            )
226        }
227    }
228
229    /// Use inside [`CustomDownloadInstance::perform_custom_request`] to inform the caller of the request progress.
230    pub fn progress_callback(&self, progress: u64, total: u64) -> bool {
231        unsafe { BNNotifyProgressForDownloadInstance(self.handle, progress, total) }
232    }
233
234    pub fn get<I>(&mut self, url: &str, headers: I) -> Result<OwnedDownloadResponse, String>
235    where
236        I: IntoIterator<Item = (String, String)>,
237    {
238        let buf: Rc<RefCell<Vec<u8>>> = Rc::new(RefCell::new(Vec::new()));
239        let buf_closure = Rc::clone(&buf);
240        let callbacks = DownloadInstanceInputOutputCallbacks {
241            read: None,
242            write: Some(Box::new(move |data: &[u8]| {
243                buf_closure.borrow_mut().extend_from_slice(data);
244                data.len()
245            })),
246            progress: Some(Box::new(|_, _| true)),
247        };
248
249        let resp = self.perform_custom_request("GET", url, headers, &callbacks)?;
250        drop(callbacks);
251        let out = Rc::try_unwrap(buf).map_err(|_| "Buffer held with strong reference")?;
252        Ok(OwnedDownloadResponse {
253            data: out.into_inner(),
254            status_code: resp.status_code,
255            headers: resp.headers,
256        })
257    }
258
259    pub fn post<I>(
260        &mut self,
261        url: &str,
262        headers: I,
263        body: Vec<u8>,
264    ) -> Result<OwnedDownloadResponse, String>
265    where
266        I: IntoIterator<Item = (String, String)>,
267    {
268        let resp_buf: Rc<RefCell<Vec<u8>>> = Rc::new(RefCell::new(Vec::new()));
269        let resp_buf_closure = Rc::clone(&resp_buf);
270        // Request body position tracker captured by the read closure
271        let mut pos = 0usize;
272        let total = body.len();
273        let callbacks = DownloadInstanceInputOutputCallbacks {
274            // Supply request body to the core
275            read: Some(Box::new(move |dst: &mut [u8]| -> Option<usize> {
276                if pos >= total {
277                    return Some(0);
278                }
279                let remaining = total - pos;
280                let to_copy = remaining.min(dst.len());
281                dst[..to_copy].copy_from_slice(&body[pos..pos + to_copy]);
282                pos += to_copy;
283                Some(to_copy)
284            })),
285            // Collect response body
286            write: Some(Box::new(move |data: &[u8]| {
287                resp_buf_closure.borrow_mut().extend_from_slice(data);
288                data.len()
289            })),
290            progress: Some(Box::new(|_, _| true)),
291        };
292
293        let resp = self.perform_custom_request("POST", url, headers, &callbacks)?;
294        drop(callbacks);
295        if !(200..300).contains(&(resp.status_code as i32)) {
296            return Err(format!("HTTP error: {}", resp.status_code));
297        }
298
299        let out = Rc::try_unwrap(resp_buf).map_err(|_| "Buffer held with strong reference")?;
300        Ok(OwnedDownloadResponse {
301            data: out.into_inner(),
302            status_code: resp.status_code,
303            headers: resp.headers,
304        })
305    }
306
307    pub fn post_json<I, T>(
308        &mut self,
309        url: &str,
310        headers: I,
311        body: &T,
312    ) -> Result<OwnedDownloadResponse, String>
313    where
314        I: IntoIterator<Item = (String, String)>,
315        T: serde::Serialize,
316    {
317        let mut headers: Vec<(String, String)> = headers.into_iter().collect();
318        if !headers
319            .iter()
320            .any(|(k, _)| k.eq_ignore_ascii_case("content-type"))
321        {
322            headers.push(("content-type".into(), "application/json".into()));
323        }
324        let bytes = serde_json::to_vec(body).map_err(|e| e.to_string())?;
325        self.post(url, headers, bytes)
326    }
327
328    pub fn perform_request(
329        &mut self,
330        url: &str,
331        callbacks: &DownloadInstanceOutputCallbacks,
332    ) -> Result<(), String> {
333        let mut cbs = BNDownloadInstanceOutputCallbacks {
334            writeCallback: Some(cb_write_output),
335            writeContext: callbacks as *const _ as *mut c_void,
336            progressCallback: Some(cb_progress_output),
337            progressContext: callbacks as *const _ as *mut c_void,
338        };
339
340        let url_raw = url.to_cstr();
341        let result = unsafe {
342            BNPerformDownloadRequest(
343                self.handle,
344                url_raw.as_ptr(),
345                &mut cbs as *mut BNDownloadInstanceOutputCallbacks,
346            )
347        };
348
349        if result < 0 {
350            Err(self.get_error())
351        } else {
352            Ok(())
353        }
354    }
355
356    pub fn perform_custom_request<I>(
357        &mut self,
358        method: &str,
359        url: &str,
360        headers: I,
361        callbacks: &DownloadInstanceInputOutputCallbacks,
362    ) -> Result<DownloadResponse, String>
363    where
364        I: IntoIterator<Item = (String, String)>,
365    {
366        let mut header_keys = vec![];
367        let mut header_values = vec![];
368        for (key, value) in headers {
369            header_keys.push(key.to_cstr());
370            header_values.push(value.to_cstr());
371        }
372
373        let mut header_key_ptrs = vec![];
374        let mut header_value_ptrs = vec![];
375
376        for (key, value) in header_keys.iter().zip(header_values.iter()) {
377            header_key_ptrs.push(key.as_ptr());
378            header_value_ptrs.push(value.as_ptr());
379        }
380
381        let mut cbs = BNDownloadInstanceInputOutputCallbacks {
382            readCallback: Some(cb_read_input),
383            readContext: callbacks as *const _ as *mut c_void,
384            writeCallback: Some(cb_write_input),
385            writeContext: callbacks as *const _ as *mut c_void,
386            progressCallback: Some(cb_progress_input),
387            progressContext: callbacks as *const _ as *mut c_void,
388        };
389
390        let mut response: *mut BNDownloadInstanceResponse = null_mut();
391
392        let method_raw = method.to_cstr();
393        let url_raw = url.to_cstr();
394        let result = unsafe {
395            BNPerformCustomRequest(
396                self.handle,
397                method_raw.as_ptr(),
398                url_raw.as_ptr(),
399                header_key_ptrs.len() as u64,
400                header_key_ptrs.as_ptr(),
401                header_value_ptrs.as_ptr(),
402                &mut response as *mut *mut BNDownloadInstanceResponse,
403                &mut cbs as *mut BNDownloadInstanceInputOutputCallbacks,
404            )
405        };
406
407        if result < 0 {
408            unsafe { BNFreeDownloadInstanceResponse(response) };
409            return Err(self.get_error());
410        }
411
412        let mut response_headers = HashMap::new();
413        unsafe {
414            let response_header_keys: &[*mut c_char] =
415                slice::from_raw_parts((*response).headerKeys, (*response).headerCount as usize);
416            let response_header_values: &[*mut c_char] =
417                slice::from_raw_parts((*response).headerValues, (*response).headerCount as usize);
418
419            for (key, value) in response_header_keys
420                .iter()
421                .zip(response_header_values.iter())
422            {
423                response_headers.insert(
424                    CStr::from_ptr(*key).to_str().unwrap().to_owned(),
425                    CStr::from_ptr(*value).to_str().unwrap().to_owned(),
426                );
427            }
428        }
429
430        let r = DownloadResponse {
431            status_code: unsafe { (*response).statusCode },
432            headers: response_headers,
433        };
434
435        unsafe { BNFreeDownloadInstanceResponse(response) };
436
437        Ok(r)
438    }
439}
440
441// TODO: Verify the object is thread safe in the core (hint its not).
442unsafe impl Send for DownloadInstance {}
443unsafe impl Sync for DownloadInstance {}
444
445impl ToOwned for DownloadInstance {
446    type Owned = Ref<Self>;
447
448    fn to_owned(&self) -> Self::Owned {
449        unsafe { RefCountable::inc_ref(self) }
450    }
451}
452
453unsafe impl RefCountable for DownloadInstance {
454    unsafe fn inc_ref(handle: &Self) -> Ref<Self> {
455        Ref::new(Self {
456            handle: BNNewDownloadInstanceReference(handle.handle),
457        })
458    }
459
460    unsafe fn dec_ref(handle: &Self) {
461        BNFreeDownloadInstance(handle.handle);
462    }
463}
464
465unsafe extern "C" fn cb_read_input(data: *mut u8, len: u64, ctxt: *mut c_void) -> i64 {
466    let callbacks = ctxt as *mut DownloadInstanceInputOutputCallbacks;
467    if let Some(func) = &mut (*callbacks).read {
468        let slice = slice::from_raw_parts_mut(data, len as usize);
469        let result = (func)(slice);
470        if let Some(count) = result {
471            count as i64
472        } else {
473            -1
474        }
475    } else {
476        0
477    }
478}
479
480unsafe extern "C" fn cb_write_input(data: *mut u8, len: u64, ctxt: *mut c_void) -> u64 {
481    let callbacks = ctxt as *mut DownloadInstanceInputOutputCallbacks;
482    if let Some(func) = &mut (*callbacks).write {
483        let slice = slice::from_raw_parts(data, len as usize);
484        let result = (func)(slice);
485        result as u64
486    } else {
487        0
488    }
489}
490
491unsafe extern "C" fn cb_progress_input(ctxt: *mut c_void, progress: usize, total: usize) -> bool {
492    let callbacks = ctxt as *mut DownloadInstanceInputOutputCallbacks;
493    if let Some(func) = &mut (*callbacks).progress {
494        (func)(progress, total)
495    } else {
496        true
497    }
498}
499
500unsafe extern "C" fn cb_write_output(data: *mut u8, len: u64, ctxt: *mut c_void) -> u64 {
501    let callbacks = ctxt as *mut DownloadInstanceOutputCallbacks;
502    if let Some(func) = &mut (*callbacks).write {
503        let slice = slice::from_raw_parts(data, len as usize);
504        let result = (func)(slice);
505        result as u64
506    } else {
507        0u64
508    }
509}
510
511unsafe extern "C" fn cb_progress_output(ctxt: *mut c_void, progress: usize, total: usize) -> bool {
512    let callbacks = ctxt as *mut DownloadInstanceOutputCallbacks;
513    if let Some(func) = &mut (*callbacks).progress {
514        (func)(progress, total)
515    } else {
516        true
517    }
518}
519
520pub unsafe extern "C" fn cb_destroy_instance<C: CustomDownloadInstance>(ctxt: *mut c_void) {
521    let _ = Box::from_raw(ctxt as *mut C);
522}
523
524pub unsafe extern "C" fn cb_perform_request<C: CustomDownloadInstance>(
525    ctxt: *mut c_void,
526    url: *const c_char,
527) -> i32 {
528    let c = ManuallyDrop::new(Box::from_raw(ctxt as *mut C));
529
530    let url = match CStr::from_ptr(url).to_str() {
531        Ok(url) => url,
532        Err(e) => {
533            c.handle().set_error(&format!("Invalid URL: {}", e));
534            return -1;
535        }
536    };
537
538    match c.perform_request(url) {
539        Ok(()) => 0,
540        Err(e) => {
541            c.handle().set_error(&e);
542            -1
543        }
544    }
545}
546
547pub unsafe extern "C" fn cb_perform_custom_request<C: CustomDownloadInstance>(
548    ctxt: *mut c_void,
549    method: *const c_char,
550    url: *const c_char,
551    header_count: u64,
552    header_keys: *const *const c_char,
553    header_values: *const *const c_char,
554    response: *mut *mut BNDownloadInstanceResponse,
555) -> i32 {
556    let c = ManuallyDrop::new(Box::from_raw(ctxt as *mut C));
557
558    let method = match CStr::from_ptr(method).to_str() {
559        Ok(method) => method,
560        Err(e) => {
561            c.handle().set_error(&format!("Invalid Method: {}", e));
562            return -1;
563        }
564    };
565
566    let url = match CStr::from_ptr(url).to_str() {
567        Ok(url) => url,
568        Err(e) => {
569            c.handle().set_error(&format!("Invalid URL: {}", e));
570            return -1;
571        }
572    };
573
574    // SAFETY BnString and *mut c_char are transparent
575    let header_count = usize::try_from(header_count).unwrap();
576    let header_keys = slice::from_raw_parts(header_keys as *const BnString, header_count);
577    let header_values = slice::from_raw_parts(header_values as *const BnString, header_count);
578    let header_keys_str = header_keys.iter().map(|s| s.to_string_lossy().to_string());
579    let header_values_str = header_values
580        .iter()
581        .map(|s| s.to_string_lossy().to_string());
582    let headers = header_keys_str.zip(header_values_str);
583
584    match c.perform_custom_request(method, url, headers) {
585        Ok(res) => {
586            let res_header_keys_ptr = strings_to_string_list(res.headers.keys());
587            let res_header_values_ptr = strings_to_string_list(res.headers.values());
588            let raw_response = BNDownloadInstanceResponse {
589                statusCode: res.status_code,
590                headerCount: res.headers.len() as u64,
591                headerKeys: res_header_keys_ptr,
592                headerValues: res_header_values_ptr,
593            };
594            // Leak the response and free it with cb_free_response
595            unsafe { *response = Box::leak(Box::new(raw_response)) };
596            0
597        }
598        Err(e) => {
599            c.handle().set_error(&e);
600            -1
601        }
602    }
603}
604
605unsafe extern "C" fn cb_free_response(
606    _ctxt: *mut c_void,
607    response: *mut BNDownloadInstanceResponse,
608) {
609    let _ = Box::from_raw(response);
610}