-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feature: code gen based on func signature
- Loading branch information
Showing
12 changed files
with
835 additions
and
205 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
use crate::python_file::PythonFile; | ||
use proc_macro2::{Literal, TokenStream as TokenStream2}; | ||
use quote::quote; | ||
|
||
pub struct PythonFileLoader { | ||
module_path: Literal, | ||
module_name: Literal, | ||
file_name: Literal, | ||
} | ||
|
||
impl<'a> From<&'a PythonFile> for PythonFileLoader { | ||
fn from(python_file: &'a PythonFile) -> PythonFileLoader { | ||
PythonFileLoader { | ||
module_path: Literal::string(&format!( | ||
"{}", | ||
std::fs::canonicalize(&python_file.path) | ||
.unwrap() | ||
.to_str() | ||
.unwrap() | ||
)), | ||
module_name: Literal::string(&python_file.uuid), | ||
file_name: Literal::string(&python_file.file_name), | ||
} | ||
} | ||
} | ||
|
||
impl Into<TokenStream2> for PythonFileLoader { | ||
fn into(self) -> TokenStream2 { | ||
let file_name: Literal = self.file_name; | ||
let module_name: Literal = self.module_name; | ||
let module_path: Literal = self.module_path; | ||
|
||
quote! { | ||
pyo3::types::PyModule::from_code( | ||
py, | ||
include_str!(#module_path), | ||
#file_name, | ||
#module_name, | ||
); | ||
} | ||
} | ||
} | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
use super::*; | ||
use std::path::PathBuf; | ||
|
||
#[test] | ||
fn test_py_file_loader() { | ||
let py_file = PythonFile { | ||
file_name: "test.py".into(), | ||
file_stem: "test".into(), | ||
uuid: "27597466".into(), | ||
main_func_args: vec![], | ||
path: PathBuf::from("test_py/test.py"), | ||
}; | ||
|
||
let token_stream: TokenStream2 = PythonFileLoader::from(&py_file).into(); | ||
|
||
let full_file_path = Literal::string(&format!( | ||
"{}", | ||
std::fs::canonicalize(&py_file.path) | ||
.unwrap() | ||
.to_str() | ||
.unwrap() | ||
)); | ||
|
||
let target_ts = quote! { | ||
pyo3::types::PyModule::from_code( | ||
py, | ||
include_str!(#full_file_path), | ||
"test.py", | ||
"27597466", | ||
); | ||
}; | ||
|
||
assert_eq!(token_stream.to_string(), target_ts.to_string()); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,164 @@ | ||
use crate::python_file::PythonFile; | ||
use proc_macro2::{Ident, Literal, Span, TokenStream as TokenStream2}; | ||
use quote::quote; | ||
|
||
use crate::py_arg::PyArg; | ||
|
||
pub struct FormIdent { | ||
ident: Ident, | ||
} | ||
|
||
impl From<&PythonFile> for FormIdent { | ||
fn from(python_file: &PythonFile) -> FormIdent { | ||
FormIdent { | ||
ident: Ident::new(&format!("Form_{}", python_file.uuid), Span::call_site()), | ||
} | ||
} | ||
} | ||
|
||
impl Into<Ident> for FormIdent { | ||
fn into(self) -> Ident { | ||
self.ident | ||
} | ||
} | ||
|
||
pub struct Form { | ||
ident: FormIdent, | ||
variants: Vec<PyArg>, | ||
} | ||
|
||
impl From<&PythonFile> for Form { | ||
fn from(python_file: &PythonFile) -> Form { | ||
Form { | ||
ident: FormIdent::from(python_file), | ||
variants: python_file.main_func_args.clone(), | ||
} | ||
} | ||
} | ||
|
||
impl Into<TokenStream2> for Form { | ||
fn into(self) -> TokenStream2 { | ||
let form_ident: Ident = self.ident.into(); | ||
let struct_fields: Vec<TokenStream2> = self | ||
.variants | ||
.clone() | ||
.into_iter() | ||
.map(|variant| variant.into()) | ||
.collect(); | ||
|
||
let struct_fields_names: Vec<TokenStream2> = self | ||
.variants | ||
.iter() | ||
.map(|variant| { | ||
let literal_name = Literal::string(&variant.name); | ||
quote! { | ||
#literal_name | ||
} | ||
}) | ||
.collect(); | ||
|
||
let struct_fields_values: Vec<TokenStream2> = self | ||
.variants | ||
.into_iter() | ||
.map(|field| { | ||
let variant_ident: Ident = field.into(); | ||
quote! { | ||
self.#variant_ident | ||
} | ||
}) | ||
.collect(); | ||
|
||
quote! { | ||
#[derive(rocket::form::FromForm)] | ||
struct #form_ident { | ||
#(#struct_fields),* | ||
} | ||
|
||
impl #form_ident { | ||
pub fn kwargs(self, py: pyo3::prelude::Python) -> &pyo3::types::PyDict { | ||
use pyo3::types::IntoPyDict; | ||
use pyo3::conversion::IntoPy; | ||
|
||
let mut args : Vec<(&str, pyo3::Py<pyo3::PyAny>)> = vec!(); | ||
|
||
#( | ||
let struct_field_value = #struct_fields_values; | ||
let py_any : pyo3::Py<pyo3::PyAny> = struct_field_value.into_py(py); | ||
|
||
if !py_any.is_none(py) { | ||
args.push((#struct_fields_names, py_any)); | ||
} | ||
)* | ||
|
||
args.into_py_dict(py) | ||
} | ||
} | ||
} | ||
} | ||
} | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
use super::*; | ||
use crate::py_arg::{PyArg, PyPrimitiveDataType}; | ||
use std::path::PathBuf; | ||
|
||
#[test] | ||
fn test_form() { | ||
let py_file = PythonFile { | ||
file_name: "test.py".into(), | ||
file_stem: "test".into(), | ||
uuid: "27597466".into(), | ||
main_func_args: vec![ | ||
PyArg { | ||
name: "input".into(), | ||
data_type: PyPrimitiveDataType::Str, | ||
optional: false, | ||
}, | ||
PyArg { | ||
name: "score".into(), | ||
data_type: PyPrimitiveDataType::Int, | ||
optional: true, | ||
}, | ||
], | ||
path: PathBuf::from("test_py/test.py"), | ||
}; | ||
|
||
let token_stream: TokenStream2 = Form::from(&py_file).into(); | ||
|
||
let target_ts = quote! { | ||
#[derive(rocket::form::FromForm)] | ||
struct Form_27597466 { | ||
input: String, | ||
score: Option<usize> | ||
} | ||
|
||
impl Form_27597466 { | ||
pub fn kwargs(self, py: pyo3::prelude::Python) -> &pyo3::types::PyDict { | ||
use pyo3::types::IntoPyDict; | ||
use pyo3::conversion::IntoPy; | ||
|
||
let mut args : Vec<(&str, pyo3::Py<pyo3::PyAny>)> = vec!(); | ||
|
||
let struct_field_value = self.input; | ||
let py_any : pyo3::Py<pyo3::PyAny> = struct_field_value.into_py(py); | ||
|
||
if !py_any.is_none(py) { | ||
args.push(("input", py_any)); | ||
} | ||
|
||
let struct_field_value = self.score; | ||
let py_any : pyo3::Py<pyo3::PyAny> = struct_field_value.into_py(py); | ||
|
||
if !py_any.is_none(py) { | ||
args.push(("score", py_any)); | ||
} | ||
|
||
args.into_py_dict(py) | ||
} | ||
} | ||
}; | ||
|
||
assert_eq!(token_stream.to_string(), target_ts.to_string()); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
use crate::python_file::PythonFile; | ||
use proc_macro2::{Ident, Literal, Span, TokenStream as TokenStream2}; | ||
use quote::quote; | ||
|
||
use crate::form::FormIdent; | ||
|
||
pub struct HookFunctionIdent { | ||
ident: Ident, | ||
} | ||
|
||
impl From<&PythonFile> for HookFunctionIdent { | ||
fn from(python_file: &PythonFile) -> HookFunctionIdent { | ||
HookFunctionIdent { | ||
ident: Ident::new(&format!("hook_{}", python_file.uuid), Span::call_site()), | ||
} | ||
} | ||
} | ||
|
||
impl Into<Ident> for HookFunctionIdent { | ||
fn into(self) -> Ident { | ||
self.ident | ||
} | ||
} | ||
|
||
pub struct Hook { | ||
ident: HookFunctionIdent, | ||
py_module_name: Literal, | ||
py_file_name: Literal, | ||
form_ident: FormIdent, | ||
} | ||
|
||
impl From<&PythonFile> for Hook { | ||
fn from(python_file: &PythonFile) -> Hook { | ||
Hook { | ||
ident: HookFunctionIdent::from(python_file), | ||
py_module_name: Literal::string(&python_file.uuid), | ||
py_file_name: Literal::string(&python_file.file_name), | ||
form_ident: FormIdent::from(python_file), | ||
} | ||
} | ||
} | ||
|
||
impl Into<TokenStream2> for Hook { | ||
fn into(self) -> TokenStream2 { | ||
let hook_function_ident: Ident = self.ident.into(); | ||
let module_name = self.py_module_name; | ||
let file_name = self.py_file_name; | ||
let form_ident: Ident = self.form_ident.into(); | ||
|
||
quote! { | ||
fn #hook_function_ident(py_lock: pyo3::Python, input: #form_ident) -> String { | ||
let kwargs : &pyo3::types::PyDict = input.kwargs(py_lock); | ||
|
||
let nlp = pyo3::types::PyModule::import( | ||
py_lock, | ||
#module_name, | ||
) | ||
.expect("failed to import PyModule"); | ||
|
||
match nlp | ||
.getattr("call") | ||
.expect(&format!("`call` function was not found in {}. Your python file must include a `call` function that returns json data:\n\ndef call(input):\n\tjson.dumps('{{'foo': 'bar'}}')\n\n", #file_name)) | ||
.call((), Some(kwargs)) { | ||
Ok(result) => result.extract().unwrap_or("{}".to_string()), | ||
Err(e) => format!("{{\"error\": \"{}\"}}", e.to_string()) | ||
} | ||
} | ||
} | ||
} | ||
} | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
use super::*; | ||
use std::path::PathBuf; | ||
|
||
#[test] | ||
fn test_hook() { | ||
let py_file = PythonFile { | ||
file_name: "test.py".into(), | ||
file_stem: "test".into(), | ||
uuid: "27597466".into(), | ||
main_func_args: vec![], | ||
path: PathBuf::from("/test.py"), | ||
}; | ||
|
||
let token_stream: TokenStream2 = Hook::from(&py_file).into(); | ||
|
||
let target_ts = quote! { | ||
fn hook_27597466(py_lock: pyo3::Python, input: Form_27597466) -> String { | ||
let kwargs : &pyo3::types::PyDict = input.kwargs(py_lock); | ||
|
||
let nlp = pyo3::types::PyModule::import( | ||
py_lock, | ||
"27597466", | ||
) | ||
.expect("failed to import PyModule"); | ||
|
||
match nlp | ||
.getattr("call") | ||
.expect(&format!("`call` function was not found in {}. Your python file must include a `call` function that returns json data:\n\ndef call(input):\n\tjson.dumps('{{'foo': 'bar'}}')\n\n", "test.py")) | ||
.call((), Some(kwargs)) { | ||
Ok(result) => result.extract().unwrap_or("{}".to_string()), | ||
Err(e) => format!("{{\"error\": \"{}\"}}", e.to_string()) | ||
} | ||
} | ||
}; | ||
|
||
assert_eq!(token_stream.to_string(), target_ts.to_string()); | ||
} | ||
} |
Oops, something went wrong.