rinna株式会社さんが、りんなの3.6bを公開してくれました
https://huggingface.co/rinna/japanese-gpt-neox-3.6b
サンプルソースのcudaの部分をコメントアウトして
model = model.to("mps")
pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cpu
2023/5/24現在では
(venv) % pip list | grep torch
torch 2.1.0.dev20230523
torchvision 0.16.0.dev20230523
1回目、うーん、お馬鹿さん
python3 r.py
ユーザー: 日本のおすすめの観光地を教えてください。<NL>
システム: どの地域の観光地が知りたいですか?<NL>
ユーザー: 渋谷の観光地を教えてください。<NL>
システム: そうですね、私のお気に入りの観光地は東京です。</s>
何回か動作させると
システム:
それは広大で、多くの異なる観光スポットがあります。以下は、いくつかのおすすめの観光地です:
- ハチ公
- スカイツリー
- スカイツリー
- 浅草寺
- 富士山
- 皇居
- 日光東照宮
- 日光東照宮</s>
ハチ公しかあってない気もするけど、3.6bの繋がりじゃこんなもんですかね
動作させたソースコードは以下の通り
# 入力プロンプトは、「ユーザー」と「システム」間の会話書式で記述します。 #各発話は、以下で構成されます。 # #(1) 話者 ("ユーザー" or "システム") #(2) コロン (:) #(3) スペース #(4) 発話テキスト import torch from transformers import AutoTokenizer, AutoModelForCausalLM prompt = [ { "speaker": "ユーザー", "text": "日本のおすすめの観光地を教えてください。" }, { "speaker": "システム", "text": "どの地域の観光地が知りたいですか?" }, { "speaker": "ユーザー", "text": "渋谷の観光地を教えてください。" } ] prompt = [ f"{uttr['speaker']}: {uttr['text']}" for uttr in prompt ] prompt = "".join(prompt) prompt = ( prompt + " " + "システム: " ) print(prompt) # tokenizer = AutoTokenizer.from_pretrained("rinna/japanese-gpt-neox-3.6b-instruction-sft", use_fast=False) model = AutoModelForCausalLM.from_pretrained("rinna/japanese-gpt-neox-3.6b-instruction-sft") #if torch.cuda.is_available(): # model = model.to("cuda") model = model.to("mps") # arm macなら mps に変換すると速い # token_ids = tokenizer.encode("ユーザー: 日本で一番高い山は システム:", add_special_tokens=False, return_tensors="pt") token_ids = tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt") with torch.no_grad(): output_ids = model.generate( token_ids.to(model.device), do_sample=True, max_new_tokens=128, temperature=0.7, pad_token_id=tokenizer.pad_token_id, bos_token_id=tokenizer.bos_token_id, eos_token_id=tokenizer.eos_token_id ) output = tokenizer.decode(output_ids.tolist()[0][token_ids.size(1):]) output = output.replace(" ", "\n") print(output)
エラーは検索用に残しておく
Apple M1 MPS touch error UserWarning: MPS: no support for int64 repeats mask, casting it to int32
end