import argparse import random import sys from transformers import AutoModelForCausalLM, AutoTokenizer, DynamicCache import torch parser = argparse.ArgumentParser() parser.add_argument("question", type=str) parser.add_argument( "-m", "--model-name", default="deepseek-ai/DeepSeek-R1-Distill-Qwen-32B" ) parser.add_argument("-d", "--device", default="auto") parser.add_argument( "-r", "--replacements", nargs="+", default=["\nWait, but", "\nHmm", "\nSo"] ) parser.add_argument("-t", "--min-thinking-tokens", type=int, default=128) parser.add_argument("-p", "--prefill", default="") args = parser.parse_args() tokenizer = AutoTokenizer.from_pretrained(args.model_name) model = AutoModelForCausalLM.from_pretrained( args.model_name, torch_dtype=torch.bfloat16, device_map=args.device ) _, _start_think_token, end_think_token = tokenizer.encode("<think></think>") @torch.inference_mode def reasoning_effort(question: str, min_thinking_tokens: int): tokens = tokenizer.apply_chat_template( [ {"role": "user", "content": question}, {"role": "assistant", "content": "<think>\n" + args.prefill}, ], continue_final_message=True, return_tensors="pt", ) tokens = tokens.to(model.device) kv = DynamicCache() n_thinking_tokens = 0 yield tokenizer.decode(list(tokens[0])) while True: out = model(input_ids=tokens, past_key_values=kv, use_cache=True) next_token = torch.multinomial( torch.softmax(out.logits[0, -1, :], dim=-1), 1 ).item() kv = out.past_key_values if ( next_token in (end_think_token, model.config.eos_token_id) and n_thinking_tokens < min_thinking_tokens ): replacement = random.choice(args.replacements) yield replacement replacement_tokens = tokenizer.encode(replacement) n_thinking_tokens += len(replacement_tokens) tokens = torch.tensor([replacement_tokens]).to(tokens.device) elif next_token == model.config.eos_token_id: break else: yield tokenizer.decode([next_token]) n_thinking_tokens += 1 tokens = torch.tensor([[next_token]]).to(tokens.device) for chunk in reasoning_effort(args.question, args.min_thinking_tokens): print(chunk, end="", flush=True)