From faa53d988fdd458b3722ccc3d25f31dc229e0dd3 Mon Sep 17 00:00:00 2001 From: Mathis EON <60341438+eonm-abes@users.noreply.github.com> Date: Sun, 5 Sep 2021 22:51:41 +0200 Subject: [PATCH] feature: code gen based on func signature --- py-apify-macro/Cargo.toml | 1 + py-apify-macro/src/file_loader.rs | 80 +++++++++ py-apify-macro/src/form.rs | 164 +++++++++++++++++ py-apify-macro/src/hook.rs | 111 ++++++++++++ py-apify-macro/src/lib.rs | 243 +++++--------------------- py-apify-macro/src/mount.rs | 63 +++++++ py-apify-macro/src/py_arg.rs | 175 +++++++++++++++++++ py-apify-macro/src/python_file.rs | 63 +++++++ py-apify-macro/src/request_handler.rs | 133 ++++++++++++++ py-apify-macro/test_py/test.py | 0 src/camembert-masked-lm.py | 5 +- src/main.rs | 2 +- 12 files changed, 835 insertions(+), 205 deletions(-) create mode 100644 py-apify-macro/src/file_loader.rs create mode 100644 py-apify-macro/src/form.rs create mode 100644 py-apify-macro/src/hook.rs create mode 100644 py-apify-macro/src/mount.rs create mode 100644 py-apify-macro/src/py_arg.rs create mode 100644 py-apify-macro/src/python_file.rs create mode 100644 py-apify-macro/src/request_handler.rs create mode 100644 py-apify-macro/test_py/test.py diff --git a/py-apify-macro/Cargo.toml b/py-apify-macro/Cargo.toml index 854e579..dd49652 100644 --- a/py-apify-macro/Cargo.toml +++ b/py-apify-macro/Cargo.toml @@ -14,6 +14,7 @@ env_logger = "0.9.0" syn = "1.0.75" uuid = {version = "0.8", features = ["v4"]} glob = "0.3.0" +rustpython-parser = "0.1.2" [lib] proc-macro = true diff --git a/py-apify-macro/src/file_loader.rs b/py-apify-macro/src/file_loader.rs new file mode 100644 index 0000000..d580ae2 --- /dev/null +++ b/py-apify-macro/src/file_loader.rs @@ -0,0 +1,80 @@ +use crate::python_file::PythonFile; +use proc_macro2::{Literal, TokenStream as TokenStream2}; +use quote::quote; + +pub struct PythonFileLoader { + module_path: Literal, + module_name: Literal, + file_name: Literal, +} + +impl<'a> From<&'a PythonFile> for PythonFileLoader { + fn from(python_file: &'a PythonFile) -> PythonFileLoader { + PythonFileLoader { + module_path: Literal::string(&format!( + "{}", + std::fs::canonicalize(&python_file.path) + .unwrap() + .to_str() + .unwrap() + )), + module_name: Literal::string(&python_file.uuid), + file_name: Literal::string(&python_file.file_name), + } + } +} + +impl Into for PythonFileLoader { + fn into(self) -> TokenStream2 { + let file_name: Literal = self.file_name; + let module_name: Literal = self.module_name; + let module_path: Literal = self.module_path; + + quote! { + pyo3::types::PyModule::from_code( + py, + include_str!(#module_path), + #file_name, + #module_name, + ); + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::path::PathBuf; + + #[test] + fn test_py_file_loader() { + let py_file = PythonFile { + file_name: "test.py".into(), + file_stem: "test".into(), + uuid: "27597466".into(), + main_func_args: vec![], + path: PathBuf::from("test_py/test.py"), + }; + + let token_stream: TokenStream2 = PythonFileLoader::from(&py_file).into(); + + let full_file_path = Literal::string(&format!( + "{}", + std::fs::canonicalize(&py_file.path) + .unwrap() + .to_str() + .unwrap() + )); + + let target_ts = quote! { + pyo3::types::PyModule::from_code( + py, + include_str!(#full_file_path), + "test.py", + "27597466", + ); + }; + + assert_eq!(token_stream.to_string(), target_ts.to_string()); + } +} diff --git a/py-apify-macro/src/form.rs b/py-apify-macro/src/form.rs new file mode 100644 index 0000000..baf7e13 --- /dev/null +++ b/py-apify-macro/src/form.rs @@ -0,0 +1,164 @@ +use crate::python_file::PythonFile; +use proc_macro2::{Ident, Literal, Span, TokenStream as TokenStream2}; +use quote::quote; + +use crate::py_arg::PyArg; + +pub struct FormIdent { + ident: Ident, +} + +impl From<&PythonFile> for FormIdent { + fn from(python_file: &PythonFile) -> FormIdent { + FormIdent { + ident: Ident::new(&format!("Form_{}", python_file.uuid), Span::call_site()), + } + } +} + +impl Into for FormIdent { + fn into(self) -> Ident { + self.ident + } +} + +pub struct Form { + ident: FormIdent, + variants: Vec, +} + +impl From<&PythonFile> for Form { + fn from(python_file: &PythonFile) -> Form { + Form { + ident: FormIdent::from(python_file), + variants: python_file.main_func_args.clone(), + } + } +} + +impl Into for Form { + fn into(self) -> TokenStream2 { + let form_ident: Ident = self.ident.into(); + let struct_fields: Vec = self + .variants + .clone() + .into_iter() + .map(|variant| variant.into()) + .collect(); + + let struct_fields_names: Vec = self + .variants + .iter() + .map(|variant| { + let literal_name = Literal::string(&variant.name); + quote! { + #literal_name + } + }) + .collect(); + + let struct_fields_values: Vec = self + .variants + .into_iter() + .map(|field| { + let variant_ident: Ident = field.into(); + quote! { + self.#variant_ident + } + }) + .collect(); + + quote! { + #[derive(rocket::form::FromForm)] + struct #form_ident { + #(#struct_fields),* + } + + impl #form_ident { + pub fn kwargs(self, py: pyo3::prelude::Python) -> &pyo3::types::PyDict { + use pyo3::types::IntoPyDict; + use pyo3::conversion::IntoPy; + + let mut args : Vec<(&str, pyo3::Py)> = vec!(); + + #( + let struct_field_value = #struct_fields_values; + let py_any : pyo3::Py = struct_field_value.into_py(py); + + if !py_any.is_none(py) { + args.push((#struct_fields_names, py_any)); + } + )* + + args.into_py_dict(py) + } + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::py_arg::{PyArg, PyPrimitiveDataType}; + use std::path::PathBuf; + + #[test] + fn test_form() { + let py_file = PythonFile { + file_name: "test.py".into(), + file_stem: "test".into(), + uuid: "27597466".into(), + main_func_args: vec![ + PyArg { + name: "input".into(), + data_type: PyPrimitiveDataType::Str, + optional: false, + }, + PyArg { + name: "score".into(), + data_type: PyPrimitiveDataType::Int, + optional: true, + }, + ], + path: PathBuf::from("test_py/test.py"), + }; + + let token_stream: TokenStream2 = Form::from(&py_file).into(); + + let target_ts = quote! { + #[derive(rocket::form::FromForm)] + struct Form_27597466 { + input: String, + score: Option + } + + impl Form_27597466 { + pub fn kwargs(self, py: pyo3::prelude::Python) -> &pyo3::types::PyDict { + use pyo3::types::IntoPyDict; + use pyo3::conversion::IntoPy; + + let mut args : Vec<(&str, pyo3::Py)> = vec!(); + + let struct_field_value = self.input; + let py_any : pyo3::Py = struct_field_value.into_py(py); + + if !py_any.is_none(py) { + args.push(("input", py_any)); + } + + let struct_field_value = self.score; + let py_any : pyo3::Py = struct_field_value.into_py(py); + + if !py_any.is_none(py) { + args.push(("score", py_any)); + } + + args.into_py_dict(py) + } + } + }; + + assert_eq!(token_stream.to_string(), target_ts.to_string()); + } +} diff --git a/py-apify-macro/src/hook.rs b/py-apify-macro/src/hook.rs new file mode 100644 index 0000000..e97a22a --- /dev/null +++ b/py-apify-macro/src/hook.rs @@ -0,0 +1,111 @@ +use crate::python_file::PythonFile; +use proc_macro2::{Ident, Literal, Span, TokenStream as TokenStream2}; +use quote::quote; + +use crate::form::FormIdent; + +pub struct HookFunctionIdent { + ident: Ident, +} + +impl From<&PythonFile> for HookFunctionIdent { + fn from(python_file: &PythonFile) -> HookFunctionIdent { + HookFunctionIdent { + ident: Ident::new(&format!("hook_{}", python_file.uuid), Span::call_site()), + } + } +} + +impl Into for HookFunctionIdent { + fn into(self) -> Ident { + self.ident + } +} + +pub struct Hook { + ident: HookFunctionIdent, + py_module_name: Literal, + py_file_name: Literal, + form_ident: FormIdent, +} + +impl From<&PythonFile> for Hook { + fn from(python_file: &PythonFile) -> Hook { + Hook { + ident: HookFunctionIdent::from(python_file), + py_module_name: Literal::string(&python_file.uuid), + py_file_name: Literal::string(&python_file.file_name), + form_ident: FormIdent::from(python_file), + } + } +} + +impl Into for Hook { + fn into(self) -> TokenStream2 { + let hook_function_ident: Ident = self.ident.into(); + let module_name = self.py_module_name; + let file_name = self.py_file_name; + let form_ident: Ident = self.form_ident.into(); + + quote! { + fn #hook_function_ident(py_lock: pyo3::Python, input: #form_ident) -> String { + let kwargs : &pyo3::types::PyDict = input.kwargs(py_lock); + + let nlp = pyo3::types::PyModule::import( + py_lock, + #module_name, + ) + .expect("failed to import PyModule"); + + match nlp + .getattr("call") + .expect(&format!("`call` function was not found in {}. Your python file must include a `call` function that returns json data:\n\ndef call(input):\n\tjson.dumps('{{'foo': 'bar'}}')\n\n", #file_name)) + .call((), Some(kwargs)) { + Ok(result) => result.extract().unwrap_or("{}".to_string()), + Err(e) => format!("{{\"error\": \"{}\"}}", e.to_string()) + } + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::path::PathBuf; + + #[test] + fn test_hook() { + let py_file = PythonFile { + file_name: "test.py".into(), + file_stem: "test".into(), + uuid: "27597466".into(), + main_func_args: vec![], + path: PathBuf::from("/test.py"), + }; + + let token_stream: TokenStream2 = Hook::from(&py_file).into(); + + let target_ts = quote! { + fn hook_27597466(py_lock: pyo3::Python, input: Form_27597466) -> String { + let kwargs : &pyo3::types::PyDict = input.kwargs(py_lock); + + let nlp = pyo3::types::PyModule::import( + py_lock, + "27597466", + ) + .expect("failed to import PyModule"); + + match nlp + .getattr("call") + .expect(&format!("`call` function was not found in {}. Your python file must include a `call` function that returns json data:\n\ndef call(input):\n\tjson.dumps('{{'foo': 'bar'}}')\n\n", "test.py")) + .call((), Some(kwargs)) { + Ok(result) => result.extract().unwrap_or("{}".to_string()), + Err(e) => format!("{{\"error\": \"{}\"}}", e.to_string()) + } + } + }; + + assert_eq!(token_stream.to_string(), target_ts.to_string()); + } +} diff --git a/py-apify-macro/src/lib.rs b/py-apify-macro/src/lib.rs index e3f2e6a..6c01235 100644 --- a/py-apify-macro/src/lib.rs +++ b/py-apify-macro/src/lib.rs @@ -1,226 +1,67 @@ +#![feature(type_ascription)] + extern crate proc_macro; -use glob::glob; -use log::info; use proc_macro::TokenStream; -use proc_macro2::{Ident, Literal, Span, TokenStream as TokenStream2}; +use proc_macro2::TokenStream as TokenStream2; use quote::quote; -use std::path::PathBuf; use syn::{parse::Parser, punctuated::Punctuated, LitStr, Token}; -use uuid; -use uuid::Uuid; - -#[derive(Debug, Clone)] -struct PythonFile { - path: PathBuf, - file_name: String, - file_stem: String, - uuid: String, -} - -impl PythonFile { - fn new(input: PathBuf) -> Self { - PythonFile { - file_name: input.file_name().expect("").to_str().expect("").to_string(), - file_stem: input.file_stem().expect("").to_str().expect("").to_string(), - path: input, - uuid: Uuid::new_v4().to_simple().to_string(), - } - } - #[cfg(not(feature = "no-check"))] - fn check(&self) { - let python_code = std::fs::read_to_string(self.path.clone()) - .expect("Something went wrong reading the file"); - - let display_path = self.path.display().to_string(); - - pyo3::prepare_freethreaded_python(); - - pyo3::Python::with_gil(|py| { - pyo3::types::PyModule::from_code(py, &python_code, &self.uuid, &self.uuid) - .expect("failed to import PyModule"); +mod file_loader; +mod form; +mod hook; +mod mount; +mod py_arg; +mod python_file; +mod request_handler; - let nlp = - pyo3::types::PyModule::import(py, &self.uuid).expect("failed to import PyModule"); +use file_loader::PythonFileLoader; +use form::Form; +use hook::Hook; +use mount::RocketMount; +use request_handler::RequestHandler; - nlp - .getattr("call") - .expect(&format!("`call` function was not found in {:?}. Your python file must include a `call` function that returns json data:\n\ndef call(input):\n\tjson.dumps('{{'foo': 'bar'}}')\n\n", display_path)); - }); - } - - fn request_handler_ident(&self) -> Ident { - Ident::new(&format!("route_{}", self.uuid), Span::call_site()) - } - - fn route_attribute(&self) -> Literal { - Literal::string(&format!("/{}?", self.file_stem)) - } - - fn hook_function_ident(&self) -> Ident { - Ident::new(&format!("hook_{}", self.uuid), Span::call_site()) - } - - fn module_name(&self) -> Literal { - Literal::string(&self.uuid) - } -} - -fn get_py_files(input: Vec) -> Vec { - let mut python_files: Vec = if !input.is_empty() { - input - .iter() - .flat_map(|elem| { - glob(elem) - .expect("Failed to read glob pattern") - .filter_map(|e| e.ok()) - }) - .map(|e| PythonFile::new(e)) - .collect::>() - } else { - glob("./src/*.py") - .expect("Failed to read glob pattern") - .filter_map(|e| e.ok()) - .map(|elem| PythonFile::new(elem)) - .collect::>() - }; - - python_files.sort_by(|a, b| a.file_name.partial_cmp(&b.file_name).unwrap()); - python_files.dedup_by(|a, b| a.file_name == b.file_name); - - #[cfg(not(feature = "no-check"))] - { - python_files.iter().for_each(|file| file.check()); - } +#[proc_macro] +pub fn apify(item: TokenStream) -> TokenStream { + let args: Vec = Punctuated::::parse_terminated + .parse(item) + .expect("invalid arguments") + .into_iter() + .map(|e| e.value()) + .collect(); - return python_files; -} + let python_files = python_file::get_py_files(args); -/// Generates rocket request handler -fn gen_rocket_requests_handlers(input: &Vec) -> TokenStream2 { - let routes = input + let loaders: Vec = python_files .iter() - .map(|py_file| { - let rocket_route_attribute = py_file.route_attribute(); - let route_ident = py_file.request_handler_ident(); - let hook_function_ident = py_file.hook_function_ident(); - - return quote! { - #[get(#rocket_route_attribute)] - fn #route_ident(query: String) -> rocket::response::content::Json { - return rocket::response::content::Json(format!( - "{}", - #hook_function_ident(pyo3::Python::acquire_gil().python(), query) - )); - } - } - .into(); - }) - .collect::>(); - - return quote! { - #(#routes)* - } - .into(); -} + .map(|file| PythonFileLoader::from(file).into()) + .collect(); -/// Generates a hook function that call Python -fn gen_hooks(input: &Vec) -> TokenStream2 { - let hooks = input + let hooks: Vec = python_files .iter() - .map(|py_file| { - let hook_function_ident = py_file.hook_function_ident(); - let module_name = py_file.module_name(); - let file_name = py_file.file_name.clone(); - - return quote! { - fn #hook_function_ident(py_lock: pyo3::Python, input: String) -> String { - let nlp = pyo3::types::PyModule::import( - py_lock, - #module_name, - ) - .expect("failed to import PyModule"); - - nlp - .getattr("call") - .expect(&format!("`call` function was not found in {}. Your python file must include a `call` function that returns json data:\n\ndef call(input):\n\tjson.dumps('{{'foo': 'bar'}}')\n\n", #file_name)) - .call1((input,)) - .unwrap() - .extract() - .unwrap() - } - } - .into(); - }) - .collect::>(); - - return quote! { - #(#hooks)* - } - .into(); -} + .map(|file| Hook::from(file).into()) + .collect(); -fn gen_rocket_mount(input: &Vec) -> TokenStream2 { - let routes_idents = input + let routes: Vec = python_files .iter() - .map(|py_file| py_file.request_handler_ident()) - .collect::>(); + .map(|file| RequestHandler::from(file).into()) + .collect(); - return quote! { - rocket::build().mount("/", routes![#(#routes_idents),*]) - } - .into(); -} + let mount: TokenStream2 = RocketMount::from(&python_files).into(); -fn gen_py_file_loader(input: &Vec) -> TokenStream2 { - let loaders = input + let forms: Vec = python_files .iter() - .map(|py_file| { - let module_name = py_file.module_name(); - let file_name = &py_file.file_name; - - return quote! { - pyo3::types::PyModule::from_code( - py, - include_str!(#file_name), - #file_name, - #module_name, - ) - .unwrap(); - } - .into(); - }) - .collect::>(); - - return quote! { - #(#loaders)* - } - .into(); -} - -#[proc_macro] -pub fn apify(item: TokenStream) -> TokenStream { - let args: Vec = Punctuated::::parse_terminated - .parse(item) - .expect("invalid arguments") - .into_iter() - .map(|e| e.value()) + .map(|file| Form::from(file).into()) .collect(); - let python_files = get_py_files(args); - - let routes = gen_rocket_requests_handlers(&python_files); - let mount = gen_rocket_mount(&python_files); - let loader = gen_py_file_loader(&python_files); - let hooks = gen_hooks(&python_files); - return quote! { - pyo3::prepare_freethreaded_python(); + use pyo3::prelude::*; + pyo3::prepare_freethreaded_python(); pyo3::Python::with_gil(|py| { - #loader - #routes - #hooks + #(#forms)* + #(#loaders)* + #(#routes)* + #(#hooks)* #mount }) } diff --git a/py-apify-macro/src/mount.rs b/py-apify-macro/src/mount.rs new file mode 100644 index 0000000..7727cf0 --- /dev/null +++ b/py-apify-macro/src/mount.rs @@ -0,0 +1,63 @@ +use crate::python_file::PythonFile; +use crate::request_handler::RequestHandlerIdent; +use proc_macro2::{Ident, TokenStream as TokenStream2}; +use quote::quote; + +pub struct RocketMount { + routes: Vec, +} + +impl From<&Vec> for RocketMount { + fn from(python_files: &Vec) -> RocketMount { + RocketMount { + routes: python_files.iter().map(|file| file.into()).collect(), + } + } +} + +impl Into for RocketMount { + fn into(self) -> TokenStream2 { + let idents = self + .routes + .into_iter() + .map(|e| e.into()) + .collect::>(); + + quote! { + rocket::build().mount("/", routes![#(#idents),*]) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::path::PathBuf; + + #[test] + fn test_mount() { + let py_file_1 = PythonFile { + file_name: "test.py".into(), + file_stem: "test".into(), + uuid: "27597466".into(), + main_func_args: vec![], + path: PathBuf::from("/test.py"), + }; + + let py_file_2 = PythonFile { + file_name: "test-1.py".into(), + file_stem: "test-1".into(), + uuid: "41198456".into(), + main_func_args: vec![], + path: PathBuf::from("/test-1.py"), + }; + + let token_stream: TokenStream2 = RocketMount::from(&vec![py_file_1, py_file_2]).into(); + + let target_ts = quote! { + rocket::build().mount("/", routes![route_27597466, route_41198456]) + }; + + assert_eq!(token_stream.to_string(), target_ts.to_string()); + } +} diff --git a/py-apify-macro/src/py_arg.rs b/py-apify-macro/src/py_arg.rs new file mode 100644 index 0000000..b5cc6a8 --- /dev/null +++ b/py-apify-macro/src/py_arg.rs @@ -0,0 +1,175 @@ +use proc_macro2::{Ident, Span, TokenStream as TokenStream2}; +use quote::quote; +use rustpython_parser::ast::{ExpressionType, Located, Parameter, Program, StatementType, Varargs}; +use rustpython_parser::parser; + +#[derive(Debug, Clone)] +pub enum PyPrimitiveDataType { + Str, + Float, + Int, + Bool, +} + +impl From for PyPrimitiveDataType { + fn from(input: String) -> PyPrimitiveDataType { + match input.as_ref() { + "str" => PyPrimitiveDataType::Str, + "int" => PyPrimitiveDataType::Int, + "float" => PyPrimitiveDataType::Float, + "bool" => PyPrimitiveDataType::Bool, + _ => panic!("This datatype is not supported by py apify"), + } + } +} + +impl Into for PyPrimitiveDataType { + fn into(self) -> Ident { + match self { + Self::Str => Ident::new("String", Span::call_site()), + Self::Float => Ident::new("f64", Span::call_site()), + Self::Int => Ident::new("usize", Span::call_site()), + Self::Bool => Ident::new("bool", Span::call_site()), + } + } +} + +#[derive(Debug, Clone)] +pub struct PyArg { + pub name: String, + pub data_type: PyPrimitiveDataType, + pub optional: bool, +} + +impl Into for PyArg { + fn into(self) -> Ident { + Ident::new(&self.name, Span::call_site()) + } +} + +impl Into for PyArg { + fn into(self) -> TokenStream2 { + let struct_field_ident: Ident = self.clone().into(); + let data_type_ident: Ident = self.data_type.into(); + + let data_type = if self.optional { + quote! { + Option<#data_type_ident> + } + } else { + quote! { + #data_type_ident + } + }; + + quote! { + #struct_field_ident: #data_type + } + } +} + +pub fn get_func_by_name<'a>( + program: &'a Program, + func_name: &'a str, +) -> Option<&'a Located> { + program.statements.iter().find(|statement| { + if let StatementType::FunctionDef { name, .. } = &statement.node { + return name == func_name; + } + + false + }) +} + +pub fn collect_func_args_default_values( + func: &Located, +) -> Vec<&Located> { + let mut defaults = vec![]; + + if let StatementType::FunctionDef { args, .. } = &func.node { + defaults.extend(&args.defaults); + defaults.extend( + &args + .kw_defaults + .iter() + .flat_map(|e| e) + .collect::>>(), + ); + + defaults.sort_by(|a, b| { + (a.location.column()) + .partial_cmp(&b.location.column()) + .unwrap() + }); + } + + defaults +} + +pub fn collect_func_args(func: &Located) -> Vec<&Parameter> { + let mut func_args = vec![]; + + if let StatementType::FunctionDef { args, .. } = &func.node { + func_args.extend(&args.args); + func_args.extend(&args.kwonlyargs); + + match &args.vararg { + Varargs::Named(param) => { + func_args.push(param); + } + _ => {} + }; + + func_args.sort_by(|a, b| { + (a.location.column()) + .partial_cmp(&b.location.column()) + .unwrap() + }); + } + + func_args +} + +pub fn get_func_args(py_code: String, func_name: &str) -> Vec { + let program = parser::parse_program(&py_code).unwrap(); + + let call_func = get_func_by_name(&program, &func_name).expect("call function not found"); + let defaults_values = collect_func_args_default_values(&call_func); + let func_args = collect_func_args(&call_func); + + let mut args = vec![]; + + let mut peekable_args = func_args.iter().peekable(); + + while let Some(arg) = peekable_args.next() { + let mut data_type = None; + + if let Some(location) = &arg.annotation { + if let ExpressionType::Identifier { name, .. } = &location.node { + data_type = Some(name.to_string()); + } + } + + let optional = defaults_values.iter().find(|e| { + if let Some(next_elem) = peekable_args.peek() { + e.location.column() > arg.location.column() + && e.location.column() < next_elem.location.column() + } else { + e.location.column() > arg.location.column() + } + }); + + let data_type: PyPrimitiveDataType = match data_type { + Some(d_type) => PyPrimitiveDataType::from(d_type), + None => PyPrimitiveDataType::Str, + }; + + args.push(PyArg { + name: arg.arg.to_string(), + data_type: data_type, + optional: optional.is_some(), + }); + } + + args +} diff --git a/py-apify-macro/src/python_file.rs b/py-apify-macro/src/python_file.rs new file mode 100644 index 0000000..f253f32 --- /dev/null +++ b/py-apify-macro/src/python_file.rs @@ -0,0 +1,63 @@ +use glob::glob; +use std::fs::read_to_string; +use std::path::PathBuf; +use uuid::Uuid; + +use crate::py_arg::{get_func_args, PyArg}; + +#[derive(Debug, Clone)] +pub struct PythonFile { + pub path: PathBuf, + pub file_name: String, + pub file_stem: String, + pub uuid: String, + pub main_func_args: Vec, +} + +impl PythonFile { + pub fn new(input: PathBuf) -> Self { + PythonFile { + file_name: input.file_name().expect("").to_str().expect("").to_string(), + file_stem: input.file_stem().expect("").to_str().expect("").to_string(), + main_func_args: get_func_args( + read_to_string(&input).expect("failed to read file"), + "call", + ), + path: input, + uuid: Uuid::new_v4().to_simple().to_string(), + } + } +} + +pub fn get_py_files(input: Vec) -> Vec { + let mut files_name: Vec = if !input.is_empty() { + input + .iter() + .flat_map(|elem| { + glob(elem) + .expect("Failed to read glob pattern") + .filter_map(|e| e.ok()) + }) + .collect::>() + } else { + glob("./src/*.py") + .expect("Failed to read glob pattern") + .filter_map(|e| e.ok()) + .collect::>() + }; + + files_name.sort(); + files_name.dedup(); + + let python_files = files_name + .into_iter() + .map(|elem| PythonFile::new(elem)) + .collect::>(); + + #[cfg(not(feature = "no-check"))] + { + // python_files.iter().for_each(|file| file.check()); + } + + return python_files; +} diff --git a/py-apify-macro/src/request_handler.rs b/py-apify-macro/src/request_handler.rs new file mode 100644 index 0000000..96d26d9 --- /dev/null +++ b/py-apify-macro/src/request_handler.rs @@ -0,0 +1,133 @@ +use crate::hook::HookFunctionIdent; +use crate::python_file::PythonFile; +use proc_macro2::{Ident, Literal, Span, TokenStream as TokenStream2}; +use quote::quote; + +use crate::form::FormIdent; + +pub struct RouteAttribute<'a> { + route_name: &'a str, +} + +impl<'a> From<&'a PythonFile> for RouteAttribute<'a> { + fn from(python_file: &'a PythonFile) -> RouteAttribute<'a> { + RouteAttribute { + route_name: &python_file.file_stem, + } + } +} + +impl<'a> Into for RouteAttribute<'a> { + fn into(self) -> TokenStream2 { + let route_attribute = Literal::string(&format!("/{}?", self.route_name)); + + quote! { + #[get(#route_attribute)] + } + } +} + +pub struct RequestHandlerIdent { + ident: Ident, +} + +impl From<&PythonFile> for RequestHandlerIdent { + fn from(python_file: &PythonFile) -> RequestHandlerIdent { + RequestHandlerIdent { + ident: Ident::new(&format!("route_{}", python_file.uuid), Span::call_site()), + } + } +} + +impl Into for RequestHandlerIdent { + fn into(self) -> Ident { + self.ident + } +} + +pub struct RequestHandler<'a> { + ident: RequestHandlerIdent, + route_attribute: RouteAttribute<'a>, + hook_function_ident: HookFunctionIdent, + form_ident: FormIdent, +} + +impl<'a> From<&'a PythonFile> for RequestHandler<'a> { + fn from(python_file: &'a PythonFile) -> RequestHandler<'a> { + RequestHandler { + ident: python_file.into(), + route_attribute: RouteAttribute::from(python_file), + hook_function_ident: python_file.into(), + form_ident: FormIdent::from(python_file).into(), + } + } +} + +impl<'a> Into for RequestHandler<'a> { + fn into(self) -> TokenStream2 { + let route_attribute: TokenStream2 = self.route_attribute.into(); + let route_ident: Ident = self.ident.into(); + let hook_function_ident: Ident = self.hook_function_ident.into(); + let form_ident: Ident = self.form_ident.into(); + + quote! { + #route_attribute + fn #route_ident(query: #form_ident) -> rocket::response::content::Json { + return rocket::response::content::Json(format!( + "{}", + #hook_function_ident(pyo3::Python::acquire_gil().python(), query) + )); + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::path::PathBuf; + + #[test] + fn test_route_attribute() { + let py_file = PythonFile { + file_name: "test.py".into(), + file_stem: "test".into(), + uuid: "27597466".into(), + main_func_args: vec![], + path: PathBuf::from("/test.py"), + }; + + let token_stream: TokenStream2 = RouteAttribute::from(&py_file).into(); + + let target_ts = quote! { + #[get("/test?")] + }; + + assert_eq!(token_stream.to_string(), target_ts.to_string()); + } + + #[test] + fn test_route_handler() { + let py_file = PythonFile { + file_name: "test.py".into(), + file_stem: "test".into(), + uuid: "27597466".into(), + main_func_args: vec![], + path: PathBuf::from("/test.py"), + }; + + let token_stream: TokenStream2 = RequestHandler::from(&py_file).into(); + + let target_ts = quote! { + #[get("/test?")] + fn route_27597466(query: Form_27597466) -> rocket::response::content::Json { + return rocket::response::content::Json(format!( + "{}", + hook_27597466(pyo3::Python::acquire_gil().python(), query) + )); + } + }; + + assert_eq!(token_stream.to_string(), target_ts.to_string()); + } +} diff --git a/py-apify-macro/test_py/test.py b/py-apify-macro/test_py/test.py new file mode 100644 index 0000000..e69de29 diff --git a/src/camembert-masked-lm.py b/src/camembert-masked-lm.py index 20cc3be..060fe80 100644 --- a/src/camembert-masked-lm.py +++ b/src/camembert-masked-lm.py @@ -11,9 +11,8 @@ camembert_fill_mask = pipeline("fill-mask", model="camembert/camembert-large", tokenizer="camembert/camembert-large", top_k = 10) - -def call(input): - entities = camembert_fill_mask(input) +def call(input, top_k: int = 5): + entities = camembert_fill_mask(input, top_k = top_k) # converts digit to str (for json export) converted_entities = [{k: str(v) for (k,v) in i.items()} for i in entities] diff --git a/src/main.rs b/src/main.rs index 677a4f9..935da61 100644 --- a/src/main.rs +++ b/src/main.rs @@ -10,5 +10,5 @@ use logger::setup_logger; #[launch] fn rocket() -> _ { setup_logger(); - apify! {"src/*.py"} + apify! {} }