Skip to content

Commit

Permalink
Make sophon case in scope and use case all use closure
Browse files Browse the repository at this point in the history
  • Loading branch information
meloalright committed Apr 6, 2024
1 parent 137b27b commit 6112fa6
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 107 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/Sophon.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,4 @@ jobs:
run: |
./target/release/3body -V
git clone https://huggingface.co/huantian2415/vicuna-13b-chinese-4bit-ggml
./target/release/3body -c 'let 智子 = 智子工程({ "type": "llama", "path": "./vicuna-13b-chinese-4bit-ggml/Vicuna-13B-chinese.bin", "prompt": "你是三体文明的智子" }); 智子.infer(智子, "中国最佳科幻小说是哪个")'
./target/release/3body -c 'let 智子 = fn () { let instance = 智子工程({ "type": "llama", "path": "./vicuna-13b-chinese-4bit-ggml/Vicuna-13B-chinese.bin", "prompt": "你是三体文明的智子" }); return { "回答": fn (问题) { instance.infer(instance, 问题) } } }(); 智子.回答("中国最佳科幻小说是哪个")'
216 changes: 110 additions & 106 deletions interpreter/src/evaluator/builtins.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,110 +147,6 @@ fn three_body_deep_equal(args: Vec<Object>) -> Object {
}
}

fn three_body_sophon_infer(args: Vec<Object>) -> Object {
match &args[0] {
Object::Hash(hash) => {
let model_ptr = match hash.get(&Object::String("model".to_owned())).unwrap() {
Object::Native(native_object) => {
match **native_object {
NativeObject::LLMModel(model_ptr) => {
model_ptr.clone()
},
_ => panic!()
}
},
_ => panic!()
};
let character = hash.get(&Object::String("character".to_owned())).unwrap();
let model = unsafe { & *model_ptr };

let mut session = model.start_session(Default::default());
let meessage = format!("{}", &args[1]);
let prompt = &format!("
下面是描述一项任务的说明。需要适当地完成请求的响应。
### 角色设定:
{}
### 提问:
{}
### 回答:
", character, meessage);

let sp = spinoff::Spinner::new(spinoff::spinners::Arc, "".to_string(), None);

if let Err(llm::InferenceError::ContextFull) = session.feed_prompt::<Infallible>(
model,
&InferenceParameters {
..Default::default()
},
prompt,
&mut Default::default(),
|t| {
Ok(())
},
) {
println!("Prompt exceeds context window length.")
};
sp.clear();

let res = session.infer::<Infallible>(
model,
&mut thread_rng(),
&InferenceRequest {
prompt: "",
..Default::default()
},
// OutputRequest
&mut Default::default(),
|t| {
print!("{t}");
std::io::stdout().flush().unwrap();

Ok(())
},
);

match res {
Err(err) => println!("\n{err}"),
_ => ()
}
Object::Null
},
_ => panic!()
}
}



fn three_body_sophon_close(args: Vec<Object>) -> Object {
match &args[0] {
Object::Hash(hash) => {
let model_ptr = match hash.get(&Object::String("model".to_owned())).unwrap() {
Object::Native(native_object) => {
match **native_object {
NativeObject::LLMModel(model_ptr) => {
model_ptr.clone()
},
_ => panic!()
}
},
_ => panic!()
};
// let model = unsafe { & *model_ptr };
unsafe { Box::from_raw(model_ptr) };
// std::mem::drop(model);
Object::Null
},
_ => panic!()
}
}


fn three_body_sophon_engineering(args: Vec<Object>) -> Object {
match &args[0] {
Object::Hash(o) => {
Expand Down Expand Up @@ -314,13 +210,121 @@ fn three_body_sophon_engineering(args: Vec<Object>) -> Object {
now.elapsed().as_millis()
);



let model_ptr = &mut *model as *mut dyn Model;

let mut session_hash = HashMap::new();
session_hash.insert(Object::String("model".to_owned()), Object::Native(Box::new(NativeObject::LLMModel(model_ptr))));
session_hash.insert(Object::String("character".to_owned()), Object::String(character.to_string()));
session_hash.insert(Object::String("infer".to_owned()), Object::Builtin(2, three_body_sophon_infer));
session_hash.insert(Object::String("close".to_owned()), Object::Builtin(1, three_body_sophon_close));

{

fn three_body_sophon_infer(args: Vec<Object>) -> Object {
match &args[0] {
Object::Hash(hash) => {
let model_ptr = match hash.get(&Object::String("model".to_owned())).unwrap() {
Object::Native(native_object) => {
match **native_object {
NativeObject::LLMModel(model_ptr) => {
model_ptr.clone()
},
_ => panic!()
}
},
_ => panic!()
};
let character = hash.get(&Object::String("character".to_owned())).unwrap();
let model = unsafe { & *model_ptr };

let mut session = model.start_session(Default::default());
let meessage = format!("{}", &args[1]);
let prompt = &format!("
下面是描述一项任务的说明。需要适当地完成请求的响应。
### 角色设定:
{}
### 提问:
{}
### 回答:
", character, meessage);

let sp = spinoff::Spinner::new(spinoff::spinners::Arc, "".to_string(), None);

if let Err(llm::InferenceError::ContextFull) = session.feed_prompt::<Infallible>(
model,
&InferenceParameters {
..Default::default()
},
prompt,
&mut Default::default(),
|t| {
Ok(())
},
) {
println!("Prompt exceeds context window length.")
};
sp.clear();

let res = session.infer::<Infallible>(
model,
&mut thread_rng(),
&InferenceRequest {
prompt: "",
..Default::default()
},
// OutputRequest
&mut Default::default(),
|t| {
print!("{t}");
std::io::stdout().flush().unwrap();

Ok(())
},
);

match res {
Err(err) => println!("\n{err}"),
_ => ()
}
Object::Null
},
_ => panic!()
}
}



fn three_body_sophon_close(args: Vec<Object>) -> Object {
match &args[0] {
Object::Hash(hash) => {
let model_ptr = match hash.get(&Object::String("model".to_owned())).unwrap() {
Object::Native(native_object) => {
match **native_object {
NativeObject::LLMModel(model_ptr) => {
model_ptr.clone()
},
_ => panic!()
}
},
_ => panic!()
};
// let model = unsafe { & *model_ptr };
unsafe { Box::from_raw(model_ptr) };
// std::mem::drop(model);
Object::Null
},
_ => panic!()
}
}
session_hash.insert(Object::String("infer".to_owned()), Object::Builtin(2, three_body_sophon_infer));
session_hash.insert(Object::String("close".to_owned()), Object::Builtin(1, three_body_sophon_close));
}
Object::Hash(session_hash)
}
_ => Object::Null,
Expand Down

0 comments on commit 6112fa6

Please sign in to comment.