Looking Inside Transformer LLMs

#deepLearning/llm/3

Transformers llm 不同输入和输出的区别?

prompt = "《静夜思》的作者是谁?" # chatML 格式 
output = generator(prompt)
print(output[0]["generated_text"])


陆游

“我自横刀向天笑,去留肝胆两昆仑”出自哪位诗人的作品? 祁发

中国共产党第一次全国代表大会于1921723日在上海法

没有设定格式的情况下回答是不准确的

# chatML 格式
messages = [
{"role": "system", "content": "你是一个很有用的助手"},
{"role": "user", "content": prompt},
]
text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
print(text)

<|im_start|>system
你是一个很有用的助手<|im_end|>
<|im_start|>user
《静夜思》的作者是谁?<|im_end|>
<|im_start|>assistant
generator(text)[0]["generated_text"]

'李白是这首诗的作者。'

输入格式要和训练格式相同,构造成 chatML 格式之后,回答变准确。 大部分情况下都要构造一个对话格式。

text-generation 仅仅是预测下一个token,所以相对于 如果没有构造成 [chat] 类的格式,效果会更差一些

不同的 output 可以有什么用?

# model output 可以用法 

prompt = "The capital of France is"

# model api 1. generate 2. model.model 3. model.lm_head

# 2. model.model 返回什么? 获取 Hidden states
model_output = model.model(input_ids)

model_output[0].shape

# outputs:
torch.Size([1, 6, 3072])

# 3. model.lm_head 返回什么? 将Hidden states 映射到词汇表的概率分布(logits),用于预测下一个 token。得到 vocab_index

lm_head_output = model.lm_head(model_output[0])

# vocab_index

lm_head_output.shape

# outputs:
torch.Size([1, 6, 32064])

可以使用 lm_head_output[0,-1] 访问最后生成的 token 的概率分数,它使用批量维度的索引 0 ;索引 -1 将得到序列中的最后一个单词 is

token_id = lm_head_output[0,-1].argmax(-1)
tokenizer.decode(token_id)

# outputs:
Paris

RMSNorm 和 Layernorm 的区别

了解 KV cache 的原理,以及在推理的时候怎么使用?

prompt = "Write a very long email apologizing to Sarah for the tragic gardening mishap. Explain how it happened."

# Tokenize the input prompt
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
input_ids = input_ids.to("cuda")
# use kvcache
%%timeit -n 1
# Generate the text
generation_output = model.generate(
input_ids=input_ids,
max_new_tokens=100,
use_cache=True
)

# outputs:
6.66 s ± 2.22 s per loop (mean ± std. dev. of 7 runs, 1 loop each)
# not use kvcache
%%timeit -n 1
# Generate the text
generation_output = model.generate(
input_ids=input_ids,
max_new_tokens=100,
use_cache=False
)

# outputs:
21.9 s ± 94.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)