suriya7's picture
Update README.md
8a2f8be verified
|
raw
history blame
1.96 kB
metadata
library_name: transformers
tags: []

Inference


# Load model directly
from transformers import AutoModelForCausalLM, GPT2Tokenizer

tokenizer = AutoTokenizer.from_pretrained("suriya7/conversational-gpt-1")
model = GPT2Tokenizer.from_pretrained("suriya7/conversational-gpt-1")

Chatting

prompt = """
<|im_start|>system\nYou are a helpful AI assistant named Securitron, trained by Aquilax.<|im_end|>
"""

# Keep a list for the last one conversation exchanges
conversation_history = []

while True:
    user_prompt = input("User Question: ")
    if user_prompt.lower() == 'break':
        break

    # Format the user's input
    user = f"""<|im_start|>user
{user_prompt}<|im_end|>"""

    # Add the user's question to the conversation history
    conversation_history.append(user)

    # Ensure conversation starts with a user's input and keep only the last 2 exchanges (4 turns)
    conversation_history = conversation_history[-5:]

    # Build the full prompt
    prompt = prompt + "\n".join(conversation_history)

    # Tokenize the prompt
    encodeds = tokenizer(prompt, return_tensors="pt", truncation=True).input_ids

    # Move model and inputs to the appropriate device
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model.to(device)
    inputs = encodeds.to(device)

    # Generate the model's response
    generated_ids = model.generate(
        inputs,
        max_new_tokens=512,
        pad_token_id=50259,
        eos_token_id=50259,
        num_return_sequences=1,
    )

    # Decode and process the model's response
    ans = tokenizer.decode(generated_ids[0])
    assistant_response = ans.split("<|im_start|>assistant")[-1].replace("<|im_end|>", "").strip()
    print(f"Assistant: {assistant_response}")

    # Add the assistant's response to the conversation history
    conversation_history.append(f"<|im_start|>assistant\n{assistant_response}<|im_end|>")