1use crate::download::DownloadProvider;
2use crate::headless::is_shutdown_requested;
3use crate::rc::{Ref, RefCountable};
4use crate::string::{strings_to_string_list, BnString, IntoCStr};
5use binaryninjacore_sys::*;
6use std::cell::RefCell;
7use std::collections::HashMap;
8use std::ffi::{c_void, CStr};
9use std::mem::{ManuallyDrop, MaybeUninit};
10use std::os::raw::c_char;
11use std::ptr::null_mut;
12use std::rc::Rc;
13use std::slice;
14
15pub trait CustomDownloadInstance: Sized {
16 fn new_with_provider(provider: DownloadProvider) -> Result<Ref<DownloadInstance>, ()> {
17 let instance_uninit = MaybeUninit::uninit();
18 let leaked_instance = Box::leak(Box::new(instance_uninit));
20 let mut callbacks = BNDownloadInstanceCallbacks {
21 context: leaked_instance as *mut _ as *mut c_void,
22 destroyInstance: Some(cb_destroy_instance::<Self>),
23 performRequest: Some(cb_perform_request::<Self>),
24 performCustomRequest: Some(cb_perform_custom_request::<Self>),
25 freeResponse: Some(cb_free_response),
26 };
27 let instance_ptr = unsafe { BNInitDownloadInstance(provider.handle, &mut callbacks) };
28 let instance_ref = unsafe { DownloadInstance::ref_from_raw(instance_ptr) };
30 leaked_instance.write(Self::from_core(instance_ref.clone()));
32 Ok(instance_ref)
33 }
34
35 fn from_core(core: Ref<DownloadInstance>) -> Self;
37
38 fn handle(&self) -> Ref<DownloadInstance>;
40
41 fn perform_request(&self, url: &str) -> Result<(), String> {
49 self.perform_custom_request("GET", url, vec![])?;
50 Ok(())
51 }
52
53 fn perform_custom_request<I>(
61 &self,
62 method: &str,
63 url: &str,
64 headers: I,
65 ) -> Result<DownloadResponse, String>
66 where
67 I: IntoIterator<Item = (String, String)>;
68}
69
70pub struct DownloadInstanceOutputCallbacks {
72 pub write: Option<Box<dyn FnMut(&[u8]) -> usize>>,
73 pub progress: Option<Box<dyn FnMut(usize, usize) -> bool>>,
74}
75
76pub struct DownloadInstanceInputOutputCallbacks {
78 pub read: Option<Box<dyn FnMut(&mut [u8]) -> Option<usize>>>,
79 pub write: Option<Box<dyn FnMut(&[u8]) -> usize>>,
80 pub progress: Option<Box<dyn FnMut(usize, usize) -> bool>>,
81}
82
83pub struct DownloadResponse {
84 pub status_code: u16,
85 pub headers: HashMap<String, String>,
86}
87
88pub struct OwnedDownloadResponse {
89 pub data: Vec<u8>,
90 pub status_code: u16,
91 pub headers: HashMap<String, String>,
92}
93
94impl OwnedDownloadResponse {
95 pub fn text(&self) -> Result<String, std::string::FromUtf8Error> {
97 String::from_utf8(self.data.clone())
98 }
99
100 pub fn json<T: serde::de::DeserializeOwned>(&self) -> Result<T, serde_json::Error> {
102 serde_json::from_slice(&self.data)
103 }
104
105 pub fn header(&self, name: &str) -> Option<&str> {
107 self.headers
108 .get(&name.to_ascii_lowercase())
109 .map(|s| s.as_str())
110 }
111
112 pub fn is_success(&self) -> bool {
114 (200..300).contains(&self.status_code)
115 }
116}
117
118pub struct DownloadInstanceReader {
120 pub instance: Ref<DownloadInstance>,
121}
122
123impl DownloadInstanceReader {
124 pub fn new(instance: Ref<DownloadInstance>) -> Self {
125 Self { instance }
126 }
127}
128
129impl std::io::Read for DownloadInstanceReader {
130 fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
131 let length = self.instance.read_callback(buf);
132 if length < 0 {
133 Err(std::io::Error::new(
134 std::io::ErrorKind::Interrupted,
135 "Connection interrupted",
136 ))
137 } else {
138 Ok(length as usize)
139 }
140 }
141}
142
143pub struct DownloadInstanceWriter {
145 pub instance: Ref<DownloadInstance>,
146 pub total_length: Option<u64>,
148 pub progress: u64,
150}
151
152impl DownloadInstanceWriter {
153 pub fn new(instance: Ref<DownloadInstance>, total_length: Option<u64>) -> Self {
154 Self {
155 instance,
156 total_length,
157 progress: 0,
158 }
159 }
160}
161
162impl std::io::Write for DownloadInstanceWriter {
163 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
164 let length = self.instance.write_callback(buf);
165 if is_shutdown_requested() || length == 0 {
166 Err(std::io::Error::from(std::io::ErrorKind::ConnectionAborted))
167 } else {
168 self.progress += buf.len() as u64;
169 if self
170 .instance
171 .progress_callback(self.progress, self.total_length.unwrap_or(u64::MAX))
172 {
173 Ok(length as usize)
174 } else {
175 Err(std::io::Error::from(std::io::ErrorKind::ConnectionAborted))
176 }
177 }
178 }
179
180 fn flush(&mut self) -> std::io::Result<()> {
181 Ok(())
182 }
183}
184
185pub struct DownloadInstance {
186 pub(crate) handle: *mut BNDownloadInstance,
187}
188
189impl DownloadInstance {
190 pub(crate) unsafe fn from_raw(handle: *mut BNDownloadInstance) -> Self {
191 debug_assert!(!handle.is_null());
192 Self { handle }
193 }
194
195 pub(crate) unsafe fn ref_from_raw(handle: *mut BNDownloadInstance) -> Ref<Self> {
196 Ref::new(Self::from_raw(handle))
197 }
198
199 fn get_error(&self) -> String {
200 let err: *mut c_char = unsafe { BNGetErrorForDownloadInstance(self.handle) };
201 unsafe { BnString::into_string(err) }
202 }
203
204 fn set_error(&self, err: &str) {
207 let err = err.to_cstr();
208 unsafe { BNSetErrorForDownloadInstance(self.handle, err.as_ptr()) };
209 }
210
211 pub fn write_callback(&self, data: &[u8]) -> u64 {
213 unsafe {
214 BNWriteDataForDownloadInstance(self.handle, data.as_ptr() as *mut _, data.len() as u64)
215 }
216 }
217
218 pub fn read_callback(&self, data: &mut [u8]) -> i64 {
220 unsafe {
221 BNReadDataForDownloadInstance(
222 self.handle,
223 data.as_mut_ptr() as *mut _,
224 data.len() as u64,
225 )
226 }
227 }
228
229 pub fn progress_callback(&self, progress: u64, total: u64) -> bool {
231 unsafe { BNNotifyProgressForDownloadInstance(self.handle, progress, total) }
232 }
233
234 pub fn get<I>(&mut self, url: &str, headers: I) -> Result<OwnedDownloadResponse, String>
235 where
236 I: IntoIterator<Item = (String, String)>,
237 {
238 let buf: Rc<RefCell<Vec<u8>>> = Rc::new(RefCell::new(Vec::new()));
239 let buf_closure = Rc::clone(&buf);
240 let callbacks = DownloadInstanceInputOutputCallbacks {
241 read: None,
242 write: Some(Box::new(move |data: &[u8]| {
243 buf_closure.borrow_mut().extend_from_slice(data);
244 data.len()
245 })),
246 progress: Some(Box::new(|_, _| true)),
247 };
248
249 let resp = self.perform_custom_request("GET", url, headers, &callbacks)?;
250 drop(callbacks);
251 let out = Rc::try_unwrap(buf).map_err(|_| "Buffer held with strong reference")?;
252 Ok(OwnedDownloadResponse {
253 data: out.into_inner(),
254 status_code: resp.status_code,
255 headers: resp.headers,
256 })
257 }
258
259 pub fn post<I>(
260 &mut self,
261 url: &str,
262 headers: I,
263 body: Vec<u8>,
264 ) -> Result<OwnedDownloadResponse, String>
265 where
266 I: IntoIterator<Item = (String, String)>,
267 {
268 let resp_buf: Rc<RefCell<Vec<u8>>> = Rc::new(RefCell::new(Vec::new()));
269 let resp_buf_closure = Rc::clone(&resp_buf);
270 let mut pos = 0usize;
272 let total = body.len();
273 let callbacks = DownloadInstanceInputOutputCallbacks {
274 read: Some(Box::new(move |dst: &mut [u8]| -> Option<usize> {
276 if pos >= total {
277 return Some(0);
278 }
279 let remaining = total - pos;
280 let to_copy = remaining.min(dst.len());
281 dst[..to_copy].copy_from_slice(&body[pos..pos + to_copy]);
282 pos += to_copy;
283 Some(to_copy)
284 })),
285 write: Some(Box::new(move |data: &[u8]| {
287 resp_buf_closure.borrow_mut().extend_from_slice(data);
288 data.len()
289 })),
290 progress: Some(Box::new(|_, _| true)),
291 };
292
293 let resp = self.perform_custom_request("POST", url, headers, &callbacks)?;
294 drop(callbacks);
295 if !(200..300).contains(&(resp.status_code as i32)) {
296 return Err(format!("HTTP error: {}", resp.status_code));
297 }
298
299 let out = Rc::try_unwrap(resp_buf).map_err(|_| "Buffer held with strong reference")?;
300 Ok(OwnedDownloadResponse {
301 data: out.into_inner(),
302 status_code: resp.status_code,
303 headers: resp.headers,
304 })
305 }
306
307 pub fn post_json<I, T>(
308 &mut self,
309 url: &str,
310 headers: I,
311 body: &T,
312 ) -> Result<OwnedDownloadResponse, String>
313 where
314 I: IntoIterator<Item = (String, String)>,
315 T: serde::Serialize,
316 {
317 let mut headers: Vec<(String, String)> = headers.into_iter().collect();
318 if !headers
319 .iter()
320 .any(|(k, _)| k.eq_ignore_ascii_case("content-type"))
321 {
322 headers.push(("content-type".into(), "application/json".into()));
323 }
324 let bytes = serde_json::to_vec(body).map_err(|e| e.to_string())?;
325 self.post(url, headers, bytes)
326 }
327
328 pub fn perform_request(
329 &mut self,
330 url: &str,
331 callbacks: &DownloadInstanceOutputCallbacks,
332 ) -> Result<(), String> {
333 let mut cbs = BNDownloadInstanceOutputCallbacks {
334 writeCallback: Some(cb_write_output),
335 writeContext: callbacks as *const _ as *mut c_void,
336 progressCallback: Some(cb_progress_output),
337 progressContext: callbacks as *const _ as *mut c_void,
338 };
339
340 let url_raw = url.to_cstr();
341 let result = unsafe {
342 BNPerformDownloadRequest(
343 self.handle,
344 url_raw.as_ptr(),
345 &mut cbs as *mut BNDownloadInstanceOutputCallbacks,
346 )
347 };
348
349 if result < 0 {
350 Err(self.get_error())
351 } else {
352 Ok(())
353 }
354 }
355
356 pub fn perform_custom_request<I>(
357 &mut self,
358 method: &str,
359 url: &str,
360 headers: I,
361 callbacks: &DownloadInstanceInputOutputCallbacks,
362 ) -> Result<DownloadResponse, String>
363 where
364 I: IntoIterator<Item = (String, String)>,
365 {
366 let mut header_keys = vec![];
367 let mut header_values = vec![];
368 for (key, value) in headers {
369 header_keys.push(key.to_cstr());
370 header_values.push(value.to_cstr());
371 }
372
373 let mut header_key_ptrs = vec![];
374 let mut header_value_ptrs = vec![];
375
376 for (key, value) in header_keys.iter().zip(header_values.iter()) {
377 header_key_ptrs.push(key.as_ptr());
378 header_value_ptrs.push(value.as_ptr());
379 }
380
381 let mut cbs = BNDownloadInstanceInputOutputCallbacks {
382 readCallback: Some(cb_read_input),
383 readContext: callbacks as *const _ as *mut c_void,
384 writeCallback: Some(cb_write_input),
385 writeContext: callbacks as *const _ as *mut c_void,
386 progressCallback: Some(cb_progress_input),
387 progressContext: callbacks as *const _ as *mut c_void,
388 };
389
390 let mut response: *mut BNDownloadInstanceResponse = null_mut();
391
392 let method_raw = method.to_cstr();
393 let url_raw = url.to_cstr();
394 let result = unsafe {
395 BNPerformCustomRequest(
396 self.handle,
397 method_raw.as_ptr(),
398 url_raw.as_ptr(),
399 header_key_ptrs.len() as u64,
400 header_key_ptrs.as_ptr(),
401 header_value_ptrs.as_ptr(),
402 &mut response as *mut *mut BNDownloadInstanceResponse,
403 &mut cbs as *mut BNDownloadInstanceInputOutputCallbacks,
404 )
405 };
406
407 if result < 0 {
408 unsafe { BNFreeDownloadInstanceResponse(response) };
409 return Err(self.get_error());
410 }
411
412 let mut response_headers = HashMap::new();
413 unsafe {
414 let response_header_keys: &[*mut c_char] =
415 slice::from_raw_parts((*response).headerKeys, (*response).headerCount as usize);
416 let response_header_values: &[*mut c_char] =
417 slice::from_raw_parts((*response).headerValues, (*response).headerCount as usize);
418
419 for (key, value) in response_header_keys
420 .iter()
421 .zip(response_header_values.iter())
422 {
423 response_headers.insert(
424 CStr::from_ptr(*key).to_str().unwrap().to_owned(),
425 CStr::from_ptr(*value).to_str().unwrap().to_owned(),
426 );
427 }
428 }
429
430 let r = DownloadResponse {
431 status_code: unsafe { (*response).statusCode },
432 headers: response_headers,
433 };
434
435 unsafe { BNFreeDownloadInstanceResponse(response) };
436
437 Ok(r)
438 }
439}
440
441unsafe impl Send for DownloadInstance {}
443unsafe impl Sync for DownloadInstance {}
444
445impl ToOwned for DownloadInstance {
446 type Owned = Ref<Self>;
447
448 fn to_owned(&self) -> Self::Owned {
449 unsafe { RefCountable::inc_ref(self) }
450 }
451}
452
453unsafe impl RefCountable for DownloadInstance {
454 unsafe fn inc_ref(handle: &Self) -> Ref<Self> {
455 Ref::new(Self {
456 handle: BNNewDownloadInstanceReference(handle.handle),
457 })
458 }
459
460 unsafe fn dec_ref(handle: &Self) {
461 BNFreeDownloadInstance(handle.handle);
462 }
463}
464
465unsafe extern "C" fn cb_read_input(data: *mut u8, len: u64, ctxt: *mut c_void) -> i64 {
466 let callbacks = ctxt as *mut DownloadInstanceInputOutputCallbacks;
467 if let Some(func) = &mut (*callbacks).read {
468 let slice = slice::from_raw_parts_mut(data, len as usize);
469 let result = (func)(slice);
470 if let Some(count) = result {
471 count as i64
472 } else {
473 -1
474 }
475 } else {
476 0
477 }
478}
479
480unsafe extern "C" fn cb_write_input(data: *mut u8, len: u64, ctxt: *mut c_void) -> u64 {
481 let callbacks = ctxt as *mut DownloadInstanceInputOutputCallbacks;
482 if let Some(func) = &mut (*callbacks).write {
483 let slice = slice::from_raw_parts(data, len as usize);
484 let result = (func)(slice);
485 result as u64
486 } else {
487 0
488 }
489}
490
491unsafe extern "C" fn cb_progress_input(ctxt: *mut c_void, progress: usize, total: usize) -> bool {
492 let callbacks = ctxt as *mut DownloadInstanceInputOutputCallbacks;
493 if let Some(func) = &mut (*callbacks).progress {
494 (func)(progress, total)
495 } else {
496 true
497 }
498}
499
500unsafe extern "C" fn cb_write_output(data: *mut u8, len: u64, ctxt: *mut c_void) -> u64 {
501 let callbacks = ctxt as *mut DownloadInstanceOutputCallbacks;
502 if let Some(func) = &mut (*callbacks).write {
503 let slice = slice::from_raw_parts(data, len as usize);
504 let result = (func)(slice);
505 result as u64
506 } else {
507 0u64
508 }
509}
510
511unsafe extern "C" fn cb_progress_output(ctxt: *mut c_void, progress: usize, total: usize) -> bool {
512 let callbacks = ctxt as *mut DownloadInstanceOutputCallbacks;
513 if let Some(func) = &mut (*callbacks).progress {
514 (func)(progress, total)
515 } else {
516 true
517 }
518}
519
520pub unsafe extern "C" fn cb_destroy_instance<C: CustomDownloadInstance>(ctxt: *mut c_void) {
521 let _ = Box::from_raw(ctxt as *mut C);
522}
523
524pub unsafe extern "C" fn cb_perform_request<C: CustomDownloadInstance>(
525 ctxt: *mut c_void,
526 url: *const c_char,
527) -> i32 {
528 let c = ManuallyDrop::new(Box::from_raw(ctxt as *mut C));
529
530 let url = match CStr::from_ptr(url).to_str() {
531 Ok(url) => url,
532 Err(e) => {
533 c.handle().set_error(&format!("Invalid URL: {}", e));
534 return -1;
535 }
536 };
537
538 match c.perform_request(url) {
539 Ok(()) => 0,
540 Err(e) => {
541 c.handle().set_error(&e);
542 -1
543 }
544 }
545}
546
547pub unsafe extern "C" fn cb_perform_custom_request<C: CustomDownloadInstance>(
548 ctxt: *mut c_void,
549 method: *const c_char,
550 url: *const c_char,
551 header_count: u64,
552 header_keys: *const *const c_char,
553 header_values: *const *const c_char,
554 response: *mut *mut BNDownloadInstanceResponse,
555) -> i32 {
556 let c = ManuallyDrop::new(Box::from_raw(ctxt as *mut C));
557
558 let method = match CStr::from_ptr(method).to_str() {
559 Ok(method) => method,
560 Err(e) => {
561 c.handle().set_error(&format!("Invalid Method: {}", e));
562 return -1;
563 }
564 };
565
566 let url = match CStr::from_ptr(url).to_str() {
567 Ok(url) => url,
568 Err(e) => {
569 c.handle().set_error(&format!("Invalid URL: {}", e));
570 return -1;
571 }
572 };
573
574 let header_count = usize::try_from(header_count).unwrap();
576 let header_keys = slice::from_raw_parts(header_keys as *const BnString, header_count);
577 let header_values = slice::from_raw_parts(header_values as *const BnString, header_count);
578 let header_keys_str = header_keys.iter().map(|s| s.to_string_lossy().to_string());
579 let header_values_str = header_values
580 .iter()
581 .map(|s| s.to_string_lossy().to_string());
582 let headers = header_keys_str.zip(header_values_str);
583
584 match c.perform_custom_request(method, url, headers) {
585 Ok(res) => {
586 let res_header_keys_ptr = strings_to_string_list(res.headers.keys());
587 let res_header_values_ptr = strings_to_string_list(res.headers.values());
588 let raw_response = BNDownloadInstanceResponse {
589 statusCode: res.status_code,
590 headerCount: res.headers.len() as u64,
591 headerKeys: res_header_keys_ptr,
592 headerValues: res_header_values_ptr,
593 };
594 unsafe { *response = Box::leak(Box::new(raw_response)) };
596 0
597 }
598 Err(e) => {
599 c.handle().set_error(&e);
600 -1
601 }
602 }
603}
604
605unsafe extern "C" fn cb_free_response(
606 _ctxt: *mut c_void,
607 response: *mut BNDownloadInstanceResponse,
608) {
609 let _ = Box::from_raw(response);
610}