class ToolingGemma:
def __init__(self, system_instructions):
self.chat_history = ''
self.model = AutoModelForCausalLM.from_pretrained('SauravP97/tooling-gemma-270M-inst', device_map="cpu")
self.tokenizer = AutoTokenizer.from_pretrained('google/gemma-3-270m')
self.system_instructions = system_instructions
self.stop_word = "<|endoftext|>"
def generate(self, user_query):
user_query = 'USER: ' + user_query
if self.chat_history:
prompt = self.chat_history + '\n' + user_query + '\n' + 'ASSISTANT:'
else:
prompt = self.system_instructions + '\n\n' + user_query + '\n' + 'ASSISTANT:'
input_ids = self.tokenizer(prompt, return_tensors="pt")
agent_response = self.model.generate(
**input_ids,
generation_config=GenerationConfig.from_dict({"max_new_tokens": 1000}),
stop_strings=[self.stop_word],
tokenizer=self.tokenizer,
)
decoded_agent_response = self.tokenizer.decode(agent_response[0])
self.chat_history = decoded_agent_response
return decoded_agent_response
system_instructions = '''
SYSTEM: You are a helpful assistant with access to the following functions. Use them if required -
{
"name": "calculate_discount",
"description": "Calculate the discounted price of a product",
"parameters": {
"type": "object",
"properties": {
"original_price": {
"type": "number",
"description": "The original price of the product"
},
"discount_percentage": {
"type": "number",
"description": "The discount percentage"
}
},
"required": [
"original_price",
"discount_percentage"
]
}
}
'''
tooling_gemma_model = ToolingGemma(system_instructions=system_instructions)
agent_response = tooling_gemma_model.generate('Can you please book a flight for me from New York to London?')