rinna株式会社さんが、りんなの3.6bを公開してくれました
https://huggingface.co/rinna/japanese-gpt-neox-3.6b
サンプルソースのcudaの部分をコメントアウトして
model = model.to("mps")
pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cpu2023/5/24現在では
(venv) % pip list | grep torch
torch 2.1.0.dev20230523
torchvision 0.16.0.dev20230523
1回目、うーん、お馬鹿さん
python3 r.py
ユーザー: 日本のおすすめの観光地を教えてください。<NL>
システム: どの地域の観光地が知りたいですか?<NL>
ユーザー: 渋谷の観光地を教えてください。<NL>
システム: そうですね、私のお気に入りの観光地は東京です。</s>
何回か動作させると
システム:
それは広大で、多くの異なる観光スポットがあります。以下は、いくつかのおすすめの観光地です:
- ハチ公
- スカイツリー
- スカイツリー
- 浅草寺
- 富士山
- 皇居
- 日光東照宮
- 日光東照宮</s>
ハチ公しかあってない気もするけど、3.6bの繋がりじゃこんなもんですかね
動作させたソースコードは以下の通り
# 入力プロンプトは、「ユーザー」と「システム」間の会話書式で記述します。
#各発話は、以下で構成されます。
#
#(1) 話者 ("ユーザー" or "システム")
#(2) コロン (:)
#(3) スペース
#(4) 発話テキスト
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
prompt = [
{
"speaker": "ユーザー",
"text": "日本のおすすめの観光地を教えてください。"
},
{
"speaker": "システム",
"text": "どの地域の観光地が知りたいですか?"
},
{
"speaker": "ユーザー",
"text": "渋谷の観光地を教えてください。"
}
]
prompt = [
f"{uttr['speaker']}: {uttr['text']}"
for uttr in prompt
]
prompt = "".join(prompt)
prompt = (
prompt
+ ""
+ "システム: "
)
print(prompt)
#
tokenizer = AutoTokenizer.from_pretrained("rinna/japanese-gpt-neox-3.6b-instruction-sft", use_fast=False)
model = AutoModelForCausalLM.from_pretrained("rinna/japanese-gpt-neox-3.6b-instruction-sft")
#if torch.cuda.is_available():
# model = model.to("cuda")
model = model.to("mps") # arm macなら mps に変換すると速い
# token_ids = tokenizer.encode("ユーザー: 日本で一番高い山はシステム:", add_special_tokens=False, return_tensors="pt")
token_ids = tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt")
with torch.no_grad():
output_ids = model.generate(
token_ids.to(model.device),
do_sample=True,
max_new_tokens=128,
temperature=0.7,
pad_token_id=tokenizer.pad_token_id,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id
)
output = tokenizer.decode(output_ids.tolist()[0][token_ids.size(1):])
output = output.replace("", "\n")
print(output)
エラーは検索用に残しておく
Apple M1 MPS touch error UserWarning: MPS: no support for int64 repeats mask, casting it to int32
end